Skip to content

Commit 24bd5a6

Browse files
Merge commit '76ed94df1924b2262be9b37d778b6e0ccccb1180'
2 parents 7f80413 + 76ed94d commit 24bd5a6

File tree

28 files changed

+103
-1077
lines changed

28 files changed

+103
-1077
lines changed

bin/RegisterTritonDialects.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8585
mlir::registerTritonAMDGPUAccelerateMatmul();
8686
mlir::registerTritonAMDGPUOptimizeEpilogue();
8787
mlir::registerTritonAMDGPUReorderInstructions();
88-
mlir::registerTritonAMDGPUStreamPipeline();
8988
mlir::registerTritonAMDGPUStreamPipelineV2();
9089
mlir::registerTritonAMDGPUCanonicalizePointers();
9190

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -460,17 +460,12 @@ def TT_ReshapeOp : TT_Op<"reshape", [Pure,
460460
If efficient_layout is set, this is a hint that the destination layout should be kept for performance reason.
461461
The compiler is still free to change it for better performance.
462462
}];
463-
let arguments = (ins TT_Tensor:$src, BoolAttr:$allow_reorder, OptionalAttr<UnitAttr>:$efficient_layout);
463+
let arguments = (ins TT_Tensor:$src, UnitAttr:$allow_reorder, UnitAttr:$efficient_layout);
464464
let results = (outs TT_Tensor:$result);
465-
let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
465+
let assemblyFormat = "$src (`allow_reorder` $allow_reorder^)? (`efficient_layout` $efficient_layout^)? attr-dict `:` type($src) `->` type($result)";
466466
let hasCanonicalizeMethod = 1;
467467
let hasFolder = 1;
468468
let hasVerifier = 1;
469-
let builders = [
470-
OpBuilder<(ins "Type":$type, "Value":$src, "bool":$allow_reorder),
471-
[{
472-
build($_builder, $_state, type, src, allow_reorder, /*efficient_layout=*/UnitAttr());
473-
}]>];
474469
}
475470

476471
def TT_BroadcastOp : TT_Op<"broadcast", [Pure,

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ LogicalResult canonicalizeViewOrBroadcast(OpType op,
678678
}
679679

680680
LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) {
681-
if (!op.getAllowReorder() || op.getEfficientLayout().has_value())
681+
if (!op.getAllowReorder() || op.getEfficientLayout())
682682
return failure();
683683
return canonicalizeViewOrBroadcast(op, rewriter);
684684
}

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2764,7 +2764,7 @@ struct CanonicalizeConvertFromReshape
27642764
return failure();
27652765
if (isExpensiveView(convert.getSrc().getType(), op.getType()))
27662766
return failure();
2767-
if (!op.getAllowReorder() || op.getEfficientLayout().has_value())
2767+
if (!op.getAllowReorder() || op.getEfficientLayout())
27682768
return failure();
27692769

