Skip to content

Commit e244747

Browse files
committed
Merge commit '6fedb78ccef7135e78967254c442ed4637335b15'
2 parents 94b3473 + 6fedb78 commit e244747

File tree

43 files changed

+917
-380
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+917
-380
lines changed

.github/workflows/create_release.yml

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,30 +43,31 @@ jobs:
4343
tag_or_branch="${tag_or_branch#refs/heads/}"
4444
# replace directory separators with _ in branch name
4545
tag_or_branch="${tag_or_branch//\//_}"
46+
if [[ ${tag_or_branch} == v* ]]; then
47+
# strip trailing v from tag name
48+
tag_or_branch="${tag_or_branch#v}"
49+
# important: version must be fixed in setup.py
50+
sed -i -e "s:^TRITON_VERSION = .*:TRITON_VERSION = '${tag_or_branch}':" setup.py || exit 1
51+
fi
4652
echo "RELEASE_NAME=triton-$tag_or_branch" >> "$GITHUB_ENV"
47-
echo "RELEASE_FILE=triton-$tag_or_branch.tar.gz" >> "$GITHUB_ENV"
4853
- name: Create source distribution
4954
run: |
50-
# Create new folder with specified name so extracting the archive yields that
51-
rm -rf "/tmp/$RELEASE_NAME"
52-
cp -r "$PWD" "/tmp/$RELEASE_NAME"
53-
mv "/tmp/$RELEASE_NAME" .
54-
# Cleanup
55-
find "$RELEASE_NAME" -name '.git*' -exec rm -rv {} \; || true
56-
# Create archive
57-
tar -czf "$RELEASE_FILE" "$RELEASE_NAME"
58-
echo "Created source archive $RELEASE_FILE with content: $(ls -a "$RELEASE_NAME")"
55+
pip install build || exit 1
56+
python -m build -s || exit 1
57+
cd dist || exit 1
58+
release_file=( *.tar.gz )
59+
echo "RELEASE_FILE=${release_file}" >> "$GITHUB_ENV"
5960
- name: Upload source distribution for release
6061
if: ${{ github.event_name == 'release' }}
6162
uses: softprops/action-gh-release@v2
6263
with:
63-
files: ${{env.RELEASE_FILE}}
64+
files: dist/${{env.RELEASE_FILE}}
6465
- name: Upload source distribution to GHA artifacts for release tags
6566
if: ${{ github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') && contains(github.ref, 'rc') }}
6667
uses: actions/[email protected]
6768
with:
6869
name: ${{ env.RELEASE_FILE }}
69-
path: ${{ env.RELEASE_FILE }}
70+
path: dist/${{ env.RELEASE_FILE }}
7071
- name: Set output
7172
id: release_name
7273
run: echo "name=release_name::${{ env.RELEASE_NAME }}.tar.gz" >> "${GITHUB_OUTPUT}"

.github/workflows/wheels.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,12 @@ jobs:
7979
export CIBW_BUILD="cp3{9,10,11,12,13,13t}-manylinux_${{ matrix.config.arch }}"
8080
export CIBW_SKIP="cp{35,36,37,38}-*"
8181
export CIBW_FREE_THREADED_SUPPORT=1
82-
python3 -m cibuildwheel python --output-dir wheelhouse
82+
python3 -m cibuildwheel . --output-dir wheelhouse
83+
84+
- uses: actions/upload-artifact@v4
85+
with:
86+
name: cibw-wheels-manylinux_2_28_${{ matrix.config.arch }}-wheels-upload
87+
path: ./wheelhouse/*.whl
8388

8489
- name: Install Azure CLI
8590
if: ${{ steps.check-version.outputs.new_commit == 'true' }}

MANIFEST.in

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
graft bin
2+
graft cmake
3+
graft docs
4+
graft include
5+
graft lib
6+
graft python/src
7+
graft python/test
8+
graft python/triton/backends/amd
9+
graft python/triton/backends/nvidia
10+
graft python/triton/tools/extra/cuda
11+
graft test
12+
graft third_party
13+
graft unittest
14+
include CMakeLists.txt
15+
include Makefile
16+
include python/build_helpers.py
17+
include python/requirements.txt
18+
include python/test-requirements.txt

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ dev-install-torch:
7979

8080
.PHONY: dev-install-triton
8181
dev-install-triton:
82-
$(PYTHON) -m pip install -e python --no-build-isolation -v
82+
$(PYTHON) -m pip install -e . --no-build-isolation -v
8383

8484
.PHONY: dev-install
8585
.NOPARALLEL: dev-install

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ git clone https://github.com/triton-lang/triton.git
7171
cd triton
7272

7373
pip install -r python/requirements.txt # build-time dependencies
74-
pip install -e python
74+
pip install -e .
7575
```
7676

7777
Or with a virtualenv:
@@ -84,7 +84,7 @@ python -m venv .venv --prompt triton
8484
source .venv/bin/activate
8585

8686
pip install -r python/requirements.txt # build-time dependencies
87-
pip install -e python
87+
pip install -e .
8888
```
8989

9090
# Building with a custom LLVM
@@ -124,7 +124,7 @@ arbitrary LLVM version.
124124
$ LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include \
125125
LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib \
126126
LLVM_SYSPATH=$LLVM_BUILD_DIR \
127-
pip install -e python
127+
pip install -e .
128128

129129
# Tips for building
130130

@@ -139,7 +139,7 @@ arbitrary LLVM version.
139139
can be changed anytime.
140140

141141
- If you're running out of memory when building Triton, specify the `MAX_JOBS`
142-
environment variable (to the `pip install -e python` command) to limit the
142+
environment variable (to the `pip install -e .` command) to limit the
143143
number of jobs.
144144

145145
- Pass `--no-build-isolation` to `pip install` to make nop builds faster.
@@ -150,7 +150,7 @@ arbitrary LLVM version.
150150
(probably because, in our build, users don't invoke cmake directly, but
151151
instead use setup.py). Teach vscode how to compile Triton as follows.
152152

153-
- Do a local build. Run command `pip install -e python`
153+
- Do a local build. Run command `pip install -e .`
154154
- Get the full path to the `compile_commands.json` file produced by the build:
155155
`find python/build -name 'compile_commands.json' | xargs readlink -f`.
156156
You might get a full path similar to `/Users/{username}/triton/python/build/cmake.macosx-11.1-arm64-cpython-3.12/compile_commands.json`

bench/bench/bench_mlp.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
9898
# -- benchmark --
9999
fpath = Path(f"logs/{name}/{batch}-{dim1}-{dim2}-{n_expts_tot}-{n_expts_act}-{x_dtype}-{w_dtype}.hatchet")
100100
fpath.parent.mkdir(parents=True, exist_ok=True)
101-
x_dtype = {"bf16": torch.bfloat16, "fp8": torch.float8_e4m3fn}[x_dtype]
101+
x_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.float8_e4m3fn}[x_dtype]
102102
# special treatment of fp8_e4m3 on AMD CDNA3 because it uses fp8_e4m3fnuz
103103
if x_dtype == torch.float8_e4m3fn and get_cdna_version() == 3:
104104
x_dtype = torch.float8_e4m3fnuz
@@ -140,17 +140,29 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
140140
min_time = max(min_time_flops, min_time_bytes)
141141
util = min_time / tot_time
142142
else:
143-
util = "N/A"
143+
util = 0.0
144144
tflops = sum([tot_flops[w] for w in [8, 16]]) / tot_time * 1e-3
145145
tbps = tot_bytes / tot_time * 1e-3
146+
print(f"Utilization: {util:.0%}; {tflops:>6.1f} TFLOPs, {tbps:.1f} TB/s")
146147

147148
return util, tflops, tbps
148149

149150

150151
if __name__ == "__main__":
151152
has_native_mx4 = torch.cuda.get_device_capability(0)[0] >= 10 or get_cdna_version() == 4
152-
qxdtype = "fp8" if has_native_mx4 else "bf16"
153-
print(bench_mlp(8192, 8192, 8192, 1, 1, "fp8", "fp8", TP=1, EP=1, name="dense"))
154-
print(bench_mlp(8192, 8192, 8192, 1, 1, qxdtype, "mx4", TP=1, EP=1, name="dense"))
155-
print(bench_mlp(2048, 5120, 8192, 128, 4, "fp8", "fp8", TP=4, EP=1, name="llama4"))
156-
print(bench_mlp(2048, 5120, 8192, 128, 4, qxdtype, "mx4", TP=4, EP=1, name="llama4"))
153+
if SPECS is None:
154+
print("Current GPU has no specs provided, utilization is N/A")
155+
if has_native_mx4:
156+
bench_mlp(8192, 8192, 8192, 1, 1, "fp8", "fp8", TP=1, EP=1, name="dense")
157+
bench_mlp(8192, 8192, 8192, 1, 1, "fp8", "mx4", TP=1, EP=1, name="dense")
158+
bench_mlp(2048, 5120, 8192, 128, 4, "fp8", "fp8", TP=4, EP=1, name="llama4")
159+
bench_mlp(2048, 5120, 8192, 128, 4, "fp8", "mx4", TP=4, EP=1, name="llama4")
160+
else:
161+
# bf16/fp16 x fp8 is skipped because matmul_ogs requires x and w has the
162+
# same type when not doing mxfp operation
163+
bench_mlp(8192, 8192, 8192, 1, 1, "fp8", "fp8", TP=1, EP=1, name="dense")
164+
bench_mlp(8192, 8192, 8192, 1, 1, "fp16", "mx4", TP=1, EP=1, name="dense")
165+
bench_mlp(8192, 8192, 8192, 1, 1, "bf16", "mx4", TP=1, EP=1, name="dense")
166+
bench_mlp(2048, 5120, 8192, 128, 4, "fp8", "fp8", TP=4, EP=1, name="llama4")
167+
bench_mlp(2048, 5120, 8192, 128, 4, "bf16", "mx4", TP=4, EP=1, name="llama4")
168+
bench_mlp(2048, 5120, 8192, 128, 4, "fp16", "mx4", TP=4, EP=1, name="llama4")

include/triton/Dialect/Triton/IR/OpInterfaces.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define TRITON_IR_OP_INTERFACES_H_
33

44
#include "mlir/IR/OpDefinition.h"
5+
#include "triton/Dialect/Triton/IR/Types.h"
56

67
namespace mlir {
78

include/triton/Dialect/Triton/IR/TritonAttrDefs.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,21 @@ def TT_AtomicRMWAttr : I32EnumAttr<
6767
let cppNamespace = "::mlir::triton";
6868
}
6969

70+
def TT_DescriptorReduceKindAttr : I32EnumAttr<
71+
"DescriptorReduceKind", "",
72+
[
73+
I32EnumAttrCase<"ADD", 1, "add">,
74+
I32EnumAttrCase<"MIN", 2, "min">,
75+
I32EnumAttrCase<"MAX", 3, "max">,
76+
I32EnumAttrCase<"INC", 4, "inc">,
77+
I32EnumAttrCase<"DEC", 5, "dec">,
78+
I32EnumAttrCase<"AND", 6, "and">,
79+
I32EnumAttrCase<"OR", 7, "or">,
80+
I32EnumAttrCase<"XOR", 8, "xor">,
81+
]> {
82+
let cppNamespace = "::mlir::triton";
83+
}
84+
7085
def TT_MemSyncScopeAttr : I32EnumAttr<
7186
"MemSyncScope", "",
7287
[

include/triton/Dialect/Triton/IR/TritonOpInterfaces.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,5 +75,31 @@ def DotOpInterface : OpInterface<"DotOpInterface"> {
7575
let verify = [{ return ::mlir::triton::impl::verifyDotOpInterface($_op); }];
7676
}
7777

78+
def TT_DescriptorOpInterface : OpInterface<"DescriptorOpInterface"> {
79+
let description = [{
80+
Common interface to get the descriptor argument from an operation on tensor descriptors.
81+
}];
82+
83+
let methods = [
84+
InterfaceMethod<
85+
/*desc=*/"Get the descriptor",
86+
/*retType=*/"::mlir::TypedValue<mlir::triton::TensorDescType>",
87+
/*methodName=*/"getDesc",
88+
/*args=*/(ins)>,
89+
];
90+
}
91+
92+
def TT_DescriptorStoreLikeOpInterface : OpInterface<"DescriptorStoreLikeOpInterface", [TT_DescriptorOpInterface]> {
93+
let cppNamespace = "::mlir::triton";
94+
95+
let methods = [
96+
InterfaceMethod<
97+
/*desc=*/"Get Source tensor",
98+
/*retType=*/"::mlir::TypedValue<mlir::RankedTensorType>",
99+
/*methodName=*/"getSrc",
100+
/*args=*/(ins)>,
101+
];
102+
}
103+
78104

