Skip to content

Commit 45bf460

Browse files
Merge commit '2b41842577ce7203f51d3e975c18983b5dafb5d2'
2 parents 66391a3 + 2b41842 commit 45bf460

File tree

13 files changed

+455
-94
lines changed

13 files changed

+455
-94
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ using namespace mlir::triton;
4646
#define fadd(...) rewriter.create<LLVM::FAddOp>(loc, __VA_ARGS__)
4747
#define mul(...) rewriter.create<LLVM::MulOp>(loc, __VA_ARGS__)
4848
#define fmul(...) rewriter.create<LLVM::FMulOp>(loc, __VA_ARGS__)
49+
#define fma(...) rewriter.create<LLVM::FMAOp>(loc, __VA_ARGS__)
50+
#define neg(...) rewriter.create<LLVM::FNegOp>(loc, __VA_ARGS__)
4951
#define smax(...) rewriter.create<LLVM::SMaxOp>(loc, __VA_ARGS__)
5052
#define umax(...) rewriter.create<LLVM::UMaxOp>(loc, __VA_ARGS__)
5153
#define fmax(...) rewriter.create<LLVM::MaxNumOp>(loc, __VA_ARGS__)

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 77 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1240,58 +1240,60 @@ def ReturnOp : TT_Op<"return", [Pure, HasParent<"FuncOp">, /*MemRefsNormalizable
12401240

12411241

12421242
def TT_ExperimentalDescriptorLoadOp : TT_Op<"experimental_descriptor_load", [MemoryEffects<[MemRead<GlobalMemory>]>]> {
1243-
let summary = "Load from descriptor";
1244-
let description = [{
1245-
This operation will be lowered to Nvidia TMA load operation on targets supporting it.
1246-
`desc` is a tensor descriptor object.
1247-
The destination tensor type and shape must match the descriptor otherwise the result is undefined.
1243+
let summary = "Load from descriptor";
1244+
let description = [{
1245+
This operation will be lowered to Nvidia TMA load operation on targets supporting it.
1246+
`desc` is a tensor descriptor object.
1247+
The destination tensor type and shape must match the descriptor otherwise the result is undefined.
12481248

1249-
This is an escape hatch and is only there for testing/experimenting.
1250-
This op will be removed in the future.
1251-
}];
1252-
let arguments = (
1253-
ins
1254-
TT_TensorDescType:$desc,
1255-
Variadic<I32>:$indices,
1256-
DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
1257-
DefaultValuedAttr<TT_EvictionPolicyAttr, "::mlir::triton::EvictionPolicy::NORMAL">:$evict
1258-
);
1249+
This is an escape hatch and is only there for testing/experimenting.
1250+
This op will be removed in the future.
1251+
}];
1252+
let arguments = (ins
1253+
TT_TensorDescType:$desc,
1254+
Variadic<I32>:$indices,
1255+
DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
1256+
DefaultValuedAttr<TT_EvictionPolicyAttr, "::mlir::triton::EvictionPolicy::NORMAL">:$evict
1257+
);
12591258

1260-
let results = (outs TT_Tensor:$result);
1259+
let results = (outs TT_Tensor:$result);
12611260

1262-
let assemblyFormat = [{
1263-
$desc `[` $indices `]`
1264-
oilist(
1265-
`cacheModifier` `=` $cache |
1266-
`evictionPolicy` `=` $evict
1267-
)
1268-
attr-dict `:` qualified(type($desc)) `->` type($result)
1269-
}];
1261+
let assemblyFormat = [{
1262+
$desc `[` $indices `]`
1263+
oilist(
1264+
`cacheModifier` `=` $cache |
1265+
`evictionPolicy` `=` $evict
1266+
)
1267+
attr-dict `:` qualified(type($desc)) `->` type($result)
1268+
}];
1269+
1270+
let hasVerifier = 1;
12701271
}
12711272

12721273
def TT_ExperimentalDescriptorStoreOp : TT_Op<"experimental_descriptor_store", [
12731274
MemoryEffects<[MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>,
12741275
]> {
1275-
let summary = "store value based on descriptor";
1276-
let description = [{
1277-
This operation will be lowered to Nvidia TMA store operation on targets supporting it.
1278-
`desc` is a tensor descriptor object.
1279-
The shape and types of `src` must match the descriptor otherwise the result is undefined.
1276+
let summary = "store value based on descriptor";
1277+
let description = [{
1278+
This operation will be lowered to Nvidia TMA store operation on targets supporting it.
1279+
`desc` is a tensor descriptor object.
1280+
The shape and types of `src` must match the descriptor otherwise the result is undefined.
12801281

1281-
This is an escape hatch and is only there for testing/experimenting.
1282-
This op will be removed in the future.
1283-
}];
1284-
let arguments = (
1285-
ins
1286-
TT_TensorDescType:$desc,
1287-
TT_Tensor:$src,
1288-
Variadic<I32>:$indices
1289-
);
1282+
This is an escape hatch and is only there for testing/experimenting.
1283+
This op will be removed in the future.
1284+
}];
1285+
let arguments = (ins
1286+
TT_TensorDescType:$desc,
1287+
TT_Tensor:$src,
1288+
Variadic<I32>:$indices
1289+
);
12901290

1291-
let assemblyFormat = [{
1292-
$desc `[` $indices `]` `,` $src
1293-
attr-dict `:` qualified(type($desc)) `,` type($src)
1294-
}];
1291+
let assemblyFormat = [{
1292+
$desc `[` $indices `]` `,` $src
1293+
attr-dict `:` qualified(type($desc)) `,` type($src)
1294+
}];
1295+
1296+
let hasVerifier = 1;
12951297
}
12961298

12971299
def TT_ExperimentalTensormapCreateOp: TT_Op<
@@ -1301,46 +1303,46 @@ def TT_ExperimentalTensormapCreateOp: TT_Op<
13011303
AttrSizedOperandSegments,
13021304
]
13031305
> {
1304-
let summary = "Create a new TMA descriptor on device";
1305-
let arguments = (
1306-
ins
1307-
TT_PtrType:$desc_ptr,
1308-
TT_PtrType:$global_address,
1309-
Variadic<I32>:$box_dim,
1310-
Variadic<I32>:$global_dim,
1311-
Variadic<I64>:$global_stride,
1312-
Variadic<I32>:$element_stride,
1313-
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<12>]>:$elem_type,
1314-
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<2>]>:$interleave_layout,
1315-
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$swizzle_mode,
1316-
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$fill_mode
1317-
);
1318-
let extraClassDeclaration = [{
1319-
int32_t getRank() {
1320-
return getBoxDim().size();
1321-
}
1322-
}];
1323-
let assemblyFormat = [{
1324-
$desc_ptr `,` $global_address `,`
1325-
`[` $box_dim `]` `,`
1326-
`[` $global_dim `]` `,`
1327-
`[` $global_stride `]` `,`
1328-
`[` $element_stride `]`
1329-
attr-dict `:` functional-type(operands, results)
1330-
}];
1306+
let summary = "Create a new TMA descriptor on device";
1307+
let arguments = (
1308+
ins
1309+
TT_PtrType:$desc_ptr,
1310+
TT_PtrType:$global_address,
1311+
Variadic<I32>:$box_dim,
1312+
Variadic<I32>:$global_dim,
1313+
Variadic<I64>:$global_stride,
1314+
Variadic<I32>:$element_stride,
1315+
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<12>]>:$elem_type,
1316+
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<2>]>:$interleave_layout,
1317+
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<3>]>:$swizzle_mode,
1318+
ConfinedAttr<I32Attr, [IntNonNegative, IntMaxValue<1>]>:$fill_mode
1319+
);
1320+
let extraClassDeclaration = [{
1321+
int32_t getRank() {
1322+
return getBoxDim().size();
1323+
}
1324+
}];
1325+
let assemblyFormat = [{
1326+
$desc_ptr `,` $global_address `,`
1327+
`[` $box_dim `]` `,`
1328+
`[` $global_dim `]` `,`
1329+
`[` $global_stride `]` `,`
1330+
`[` $element_stride `]`
1331+
attr-dict `:` functional-type(operands, results)
1332+
}];
13311333

1332-
let hasVerifier = 1;
1334+
let hasVerifier = 1;
13331335
}
13341336

13351337
def TT_ExperimentalTensormapFenceproxyAcquireOp: TT_Op<
13361338
"experimental_tensormap_fenceproxy_acquire",
13371339
[MemoryEffects<[MemWrite<GlobalMemory>]>]
13381340
> {
1339-
let summary = "Acquire fence on a tensormap object";
1340-
let arguments = (ins TT_PtrType:$desc_ptr);
1341-
let assemblyFormat = [{
1342-
$desc_ptr attr-dict `:` qualified(type($desc_ptr))
1343-
}];
1341+
let summary = "Acquire fence on a tensormap object";
1342+
let arguments = (ins TT_PtrType:$desc_ptr);
1343+
let assemblyFormat = [{
1344+
$desc_ptr attr-dict `:` qualified(type($desc_ptr))
1345+
}];
13441346
}
13451347

13461348

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3333
"TRITON_HIP_USE_BLOCK_PINGPONG",
3434
"TRITON_LLVM_DEBUG_ONLY",
3535
"TRITON_ENABLE_ASAN",
36+
"TRITON_OVERRIDE_NV_CAPABILITY",
3637
"USE_IR_LOC",
3738
"NVPTX_ENABLE_DUMP",
3839
"TRITON_INTEL_ADVANCED_PATH",

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,27 @@ LogicalResult GatherOp::inferReturnTypes(
11171117
return success();
11181118
}
11191119

1120+
// -- ExperimentalDesciptorLoadOp --
1121+
static LogicalResult verifyDesciptorLoadStoreType(Operation *op,
1122+
TensorDescType desc,
1123+
RankedTensorType tensor) {
1124+
RankedTensorType block = desc.getBlockType();
1125+
if (block.getShape() == tensor.getShape() &&
1126+
block.getElementType() == tensor.getElementType())
1127+
return success();
1128+
return op->emitOpError("tensor desciptor block and tensor types must match");
1129+
}
1130+
1131+
LogicalResult ExperimentalDescriptorLoadOp::verify() {
1132+
return verifyDesciptorLoadStoreType(*this, getDesc().getType(), getType());
1133+
}
1134+
1135+
// -- ExperimentalDesciptorStoreOp --
1136+
LogicalResult ExperimentalDescriptorStoreOp::verify() {
1137+
return verifyDesciptorLoadStoreType(*this, getDesc().getType(),
1138+
getSrc().getType());
1139+
}
1140+
11201141
// -- ExperimentalTensormapCreateOp --
11211142
LogicalResult ExperimentalTensormapCreateOp::verify() {
11221143
auto rank = getBoxDim().size();

python/test/unit/language/test_conversions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,8 @@ def upcast_test(src_dtype, dst_dtype, exponent_bits, mantissa_bits, exponent_bia
303303
])
304304
def test_typeconvert_upcast(src_dtype, dst_dtype, device):
305305

306-
# On HIP, fp8e4nv upcasting is only supported to bf16, and it's only supported on MI300.
307-
if src_dtype == 'float8e4nv' and is_hip() and (dst_dtype != 'bfloat16' or not is_hip_mi300()):
306+
# On HIP, fp8e4nv upcasting is only supported to bf16 and fp16, and it's only supported on MI300.
307+
if src_dtype == 'float8e4nv' and is_hip() and (dst_dtype != 'bfloat16' or dst_dtype != 'float16' or not is_hip_mi300()):
308308
pytest.skip(f"upcasting {src_dtype} to {dst_dtype} not supported in this architecture")
309309

310310
if ((src_dtype == 'float8e4nv' and is_cuda() and torch.cuda.get_device_capability(0) < (8, 9))

python/test/unit/language/test_core.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6136,6 +6136,37 @@ def mul_add(data):
61366136
assert found_fma == enable_fp_fusion
61376137

61386138

6139+
# -----------------------
6140+
# test override_nv_compute_capability
6141+
# -----------------------
6142+
6143+
6144+
@pytest.mark.parametrize("nv_compute_capability", [70, 80, 90])
6145+
@pytest.mark.parametrize("env_var_override", [False, True])
6146+
def test_override_nv_compute_capability(nv_compute_capability, env_var_override, device):
6147+
if not is_cuda():
6148+
pytest.xfail('test_override_nv_compute_capability only for CUDA')
6149+
6150+
@triton.jit
6151+
def simple(data, out):
6152+
in_ptrs = data + tl.arange(0, 128)
6153+
out_ptrs = out + tl.arange(0, 128)
6154+
tl.store(out_ptrs, tl.load(in_ptrs) * 1.5 + 1.0)
6155+
6156+
data = torch.randn((128, ), device=device, dtype=torch.float32)
6157+
out = torch.empty_like(data)
6158+
6159+
if env_var_override:
6160+
os.environ["TRITON_OVERRIDE_NV_CAPABILITY"] = str(nv_compute_capability)
6161+
h = simple[(1, )](data, out)
6162+
os.environ.pop("TRITON_OVERRIDE_NV_CAPABILITY")
6163+
else:
6164+
h = simple[(1, )](data, out, override_nv_compute_capability=nv_compute_capability)
6165+
torch.testing.assert_close(data * 1.5 + 1.0, out)
6166+
ttgir_cc = re.search(r'cuda:(\d+)', h.asm["ttgir"])
6167+
assert int(ttgir_cc.group(1)) == nv_compute_capability
6168+
6169+
61396170
# -----------------------
61406171
# test propagate_nan
61416172
# -----------------------

python/triton/language/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,7 +1265,7 @@ def __str__(self) -> str:
12651265
return f"tensor_descriptor<{self.type}>"
12661266

12671267
@builtin
1268-
def load(self, offsets: List[tensor], _builder=None) -> tensor:
1268+
def load(self, offsets: List[constexpr | tensor], _builder=None) -> tensor:
12691269
"""Load a block from the descriptor starting at the given element offsets.
12701270
12711271
Values outside of the tensor bounds will be filled with zeros.
@@ -1275,7 +1275,7 @@ def load(self, offsets: List[tensor], _builder=None) -> tensor:
12751275
return semantic.descriptor_load(self, offsets, "", "", _builder)
12761276

12771277
@builtin
1278-
def store(self, offsets: List[tensor], value: tensor, _builder=None) -> tensor:
1278+
def store(self, offsets: List[constexpr | tensor], value: tensor, _builder=None) -> tensor:
12791279
"""Store a block from the descriptor starting at the given element offsets.
12801280
12811281
Values outside of the tensor bounds will be ignored.

python/triton/language/semantic.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,7 +1147,7 @@ def reinterpret_tensor_descriptor(desc_ptr: tl.tensor, block_ty: tl.block_type,
11471147
return tl._experimental_tensor_descriptor_base(handle, block_ty)
11481148

11491149

1150-
def descriptor_load(desc: tl.tensor, offsets, cache_modifier: str, eviction_policy: str,
1150+
def descriptor_load(desc: tl._experimental_tensor_desciptor_base, offsets, cache_modifier: str, eviction_policy: str,
11511151
builder: ir.builder) -> tl.tensor:
11521152
assert isinstance(desc, tl._experimental_tensor_descriptor_base)
11531153
offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
@@ -1156,7 +1156,8 @@ def descriptor_load(desc: tl.tensor, offsets, cache_modifier: str, eviction_poli
11561156
return tl.tensor(x, desc.type)
11571157

11581158

1159-
def descriptor_store(desc: tl.tensor, value: tl.tensor, offsets, builder: ir.builder) -> tl.tensor:
1159+
def descriptor_store(desc: tl._experimental_tensor_descriptor_base, value: tl.tensor, offsets,
1160+
builder: ir.builder) -> tl.tensor:
11601161
assert isinstance(desc, tl._experimental_tensor_descriptor_base)
11611162
offsets = _convert_to_ir_values(builder, offsets, require_i64=False)
11621163
return tl.tensor(builder.create_descriptor_store(desc.handle, value.handle, offsets), tl.void)

test/Conversion/amd/math-denorm-handling.mlir

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=True" --convert-builtin-func-to-llvm | FileCheck %s --check-prefix=LLVM_FTZ
2-
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=False" --convert-builtin-func-to-llvm | FileCheck %s --check-prefix=LLVM_NO_FTZ
1+
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=True" | FileCheck %s --check-prefixes=COMMON,LLVM_FTZ
2+
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=False" | FileCheck %s --check-prefixes=COMMON,LLVM_NO_FTZ
33

44

55
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
@@ -16,7 +16,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
1616

1717
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
1818
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
19-
tt.func public @test_exp2(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} {
19+
tt.func public @test_exp(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} {
2020
// LLVM_FTZ: llvm.exp2.f32
2121
// LLVM_NO_FTZ: llvm.exp2.f32
2222
%0 = math.exp %arg0 : tensor<64xf32, #blocked>
@@ -35,3 +35,64 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.targ
3535
tt.return
3636
}
3737
}
38+
39+
// -----
40+
41+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
42+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
43+
tt.func public @test_sqrt_f32(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} {
44+
// LLVM_FTZ-LABEL: test_sqrt_f32
45+
// LLVM_FTZ-NOT: llvm.fcmp "ogt"
46+
// LLVM_FTZ: llvm.amdgcn.sqrt.f32
47+
// LLVM_FTZ-NOT: llvm.fmul
48+
// LLVM_FTZ-NOT: llvm.select
49+
//
50+
// LLVM_NO_FTZ-LABEL: test_sqrt_f32
51+
// LLVM_NO_FTZ: llvm.fcmp "ogt"
52+
// LLVM_NO_FTZ: llvm.fmul
53+
// LLVM_NO_FTZ-NEXT: llvm.select
54+
// LLVM_NO_FTZ-NEXT: llvm.amdgcn.sqrt.f32
55+
// LLVM_NO_FTZ: llvm.fmul
56+
// LLVM_NO_FTZ-NEXT: llvm.select
57+
%0 = math.sqrt %arg0 : tensor<64xf32, #blocked>
58+
tt.return
59+
}
60+
}
61+
62+
// -----
63+
64+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
65+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
66+
tt.func public @test_sqrt_rn_f32(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} {
67+
// LLVM_FTZ-LABEL: test_sqrt_rn_f32
68+
// LLVM_FTZ: llvm.amdgcn.rsq.f32
69+
// LLVM_FTZ: llvm.fmul
70+
// LLVM_FTZ: llvm.fmul
71+
// LLVM_FTZ: llvm.fneg
72+
// LLVM_FTZ: llvm.intr.fma
73+
// LLVM_FTZ-NEXT: llvm.intr.fma
74+
// LLVM_FTZ-NEXT: llvm.intr.fma
75+
// LLVM_FTZ-NEXT: llvm.fneg
76+
// LLVM_FTZ-NEXT: llvm.intr.fma
77+
// LLVM_FTZ-NEXT: llvm.intr.fma
78+
// LLVM_FTZ-NEXT: llvm.intr.is.fpclass
79+
// LLVM_FTZ-NEXT: llvm.select
80+
//
81+
// LLVM_NO_FTZ-LABEL: test_sqrt_rn_f32
82+
// LLVM_NO_FTZ: llvm.intr.sqrt
83+
%0 = tt.precise_sqrt %arg0 : tensor<64xf32, #blocked>
84+
tt.return
85+
}
86+
}
87+
88+
// -----
89+
90+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
91+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
92+
tt.func public @test_sqrt_rn_f64(%arg0: tensor<64xf64, #blocked>) attributes {noinline = false} {
93+
// COMMON-LABEL: test_sqrt_rn_f64
94+
// COMMON: llvm.intr.sqrt
95+
%0 = tt.precise_sqrt %arg0 : tensor<64xf64, #blocked>
96+
tt.return
97+
}
98+
}

0 commit comments

Comments
 (0)