27702770
rewriter.replaceOpWithNewOp<triton::ReshapeOp>(
@@ -2885,8 +2885,7 @@ struct CanonicalizeConvertFromConvert
28852885

28862886
// cvt(reshape) -> reshape
28872887
if (auto reshape = dyn_cast<ReshapeOp>(arg)) {
2888-
if (!reshape.getAllowReorder() ||
2889-
reshape.getEfficientLayout().has_value() ||
2888+
if (!reshape.getAllowReorder() || reshape.getEfficientLayout() ||
28902889
isExpensiveView(reshape.getSrc().getType(), op.getType()))
28912890
return failure();
28922891

lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,8 @@ class TritonGPUOptimizeThreadLocalityPass
314314
IRMapping mapping;
315315
for (auto operand : reduce.getOperands()) {
316316
auto viewOp = builder.create<triton::ReshapeOp>(
317-
reduce.getLoc(), viewOpTensorType, operand, /*allowReorder=*/true);
318-
viewOp.setEfficientLayout(true);
317+
reduce.getLoc(), viewOpTensorType, operand,
318+
/*allowReorder=*/true, /*efficientLayout=*/true);
319319
mapping.map(operand, viewOp);
320320
}
321321

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,8 +556,7 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) {
556556
RankedTensorType newDstType =
557557
RankedTensorType::get(reshapeDstType.getShape(),
558558
reshapeDstType.getElementType(), targetEncoding);
559-
return reshape.getAllowReorder() &&
560-
!reshape.getEfficientLayout().has_value() &&
559+
return reshape.getAllowReorder() && !reshape.getEfficientLayout() &&
561560
!triton::gpu::isExpensiveView(reshape.getSrc().getType(),
562561
newDstType);
563562
}

python/setup.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,8 @@ def download_and_copy(name, src_path, dst_path, variable, version, url_func):
284284
arch = {"x86_64": "64", "arm64": "aarch64", "aarch64": "aarch64"}[platform.machine()]
285285
except KeyError:
286286
arch = platform.machine()
287-
url = url_func(arch, version)
287+
supported = {"Linux": "linux", "Darwin": "linux"}
288+
url = url_func(supported[system], arch, version)
288289
tmp_path = os.path.join(triton_cache_path, "nvidia", name) # path to cache the download
289290
dst_path = os.path.join(base_dir, os.pardir, "third_party", "nvidia", "backend", dst_path) # final binary path
290291
platform_name = "sbsa-linux" if arch == "aarch64" else "x86_64-linux"
@@ -500,61 +501,62 @@ def get_platform_dependent_src_path(subdir):
500501

501502
download_and_copy(
502503
name="ptxas", src_path="bin/ptxas", dst_path="bin/ptxas", variable="TRITON_PTXAS_PATH",
503-
version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], url_func=lambda arch, version:
504+
version=NVIDIA_TOOLCHAIN_VERSION["ptxas"], url_func=lambda system, arch, version:
504505
((lambda version_major, version_minor1, version_minor2:
505-
f"https://anaconda.org/nvidia/cuda-nvcc-tools/{version}/download/linux-{arch}/cuda-nvcc-tools-{version}-0.tar.bz2"
506+
f"https://anaconda.org/nvidia/cuda-nvcc-tools/{version}/download/{system}-{arch}/cuda-nvcc-tools-{version}-0.tar.bz2"
506507
if int(version_major) >= 12 and int(version_minor1) >= 5 else
507-
f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2")
508+
f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2")
508509
(*version.split('.'))))
509510
download_and_copy(
510511
name="cuobjdump",
511512
src_path="bin/cuobjdump",
512513
dst_path="bin/cuobjdump",
513514
variable="TRITON_CUOBJDUMP_PATH",
514515
version=NVIDIA_TOOLCHAIN_VERSION["cuobjdump"],
515-
url_func=lambda arch, version:
516-
f"https://anaconda.org/nvidia/cuda-cuobjdump/{version}/download/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2",
516+
url_func=lambda system, arch, version:
517+
f"https://anaconda.org/nvidia/cuda-cuobjdump/{version}/download/{system}-{arch}/cuda-cuobjdump-{version}-0.tar.bz2",
517518
)
518519
download_and_copy(
519520
name="nvdisasm",
520521
src_path="bin/nvdisasm",
521522
dst_path="bin/nvdisasm",
522523
variable="TRITON_NVDISASM_PATH",
523524
version=NVIDIA_TOOLCHAIN_VERSION["nvdisasm"],
524-
url_func=lambda arch, version:
525-
f"https://anaconda.org/nvidia/cuda-nvdisasm/{version}/download/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2",
525+
url_func=lambda system, arch, version:
526+
f"https://anaconda.org/nvidia/cuda-nvdisasm/{version}/download/{system}-{arch}/cuda-nvdisasm-{version}-0.tar.bz2",
526527
)
527528
download_and_copy(
528529
name="cudacrt", src_path=get_platform_dependent_src_path("include"), dst_path="include",
529-
variable="TRITON_CUDACRT_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cudacrt"], url_func=lambda arch, version:
530+
variable="TRITON_CUDACRT_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cudacrt"], url_func=lambda system, arch, version:
530531
((lambda version_major, version_minor1, version_minor2:
531-
f"https://anaconda.org/nvidia/cuda-crt-dev_linux-{arch}/{version}/download/noarch/cuda-crt-dev_linux-{arch}-{version}-0.tar.bz2"
532+
f"https://anaconda.org/nvidia/cuda-crt-dev_{system}-{arch}/{version}/download/noarch/cuda-crt-dev_{system}-{arch}-{version}-0.tar.bz2"
532533
if int(version_major) >= 12 and int(version_minor1) >= 5 else
533-
f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2")
534+
f"https://anaconda.org/nvidia/cuda-nvcc/{version}/download/{system}-{arch}/cuda-nvcc-{version}-0.tar.bz2")
534535
(*version.split('.'))))
535536
download_and_copy(
536537
name="cudart", src_path=get_platform_dependent_src_path("include"), dst_path="include",
537-
variable="TRITON_CUDART_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cudart"], url_func=lambda arch, version:
538+
variable="TRITON_CUDART_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cudart"], url_func=lambda system, arch, version:
538539
((lambda version_major, version_minor1, version_minor2:
539-
f"https://anaconda.org/nvidia/cuda-cudart-dev_linux-{arch}/{version}/download/noarch/cuda-cudart-dev_linux-{arch}-{version}-0.tar.bz2"
540+
f"https://anaconda.org/nvidia/cuda-cudart-dev_{system}-{arch}/{version}/download/noarch/cuda-cudart-dev_{system}-{arch}-{version}-0.tar.bz2"
540541
if int(version_major) >= 12 and int(version_minor1) >= 5 else
541-
f"https://anaconda.org/nvidia/cuda-cudart-dev/{version}/download/linux-{arch}/cuda-cudart-dev-{version}-0.tar.bz2"
542+
f"https://anaconda.org/nvidia/cuda-cudart-dev/{version}/download/{system}-{arch}/cuda-cudart-dev-{version}-0.tar.bz2"
542543
)(*version.split('.'))))
543544
download_and_copy(
544545
name="cupti", src_path=get_platform_dependent_src_path("include"), dst_path="include",
545-
variable="TRITON_CUPTI_INCLUDE_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cupti"], url_func=lambda arch, version:
546+
variable="TRITON_CUPTI_INCLUDE_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cupti"],
547+
url_func=lambda system, arch, version:
546548
((lambda version_major, version_minor1, version_minor2:
547-
f"https://anaconda.org/nvidia/cuda-cupti-dev/{version}/download/linux-{arch}/cuda-cupti-dev-{version}-0.tar.bz2"
549+
f"https://anaconda.org/nvidia/cuda-cupti-dev/{version}/download/{system}-{arch}/cuda-cupti-dev-{version}-0.tar.bz2"
548550
if int(version_major) >= 12 and int(version_minor1) >= 5 else
549-
f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/linux-{arch}/cuda-cupti-{version}-0.tar.bz2")
551+
f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/{system}-{arch}/cuda-cupti-{version}-0.tar.bz2")
550552
(*version.split('.'))))
551553
download_and_copy(
552554
name="cupti", src_path=get_platform_dependent_src_path("lib"), dst_path="lib/cupti",
553-
variable="TRITON_CUPTI_LIB_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cupti"], url_func=lambda arch, version:
555+
variable="TRITON_CUPTI_LIB_PATH", version=NVIDIA_TOOLCHAIN_VERSION["cupti"], url_func=lambda system, arch, version:
554556
((lambda version_major, version_minor1, version_minor2:
555-
f"https://anaconda.org/nvidia/cuda-cupti-dev/{version}/download/linux-{arch}/cuda-cupti-dev-{version}-0.tar.bz2"
557+
f"https://anaconda.org/nvidia/cuda-cupti-dev/{version}/download/{system}-{arch}/cuda-cupti-dev-{version}-0.tar.bz2"
556558
if int(version_major) >= 12 and int(version_minor1) >= 5 else
557-
f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/linux-{arch}/cuda-cupti-{version}-0.tar.bz2")
559+
f"https://anaconda.org/nvidia/cuda-cupti/{version}/download/{system}-{arch}/cuda-cupti-{version}-0.tar.bz2")
558560
(*version.split('.'))))
559561

