Skip to content

Commit 45aeede

Browse files
Merge commit '915c1499789ea2257ab494da833cb78789c9f5af'
2 parents e4fa38e + 915c149 commit 45aeede

File tree

22 files changed

+392
-107
lines changed

22 files changed

+392
-107
lines changed

.github/ISSUE_TEMPLATE/bug.yml

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
name: Report a bug
2+
description: Report triton failing to compile a kernel, or giving incorrect results
3+
labels: ["bug"]
4+
5+
body:
6+
- type: markdown
7+
attributes:
8+
value: |
9+
#### Disclaimer
10+
The core triton team is small and has very limited capacity. We may not have time to look into your report.
11+
For the best results, please:
12+
- Avoid submitting duplicates. Search through [the existing and past issues](https://github.com/triton-lang/triton/issues?q=is%3Aissue+sort%3Acreated-desc+) first to see if it's been reported previously.
13+
- Check if the issue persists with a build from the latest source.
14+
- Provide all relevant information in the initial report, to prevent unnecessary back and forth discussion.
15+
- If you can, try to diagnose and/or fix the issue yourself. We welcome high quality contributions.
16+
- type: textarea
17+
attributes:
18+
label: Describe the bug
19+
description: |
20+
Please provide a clear and concise description of what the bug is.
21+
22+
If relevant, add a [minimal complete example](https://stackoverflow.com/help/minimal-reproducible-example) that reproduces the bug. It is very important for the snippet to be as simple as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did, so include both the kernel and launching code as well as any relevant imports.
23+
24+
If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.
25+
26+
Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
27+
placeholder: |
28+
A clear and concise description of what the bug is.
29+
30+
```python
31+
# Sample code to reproduce the problem
32+
```
33+
34+
```
35+
The error message you got, with the full traceback.
36+
```
37+
validations:
38+
required: true
39+
- type: textarea
40+
attributes:
41+
label: Environment details
42+
description: |
43+
Please include any relevant context about how you're running the reproducer e.g. which version of triton, and what GPU you are using.
44+
placeholder: |
45+
Triton: ...
46+
GPU: ...
47+
validations:
48+
required: true

.github/ISSUE_TEMPLATE/config.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
blank_issues_enabled: true
2+
contact_links:
3+
- name: Community help
4+
url: https://discord.gg/gpumode
5+
about: GPU-mode discord community has a triton channel which is a great resource for help writing/learning triton
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
name: Report a performance issue
2+
description: Report cases where triton is generating sub-optimal (but functionally correct) PTX/LLVM IR
3+
labels: ["performance"]
4+
5+
body:
6+
- type: markdown
7+
attributes:
8+
value: |
9+
#### Disclaimer
10+
The core triton team is small and has very limited capacity. We may not have time to look into your report.
11+
For the best results, please:
12+
- Avoid submitting duplicates. Search through [the existing and past issues](https://github.com/triton-lang/triton/issues?q=is%3Aissue+sort%3Acreated-desc+) first to see if it's been reported previously.
13+
- Check if the issue persists with a build from the latest source.
14+
- Provide all relevant information in the initial report, to prevent unnecessary back and forth discussion.
15+
- If you can, try to diagnose and/or fix the issue yourself. We welcome high quality contributions.
16+
- type: textarea
17+
attributes:
18+
label: Describe the issue
19+
description: |
20+
Please provide a clear and concise description of the issue.
21+
22+
Include a [minimal complete example](https://stackoverflow.com/help/minimal-reproducible-example) that reproduces the issue. It is very important for the snippet to be as simple as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did.
23+
24+
A reproducer could be a python program that runs a triton kernel and prints out the relevant suboptimal IR, or an IR file with an accompanying triton-opt command.
25+
26+
If the code is too long (hopefully, it isn't), feel free to put it in a public gist and link it in the issue: https://gist.github.com.
27+
placeholder: |
28+
A clear and concise description of the issue.
29+
30+
```python
31+
# Sample code to reproduce the problem
32+
```
33+
validations:
34+
required: true
35+
- type: textarea
36+
attributes:
37+
label: Environment details
38+
description: |
39+
Please include any relevant context about how you're running the reproducer e.g. which version of triton, and what GPU you are using.
40+
placeholder: |
41+
Triton: ...
42+
GPU: ...
43+
validations:
44+
required: true

.github/PULL_REQUEST_TEMPLATE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
<!---
12
The core Triton is a small number of people, and we receive many PRs (thank
23
you!). To help us review your code more quickly, **if you are a new
34
contributor (less than 3 PRs merged) we ask that you complete the following
45
tasks and include the filled-out checklist in your PR description.**
56
67
Complete the following tasks before sending your PR, and replace `[ ]` with
78
`[x]` to indicate you have done them.
9+
-->
810

11+
# New contributor declaration
912
- [ ] I am not making a trivial change, such as fixing a typo in a comment.
1013

1114
- [ ] I have written a PR description following these

.github/workflows/integration-tests.yml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,6 @@ jobs:
141141
- name: Check pre-commit
142142
run: |
143143
python3 -m pip install --upgrade pre-commit
144-
# TODO: ignore the first yapf failure until https://github.com/google/yapf/issues/1164 is fixed
145-
python3 -m pre_commit run --all-files --verbose yapf &> /dev/null || true
146-
# If first run of yapf worked and made changes reset the tree to the original state
147-
git reset --hard
148144
python3 -m pre_commit run --all-files --verbose
149145
- name: Print diff of changes if pre-commit failed
150146
if: failure()

.github/workflows/integration-tests.yml.in

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,6 @@ jobs:
155155
- name: Check pre-commit
156156
run: |
157157
python3 -m pip install --upgrade pre-commit
158-
# TODO: ignore the first yapf failure until https://github.com/google/yapf/issues/1164 is fixed
159-
python3 -m pre_commit run --all-files --verbose yapf &> /dev/null || true
160-
# If first run of yapf worked and made changes reset the tree to the original state
161-
git reset --hard
162158
python3 -m pre_commit run --all-files --verbose
163159

164160
- name: Print diff of changes if pre-commit failed

.github/workflows/llvm-build.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,12 +245,14 @@ jobs:
245245
246246
# Create temporary container to copy cache and installed artifacts.
247247
CONTAINER_ID=$(docker create llvm-build)
248+
249+
# We remove the existing directories, otherwise docker cp will
250+
# create a subdirectory inside the existing directory.
251+
rm -rf "${{ env.SCCACHE_DIR }}" "${{ env.llvm_install_dir }}"
252+
248253
docker cp "${CONTAINER_ID}:/install" "${{ env.llvm_install_dir }}"
249254
tar czf "${{ env.llvm_install_dir }}.tar.gz" "${{ env.llvm_install_dir }}"
250255
251-
# We remove the existing directory, otherwise docker will
252-
# create a subdirectory inside the existing directory.
253-
rm -rf "${{ env.SCCACHE_DIR }}"
254256
docker cp "${CONTAINER_ID}:/sccache" "${{ env.SCCACHE_DIR }}"
255257
sudo chown -R "$(id -u -n):$(id -g -n)" "${{ env.SCCACHE_DIR }}"
256258

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,11 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [Elementwise,
100100
}];
101101

102102
let arguments = (
103-
ins TT_FloatTensor:$src,
103+
ins TT_FloatLike:$src,
104104
OptionalAttr<TT_RoundingModeAttr>:$rounding
105105
);
106106

107-
let results = (outs TT_FloatTensor:$result);
107+
let results = (outs TT_FloatLike:$result);
108108

109109
let assemblyFormat = "$src attr-dict (`,` `rounding` `=` $rounding^)? `:` type($src) `->` type($result)";
110110

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -734,26 +734,34 @@ OpFoldResult FpToFpOp::fold(FoldAdaptor adaptor) {
734734
auto srcVal = getSrc();
735735
auto dstTy = getType();
736736

737-
const llvm::fltSemantics &semantic =
738-
llvm::cast<FloatType>(dstTy.getElementType()).getFloatSemantics();
737+
auto resElemType = cast<FloatType>(getElementTypeOrSelf(getType()));
738+
const llvm::fltSemantics &semantic = resElemType.getFloatSemantics();
739739

740740
if (matchPattern(srcVal, m_PosZeroFloat())) {
741741
llvm::APFloat posZero =
742742
llvm::APFloat::getZero(semantic, /*negative=*/false);
743-
return DenseFPElementsAttr::get(dstTy, posZero);
743+
if (auto tensorTy = dyn_cast<RankedTensorType>(dstTy))
744+
return DenseElementsAttr::get(tensorTy, posZero);
745+
return Builder(getContext()).getFloatAttr(resElemType, posZero);
744746
}
745747

746748
if (matchPattern(srcVal, m_NegZeroFloat())) {
747749
llvm::APFloat negZero = llvm::APFloat::getZero(semantic, /*negative=*/true);
748-
return DenseFPElementsAttr::get(dstTy, negZero);
750+
if (auto tensorTy = dyn_cast<RankedTensorType>(dstTy))
751+
return DenseElementsAttr::get(tensorTy, negZero);
752+
return Builder(getContext()).getFloatAttr(resElemType, negZero);
749753
}
750754

751755
return {};
752756
}
753757

754758
LogicalResult FpToFpOp::verify() {
755-
auto dstType = getType().getElementType();
756-
auto srcType = getSrc().getType().getElementType();
759+
auto dstType = getType();
760+
auto srcType = getSrc().getType();
761+
if (auto dstTensorType = dyn_cast<RankedTensorType>(dstType))
762+
dstType = dstTensorType.getElementType();
763+
if (auto srcTensorType = dyn_cast<RankedTensorType>(srcType))
764+
srcType = srcTensorType.getElementType();
757765
if ((dstType.getIntOrFloatBitWidth() < srcType.getIntOrFloatBitWidth()) &&
758766
(!getRounding().has_value())) {
759767
return emitError("Rounding mode is required for FP downcast");

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
#include "mlir/IR/BuiltinTypes.h"
2-
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
32
#include "triton/Dialect/Triton/IR/Dialect.h"
43
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
54
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
6-
#include "triton/Dialect/TritonNvidiaGPU/IR/Types.h"
7-
#include "llvm/Support/raw_ostream.h"
85

96
#define GET_OP_CLASSES
107
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
@@ -39,19 +36,6 @@ LogicalResult UpcastMXFPOp::verify() {
3936
return emitOpError("NYI: fpType must be E2M1, E4M3, or E5M2");
4037
}
4138

42-
// Change to support fp8 types
43-
const auto elems_packed = fpType == ScaleDotElemType::E2M1 ? 2 : 1;
44-
45-
if (xShape.back() != (32 / elems_packed) * scaleShape.back()) {
46-
return emitOpError("last dimension of first operand must be 16 times "
47-
"larger than that of the second operand");
48-
}
49-
50-
if (!std::equal(xShape.begin(), xShape.end() - 1, scaleShape.begin())) {
51-
return emitOpError(
52-
"all dimensions except the last must match between operands");
53-
}
54-
5539
auto layoutX = xTy.getEncoding();
5640
auto layoutScale = scaleTy.getEncoding();
5741
if (bool(layoutX) != bool(layoutScale)) {
@@ -82,6 +66,28 @@ LogicalResult UpcastMXFPOp::verify() {
8266
}
8367
}
8468

69+
// Change to support fp8 types
70+
const auto elemsPacked = fpType == ScaleDotElemType::E2M1 ? 2 : 1;
71+
// Figure out the K dimension for the input A/B. For A/B scale, the K
72+
// dimension is always the last dimension.
73+
const int opIdx = dotEncoding.getOpIdx();
74+
const bool hasBatch = xShape.size() == 3;
75+
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
76+
77+
if (xShape[kIdx] != (32 / elemsPacked) * scaleShape.back()) {
78+
return emitOpError("K dimension of first operand must be 16 times "
79+
"larger than last/K dimension of the second operand");
80+
}
81+
82+
// Check other dimensions match too. For input A/B, we need to figure out the
83+
// index for the M/N dimension. For scale, it's always {(batch), M/N, K}.
84+
const int mnIdx = (opIdx == 0 ? 0 : 1) + hasBatch;
85+
if (hasBatch && xShape[0] != scaleShape[0])
86+
return emitOpError("batch dimension must match between operands");
87+
if (xShape[mnIdx] != scaleShape[hasBatch]) {
88+
return emitOpError("M/N dimension must match between operands");
89+
}
90+
8591
return success();
8692
}
8793

@@ -100,14 +106,20 @@ LogicalResult UpcastMXFPOp::inferReturnTypes(
100106
RankedTensorType retTy;
101107

102108
auto newShape = SmallVector<int64_t>(xShape);
103-
newShape.back() *= 2;
104109
if (!encoding) {
110+
newShape.back() *= 2;
105111
retTy = RankedTensorType::get(xShape, FloatType::getBF16(ctx));
106112
} else {
107113
auto oldEncoding = cast<DotOperandEncodingAttr>(encoding);
108114
auto newVEncoding = DotOperandEncodingAttr::get(
109115
ctx, oldEncoding.getOpIdx(), oldEncoding.getParent(),
110116
oldEncoding.getKWidth() * 2);
117+
// Figure out the K dimension for the input A/B, given that the return
118+
// type is upcasted A/B type so we need to update the proper dim size.
119+
const int opIdx = oldEncoding.getOpIdx();
120+
const bool hasBatch = xShape.size() == 3;
121+
const int kIdx = (opIdx == 0 ? 1 : 0) + hasBatch;
122+
newShape[kIdx] *= 2;
111123
retTy = RankedTensorType::get(newShape, FloatType::getBF16(ctx),
112124
newVEncoding);
113125
}

0 commit comments

Comments
 (0)