79105
#endif // TRITON_OP_INTERFACES

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,7 @@ def TT_MakeTensorDescOp : TT_Op<"make_tensor_descriptor", [
10191019
let assemblyFormat = "$base `,` `[` $shape `]` `,` `[` $strides `]` attr-dict `:` type($base) `,` type($result)";
10201020

10211021
let builders = [
1022-
OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef<int32_t>":$blockShape)>
1022+
OpBuilder<(ins "Value":$base, "ValueRange":$shape, "ValueRange":$strides, "ArrayRef<int32_t>":$blockShape, "bool":$isSignedInteger)>
10231023
];
10241024

10251025
let extraClassDeclaration = [{
@@ -1259,7 +1259,7 @@ def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable
12591259
}
12601260

12611261

1262-
def TT_DescriptorLoadOp : TT_Op<"descriptor_load"> {
1262+
def TT_DescriptorLoadOp : TT_Op<"descriptor_load", [TT_DescriptorOpInterface]> {
12631263
let summary = "Load from descriptor";
12641264
let description = [{
12651265
This operation will be lowered to Nvidia TMA load operation on targets supporting it.
@@ -1287,7 +1287,7 @@ def TT_DescriptorLoadOp : TT_Op<"descriptor_load"> {
12871287
let hasVerifier = 1;
12881288
}
12891289

1290-
def TT_DescriptorStoreOp : TT_Op<"descriptor_store"> {
1290+
def TT_DescriptorStoreOp : TT_Op<"descriptor_store", [TT_DescriptorStoreLikeOpInterface]> {
12911291
let summary = "store value based on descriptor";
12921292
let description = [{
12931293
This operation will be lowered to Nvidia TMA store operation on targets supporting it.
@@ -1304,11 +1304,30 @@ def TT_DescriptorStoreOp : TT_Op<"descriptor_store"> {
13041304
$desc `[` $indices `]` `,` $src
13051305
attr-dict `:` qualified(type($desc)) `,` type($src)
13061306
}];
1307-
13081307
let hasVerifier = 1;
13091308
}
13101309

1311-
def TT_DescriptorGatherOp : TT_Op<"descriptor_gather"> {
1310+
def TT_DescriptorReduceOp : TT_Op<"descriptor_reduce", [TT_DescriptorStoreLikeOpInterface]> {
1311+
let summary = "performs a reducing store operation based on a descriptor";
1312+
let description = [{
1313+
This operation will be lowered to Nvidia TMA store operation on targets supporting it.
1314+
`desc` is a tensor descriptor object.
1315+
The shape and types of `src` must match the descriptor otherwise the result is undefined.
1316+
}];
1317+
let arguments = (ins
1318+
TT_DescriptorReduceKindAttr:$kind,
1319+
Arg<TT_TensorDescType, "", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>:$desc,
1320+
TT_Tensor:$src,
1321+
Variadic<I32>:$indices
1322+
);
1323+
1324+
let assemblyFormat = [{
1325+
$kind `,` $desc `[` $indices `]` `,` $src
1326+
attr-dict `:` qualified(type($desc)) `,` type($src)
1327+
}];
1328+
}
1329+
1330+
def TT_DescriptorGatherOp : TT_Op<"descriptor_gather", [TT_DescriptorOpInterface]> {
13121331
let summary = "gather multiple rows from a descriptor into a single tensor";
13131332
let description = [{
13141333
The `tt.descriptor_gather` op will be lowered to NVIDIA TMA
@@ -1341,7 +1360,7 @@ def TT_DescriptorGatherOp : TT_Op<"descriptor_gather"> {
13411360
}];
13421361
}
13431362

1344-
def TT_DescriptorScatterOp : TT_Op<"descriptor_scatter"> {
1363+
def TT_DescriptorScatterOp : TT_Op<"descriptor_scatter", [TT_DescriptorStoreLikeOpInterface]> {
13451364
let summary = "scatter multiple rows to a descriptor from a single tensor";
13461365
let description = [{
13471366
The `tt.descriptor_scatter` op will be lowered to NVIDIA TMA

0 commit comments

Comments
 (0)