560562
backends = [*BackendInstaller.copy(["intel", "nvidia", "amd"]), *BackendInstaller.copy_externals()]

python/test/unit/test_debug.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
for env_var in [True, False]\
1111
])
1212
@pytest.mark.forked
13-
def test_device_assert(cond, opt_flag, env_var, device="cuda"):
13+
def test_device_assert(cond, opt_flag, env_var, device):
1414
os.environ['TRITON_DEBUG'] = str(int(env_var))
1515
torch.zeros([1], dtype=torch.int32, device=device)
1616

@@ -21,11 +21,11 @@ def _kernel(COND: tl.constexpr):
2121
if not cond and (opt_flag or env_var):
2222
with pytest.raises(RuntimeError):
2323
_kernel[(1, )](cond, debug=opt_flag)
24-
torch.cuda.synchronize()
24+
getattr(torch, device).synchronize()
2525
return
2626

2727
_kernel[(1, )](cond, debug=opt_flag)
28-
torch.cuda.synchronize()
28+
getattr(torch, device).synchronize()
2929

3030

3131
@pytest.mark.parametrize("cond", [False, True])
@@ -43,19 +43,18 @@ def _kernel(COND: tl.constexpr):
4343
_kernel[(1, )](cond)
4444

4545

46-
def _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, tri_func, ref_func):
47-
device = "cuda"
46+
def _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, tri_func, ref_func, device):
4847
x = torch.tensor([x], dtype=getattr(torch, x_dtype), device=device)
4948
y = torch.tensor([y], dtype=getattr(torch, y_dtype), device=device)
5049
z = torch.empty_like(x)
5150
if should_overflow and debug:
5251
with pytest.raises(RuntimeError) as exc_info:
5352
tri_func[(1, )](x, y, z, debug=debug)
54-
torch.cuda.synchronize()
53+
getattr(torch, device).synchronize()
5554
assert "device-side assert" in str(exc_info.value)
5655
else:
5756
tri_func[(1, )](x, y, z, debug=debug)
58-
torch.cuda.synchronize()
57+
getattr(torch, device).synchronize()
5958
assert int(z) == int(ref_func(x, y))
6059

