Skip to content

Commit 81b0627

Browse files
Sync from upstream (#2820)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 9b05fba commit 81b0627

File tree

9 files changed

+34
-26
lines changed

9 files changed

+34
-26
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ endfunction()
154154
if(NOT MSVC)
155155
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wno-covered-switch-default -fvisibility=hidden")
156156
else()
157-
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /WX-")
157+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /WX- /wd4244 /wd4624 /wd4715 /wd4530")
158158
endif()
159159

160160
include_directories(".")

docs/index.rst

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,6 @@ Getting Started
2222
getting-started/tutorials/index
2323

2424

25-
Programming Guide
26-
-----------------
27-
28-
Check out the following documents to learn more about Triton and its comparison with other DSLs for Deep Neural Networks (DNNs):
29-
30-
- Chapter 1: :doc:`Introduction <programming-guide/chapter-1/introduction>`
31-
- Chapter 2: :doc:`Related Work <programming-guide/chapter-2/related-work>`
32-
- Chapter 3: :doc:`Debugging <programming-guide/chapter-3/debugging>`
33-
34-
.. toctree::
35-
:maxdepth: 1
36-
:caption: Programming Guide
37-
:hidden:
38-
39-
programming-guide/chapter-1/introduction
40-
programming-guide/chapter-2/related-work
41-
programming-guide/chapter-3/debugging
42-
43-
.. _Triton: https://github.com/triton-lang/triton
44-
4525
Python API
4626
----------
4727

@@ -73,3 +53,23 @@ Triton MLIR Dialects and Ops
7353
:hidden:
7454

7555
dialects/dialects
56+
57+
Going Further
58+
-------------
59+
60+
Check out the following documents to learn more about Triton and how it compares against other DSLs for DNNs:
61+
62+
- Chapter 1: :doc:`Introduction <programming-guide/chapter-1/introduction>`
63+
- Chapter 2: :doc:`Related Work <programming-guide/chapter-2/related-work>`
64+
- Chapter 3: :doc:`Debugging <programming-guide/chapter-3/debugging>`
65+
66+
.. toctree::
67+
:maxdepth: 1
68+
:caption: Programming Guide
69+
:hidden:
70+
71+
programming-guide/chapter-1/introduction
72+
programming-guide/chapter-2/related-work
73+
programming-guide/chapter-3/debugging
74+
75+
.. _Triton: https://github.com/triton-lang/triton

lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
5050
// know about the op to split the block.
5151
void llAssert(Operation *op, Value condition, StringRef message,
5252
ConversionPatternRewriter &rewriter) const {
53+
5354
auto ctx = rewriter.getContext();
5455
auto loc = op->getLoc();
5556

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
344344
auto dstTy = op.getType();
345345
auto inVals = unpackLLElements(loc, adaptor.getSrc(), rewriter);
346346
SmallVector<Value> outVals(numRegs);
347-
for (int i = 0; i < outVals.size(); i++) {
347+
for (int i = 0; i < numRegs; i++) {
348348
// Remove free masks from the register index
349349
// For example, if idx = 0b00111, and masks = 0b00100, then we get
350350
// 0b00011. It means that register 7 (0b111) has the same value as

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ class SwizzleShmemConvert : public OpRewritePattern<ConvertLayoutOp> {
157157
if (auto srcCvt = trans.getSrc().getDefiningOp<ConvertLayoutOp>()) {
158158
srcTy = srcCvt.getSrc().getType();
159159
}
160-
auto sharedLoadTy = cast<RankedTensorType>(cvtOp.getType());
160+
RankedTensorType sharedLoadTy = cvtOp.getType();
161161
auto cvtEncoding =
162162
dyn_cast<DotOperandEncodingAttr>(sharedLoadTy.getEncoding());
163163
if (!cvtEncoding)

python/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -764,7 +764,7 @@ def get_git_commit_hash(length=8):
764764

765765
def get_install_requires():
766766
install_requires = [
767-
"packaging", # used by third_party/intel/backend/compiler.py
767+
"packaging", # used by third_party/intel/backend/driver.py
768768
] # yapf: disable
769769
return install_requires
770770

python/test/regression/test_cast_matmul.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
66
TODO: float8 types
77
"""
8+
89
import warnings
910
import pytest
1011
import torch

python/triton/backends/compiler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from abc import ABCMeta, abstractmethod
88
from dataclasses import dataclass
9-
from typing import Dict, Union
9+
from typing import Dict, List, Tuple, Union
1010
from types import ModuleType
1111

1212
# Table that associates strings to AttrsDescriptor (sub)classes.
@@ -171,7 +171,7 @@ def from_dict(data):
171171
return attrs_descriptor
172172

173173
@classmethod
174-
def from_hints(cls, hints: list[tuple[int, int]]):
174+
def from_hints(cls, hints: List[Tuple[int, int]]):
175175
"""
176176
Create the class from a set of hints that are passed in.
177177

third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,12 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
232232
mfmaInstrK = elemsPerInstr[kDimIdx];
233233
}
234234

235+
if (mfmaInstrNonK > shape[nonKDimIdx] || mfmaInstrK > shape[kDimIdx]) {
236+
// This pattern does not support cases tensor shape is smaller than
237+
// one instruction size, it will be processed by LinearLayout converter
238+
return Value();
239+
}
240+
235241
auto numReps = mfmaLayout.getRepForOperand(shape, kWidth, opIdx);
236242
auto numRepNonK = numReps[nonKDimIdx];
237243
auto numRepK = numReps[kDimIdx];

0 commit comments

Comments
 (0)