6160

@@ -74,13 +73,13 @@ def _test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, tri_func, ref
7473
(2**15 - 1, 1, 'int16', 'int16', True, True),
7574
])
7675
@pytest.mark.forked
77-
def test_sanitize_int_add_overflow(x, y, x_dtype, y_dtype, debug, should_overflow):
76+
def test_sanitize_int_add_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, device):
7877

7978
@triton.jit
8079
def _kernel_add(X, Y, Z):
8180
tl.store(Z, tl.load(X) + tl.load(Y))
8281

83-
_test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, _kernel_add, lambda x, y: x + y)
82+
_test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, _kernel_add, lambda x, y: x + y, device)
8483

8584

8685
# mul overflow
@@ -95,13 +94,13 @@ def _kernel_add(X, Y, Z):
9594
(-2**30, 2, 'int32', 'int32', True, False),
9695
])
9796
@pytest.mark.forked
98-
def test_sanitize_int_mul_overflow(x, y, x_dtype, y_dtype, debug, should_overflow):
97+
def test_sanitize_int_mul_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, device):
9998

10099
@triton.jit
101100
def _kernel_mul(X, Y, Z):
102101
tl.store(Z, tl.load(X) * tl.load(Y))
103102

104-
_test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, _kernel_mul, lambda x, y: x * y)
103+
_test_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, _kernel_mul, lambda x, y: x * y, device)
105104

106105

107106
# sub overflow
@@ -115,10 +114,10 @@ def _kernel_mul(X, Y, Z):
115114
(-2**31, -1, 'int32', 'int32', True, False),
116115
])
117116
@pytest.mark.forked
118-
def test_sanitize_int_sub_overflow(x, y, x_dtype, y_dtype, debug, should_overflow):
117+
def test_sanitize_int_sub_overflow(x, y, x_dtype, y_dtype, debug, should_overflow, device):
119118

120119
@triton.jit
121120
def _kernel_sub(X, Y, Z):
122121
tl.store(Z, tl.load(X) - tl.load(Y))
123122

124-
_test_overflow(x, y, x_dtype, y_dtype, should_overflow, debug, _kernel_sub, lambda x, y: x - y)
123+
_test_overflow(x, y, x_dtype, y_dtype, should_overflow, debug, _kernel_sub, lambda x, y: x - y, device)

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
506506
// CHECK-NEXT: [[STRUCT2:%.*]] = llvm.insertvalue [[ARG0_1]], [[STRUCT1]][1]
507507
// CHECK-NEXT: [[T0:%.*]] = llvm.extractvalue [[STRUCT2]][0]
508508
// CHECK-NEXT: [[T1:%.*]] = llvm.extractvalue [[STRUCT2]][1]
509-
%0 = tt.reshape %arg {allow_reorder = true} : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2>
509+
%0 = tt.reshape %arg allow_reorder : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2>
510510
// CHECK: [[RES:%.*]] = llvm.mlir.undef : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
511511
// CHECK-NEXT: [[RES1:%.*]] = llvm.insertvalue [[T0]], [[RES]][0]
512512
// CHECK-NEXT: [[RES2:%.*]] = llvm.insertvalue [[T1]], [[RES1]][1]

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
357357
// CHECK: llvm.mlir.undef
358358
// CHECK: %[[T0:.*]] = llvm.extractvalue
359359
// CHECK: %[[T1:.*]] = llvm.extractvalue
360-
%0 = tt.reshape %arg {allow_reorder = true} : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2>
360+
%0 = tt.reshape %arg allow_reorder : tensor<256xf32, #blocked0> -> tensor<256x1xf32,#blocked2>
361361
// CHECK: llvm.mlir.undef
362362
// CHECK: llvm.insertvalue %[[T0]]
363363
// CHECK: llvm.insertvalue %[[T1]]

0 commit comments

Comments
 (0)