Skip to content

Commit 262102c

Browse files
Merge commit '64fff0289fa9800e965ac661b38ac6a9dc0e6482'
2 parents f66eab1 + 64fff02 commit 262102c

File tree

22 files changed

+317
-394
lines changed

22 files changed

+317
-394
lines changed

include/triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,27 @@ class ElementwiseOpConversionBase : public ConvertOpToLLVMPattern<SourceOp> {
208208
ModuleAxisInfoAnalysis &axisAnalysisPass;
209209
};
210210

211+
// Trivial case where we map elementwise to an existing LLVM operator
212+
template <typename SourceOp, typename DestOp>
213+
struct ElementwiseOpConversion
214+
: public ElementwiseOpConversionBase<
215+
SourceOp, ElementwiseOpConversion<SourceOp, DestOp>> {
216+
using Base =
217+
ElementwiseOpConversionBase<SourceOp,
218+
ElementwiseOpConversion<SourceOp, DestOp>>;
219+
using Base::Base;
220+
using OpAdaptor = typename Base::OpAdaptor;
221+
222+
// An interface to support variant DestOp builder.
223+
SmallVector<DestOp> createDestOps(SourceOp op, OpAdaptor adaptor,
224+
ConversionPatternRewriter &rewriter,
225+
Type elemTy, MultipleOperandsRange operands,
226+
Location loc) const {
227+
return {rewriter.create<DestOp>(loc, elemTy, operands[0],
228+
adaptor.getAttributes().getValue())};
229+
}
230+
};
231+
211232
} // namespace gpu
212233

213234
} // namespace mlir::triton

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,32 @@
2929
using namespace mlir;
3030
using namespace mlir::triton;
3131

32+
namespace mlir::LLVM {
33+
using namespace mlir::triton;
34+
35+
Value createConstantI1(Location loc, OpBuilder &rewriter, bool v);
36+
Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v);
37+
Value createConstantI64(Location loc, OpBuilder &rewriter, int64_t v);
38+
Value createConstantF16(Location loc, OpBuilder &rewriter, float v);
39+
Value createConstantBF16(Location loc, OpBuilder &rewriter, float v);
40+
Value createConstantF32(Location loc, OpBuilder &rewriter, float v);
41+
Value createConstantF64(Location loc, OpBuilder &rewriter, double v);
42+
Value createNaNConstant(Location loc, OpBuilder &rewriter, Type type);
43+
Value createIndexConstant(OpBuilder &builder, Location loc,
44+
const TypeConverter *converter, int64_t value);
45+
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
46+
int64_t value);
47+
48+
LLVM::CallOp createLLVMCallOp(OpBuilder &builder, Location loc,
49+
LLVMFuncOp funcOp, ValueRange args);
50+
LLVM::CallIntrinsicOp
51+
createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic,
52+
TypeRange types, ValueRange args);
53+
} // namespace mlir::LLVM
54+
55+
// Is v an integer or floating-point scalar constant equal to 0?
56+
bool isConstantZero(Value v);
57+
3258
namespace mlir::triton {
3359

3460
// Returns CTA level thread idx
@@ -248,26 +274,15 @@ struct TritonLLVMOpBuilder {
248274
Value i1_val(int64_t val) { return int_val(1, val); }
249275
Value true_val() { return int_val(1, true); }
250276
Value false_val() { return int_val(1, false); }
251-
Value f16_val(float v) {
252-
auto type = type::f16Ty(builder->getContext());
253-
return builder->create<LLVM::ConstantOp>(loc, type,
254-
builder->getF16FloatAttr(v));
255-
}
256-
Value f32_val(float v) {
257-
auto type = type::f32Ty(builder->getContext());
258-
return builder->create<LLVM::ConstantOp>(loc, type,
259-
builder->getF32FloatAttr(v));
260-
}
261-
Value f64_val(double v) {
262-
auto type = type::f64Ty(builder->getContext());
263-
return builder->create<LLVM::ConstantOp>(loc, type,
264-
builder->getF64FloatAttr(v));
265-
}
277+
Value f16_val(float v) { return LLVM::createConstantF16(loc, *builder, v); }
278+
Value bf16_val(float v) { return LLVM::createConstantBF16(loc, *builder, v); }
279+
Value f32_val(float v) { return LLVM::createConstantF32(loc, *builder, v); }
280+
Value f64_val(double v) { return LLVM::createConstantF64(loc, *builder, v); }
266281
Value i8_val(int64_t val) { return int_val(8, val); }
267282
Value i16_val(int64_t val) { return int_val(16, val); }
268283
Value i32_val(int64_t val) { return int_val(32, val); }
269284
Value i64_val(int64_t val) { return int_val(64, val); }
270-
Value tid_val() { return getThreadId(*this->builder, loc); }
285+
Value tid_val() { return getThreadId(*builder, loc); }
271286

272287
Location loc;
273288
OpBuilder *builder;
@@ -375,24 +390,6 @@ LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
375390
namespace LLVM {
376391
using namespace mlir::triton;
377392

378-
Value createConstantI1(Location loc, OpBuilder &rewriter, bool v);
379-
Value createConstantI32(Location loc, OpBuilder &rewriter, int32_t v);
380-
Value createConstantI64(Location loc, OpBuilder &rewriter, int64_t v);
381-
Value createConstantF16(Location loc, OpBuilder &rewriter, float v);
382-
Value createConstantF32(Location loc, OpBuilder &rewriter, float v);
383-
Value createConstantF64(Location loc, OpBuilder &rewriter, double v);
384-
Value createNaNConstant(Location loc, OpBuilder &rewriter, Type type);
385-
Value createIndexConstant(OpBuilder &builder, Location loc,
386-
const TypeConverter *converter, int64_t value);
387-
Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
388-
int64_t value);
389-
390-
LLVM::CallOp createLLVMCallOp(OpBuilder &builder, Location loc,
391-
LLVMFuncOp funcOp, ValueRange args);
392-
LLVM::CallIntrinsicOp
393-
createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic,
394-
TypeRange types, ValueRange args);
395-
396393
// Is v an integer or floating-point scalar constant equal to 0?
397394
bool isConstantZero(Value v);
398395

lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -215,26 +215,6 @@ struct ExternElementwiseOpConversion
215215
}
216216
};
217217

218-
template <typename SourceOp, typename DestOp>
219-
struct ElementwiseOpConversion
220-
: public ElementwiseOpConversionBase<
221-
SourceOp, ElementwiseOpConversion<SourceOp, DestOp>> {
222-
using Base =
223-
ElementwiseOpConversionBase<SourceOp,
224-
ElementwiseOpConversion<SourceOp, DestOp>>;
225-
using Base::Base;
226-
using OpAdaptor = typename Base::OpAdaptor;
227-
228-
// An interface to support variant DestOp builder.
229-
SmallVector<DestOp> createDestOps(SourceOp op, OpAdaptor adaptor,
230-
ConversionPatternRewriter &rewriter,
231-
Type elemTy, MultipleOperandsRange operands,
232-
Location loc) const {
233-
return {rewriter.create<DestOp>(loc, elemTy, operands[0],
234-
adaptor.getAttributes().getValue())};
235-
}
236-
};
237-
238218
struct ElementwiseInlineAsmOpConversion
239219
: public ConvertOpToLLVMPattern<ElementwiseInlineAsmOp> {
240220
using Base = ConvertOpToLLVMPattern<ElementwiseInlineAsmOp>;

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,15 @@ Value createConstantF16(Location loc, OpBuilder &rewriter, float v) {
505505
rewriter.getF16FloatAttr(v));
506506
}
507507

508+
Value createConstantBF16(Location loc, OpBuilder &rewriter, float v) {
509+
APFloat apf(v);
510+
bool ignored;
511+
apf.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &ignored);
512+
auto type = type::bf16Ty(rewriter.getContext());
513+
auto attr = FloatAttr::get(type, apf);
514+
return rewriter.create<LLVM::ConstantOp>(loc, type, attr);
515+
}
516+
508517
Value createConstantF32(Location loc, OpBuilder &rewriter, float v) {
509518
auto type = type::f32Ty(rewriter.getContext());
510519
return rewriter.create<LLVM::ConstantOp>(loc, type,

python/test/unit/language/test_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6466,7 +6466,7 @@ def kernel(x_ptr, min_ptr, max_ptr, out_ptr, ref_ptr, N, BLOCK_SIZE: tl.constexp
64666466
# Test for symmetric clamp(x, -limit, limit), as it may go through optimized
64676467
# codegen in the backends
64686468
@pytest.mark.interpreter
6469-
@pytest.mark.parametrize("dtype", ['float16', 'float32'])
6469+
@pytest.mark.parametrize("dtype", ['bfloat16', 'float16', 'float32'])
64706470
def test_clamp_symmetric(dtype, device):
64716471

64726472
@triton.jit
@@ -6545,7 +6545,7 @@ def test_tl_range(device):
65456545
if capability[0] >= 8:
65466546
ptx = pgm.asm['ptx']
65476547
# check that the loop got pipelined with the right number of stages.
6548-
assert 'cp.async.wait_group 0x6' in ptx
6548+
assert 'cp.async.wait_group 6' in ptx
65496549

65506550

65516551
@triton.jit(noinline=True)

test/Conversion/tma_to_llvm.mlir

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ tt.func @tma_gather_simple(%arg0: !tt.ptr<i8>, %arg1: !ttg.memdesc<1xi64, #share
2222
// CHECK: [[WIDX:%.*]] = lshr i32 [[TIDX]], 5
2323
// CHECK: [[WARP_ID:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[WIDX]],
2424

25-
// CHECK: [[ELECT:%.*]] = tail call i1 asm "elect.sync
26-
// CHECK: [[PRED:%.*]] = and i1 %5, [[ELECT]]
25+
// CHECK: [[ELECT:%.*]] = tail call { i32, i1 } @llvm.nvvm.elect.sync
26+
// CHECK: [[ELECT_PRED:%.*]] = extractvalue { i32, i1 } [[ELECT]], 1
27+
// CHECK: [[PRED:%.*]] = and i1 %5, [[ELECT_PRED]]
2728

2829
// CHECK: [[IDX0:%.*]] = extractvalue {{.*}} %2, 0
2930
// CHECK: [[IDX1:%.*]] = extractvalue {{.*}} %2, 1
@@ -142,8 +143,9 @@ tt.func @tma_gather_redundant_warps(%arg0: !tt.ptr<i8>, %arg1: !ttg.memdesc<1xi6
142143
// CHECK: [[WARP_SELECT:%.*]] = and i32 [[WARP_ID]], 2
143144
// CHECK: [[WARP_PRED:%.*]] = icmp eq i32 [[WARP_SELECT]], 0
144145
// CHECK: [[PRED_TMP:%.*]] = and i1 %5, [[WARP_PRED]]
145-
// CHECK: [[ELECT:%.*]] = tail call i1 asm "elect.sync
146-
// CHECK: [[PRED:%.*]] = and i1 [[ELECT]], [[PRED_TMP]]
146+
// CHECK: [[ELECT:%.*]] = tail call { i32, i1 } @llvm.nvvm.elect.sync
147+
// CHECK: [[ELECT_PRED:%.*]] = extractvalue { i32, i1 } [[ELECT]], 1
148+
// CHECK: [[PRED:%.*]] = and i1 [[ELECT_PRED]], [[PRED_TMP]]
147149

148150
// CHECK-COUNT-8: cp.async.bulk.tensor{{.*}}(i1 [[PRED]],
149151
ttng.async_tma_gather %arg0[%arg2, %arg3] %arg4, %arg1, %arg5 : !tt.ptr<i8>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked2}>>, i32, !ttg.memdesc<1xi64, #shared, #smem, mutable>, !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>, i1
@@ -158,14 +160,15 @@ tt.func @tma_scatter(%arg0: !tt.ptr<i8>, %arg1: tensor<32xi32, #ttg.slice<{dim =
158160
// with `async_tma_gather`, so we don't need to re-test the indexing logic.
159161

160162
// CHECK: [[BASE_PTR:%.*]] = extractvalue {{.*}} %3, 0
161-
// CHECK: [[PRED:%.*]] = tail call i1 asm "elect.sync
163+
// CHECK: [[ELECT:%.*]] = tail call { i32, i1 } @llvm.nvvm.elect.sync
164+
// CHECK: [[PRED:%.*]] = extractvalue { i32, i1 } [[ELECT]], 1
162165

163166
// CHECK: [[PTR:%.*]] = getelementptr {{.*}} [[BASE_PTR]]
164167
// CHECK-NEXT: "@$0 cp.async.bulk.tensor.2d.tile::scatter4.global.shared::cta.bulk_group [$1, {$2, $3, $4, $5, $6}], [$7];"
165168
// CHECK-SAME: (i1 [[PRED]], ptr addrspace(1) %0, i32 %2, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, i32 {{%[0-9]+}}, ptr addrspace(3) [[PTR]])
166169
ttng.async_tma_scatter %arg0[%arg1, %arg2] %arg3 : !tt.ptr<i8>, tensor<32xi32, #ttg.slice<{dim = 0, parent = #blocked}>>, i32, !ttg.memdesc<32x128xbf16, #shared1, #smem, mutable>
167170

168-
// CHECK: asm sideeffect "cp.async.bulk.commit_group ;", ""()
171+
// CHECK: call void @llvm.nvvm.cp.async.commit.group()
169172

170173
// CHECK-NEXT: ret void
171174
tt.return

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
479479
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
480480
// CHECK-LABEL: basic_program_id
481481
tt.func @basic_program_id() {
482-
// CHECK: llvm.inline_asm asm_dialect = att operand_attrs = [] "mov.u32 $0, %ctaid.x;", "=r" : () -> i32
482+
// CHECK: llvm.call_intrinsic "llvm.nvvm.read.ptx.sreg.ctaid.x"() : () -> i32
483483
%0 = tt.get_program_id x : i32
484484
tt.return
485485
}
@@ -553,7 +553,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
553553
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
554554
// CHECK-LABEL: basic_async_wait
555555
tt.func @basic_async_wait() {
556-
// CHECK: cp.async.wait_group 0x4
556+
// CHECK: nvvm.cp.async.wait.group 4
557557
ttg.async_wait {num = 4: i32}
558558
tt.return
559559
}
@@ -588,7 +588,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
588588
// CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8
589589
// CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8
590590
// CHECK: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x8, 0x8
591-
// CHECK: cp.async.commit_group
591+
// CHECK: nvvm.cp.async.commit.group
592592
%73 = ttg.async_copy_global_to_local %66, %subview : tensor<64x!tt.ptr<i64>, #slice1d0> -> !ttg.memdesc<64xi64, #shared1D, #smem, mutable>
593593
ttg.async_commit_group %73
594594
tt.return
@@ -628,8 +628,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
628628

629629
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.cg.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x10, 0x10;"
630630
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "cp.async.cg.shared.global [ ${{.*}} + 16 ], [ ${{.*}} + 0 ], 0x10, 0x10;"
631-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att
632-
// CHECK-SAME: cp.async.commit_group
631+
// CHECK: nvvm.cp.async.commit.group
633632
%a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<16x64x!tt.ptr<f32>, #AL> -> !ttg.memdesc<16x64xf32, #A, #smem, mutable>
634633
ttg.async_commit_group
635634
tt.return
@@ -675,8 +674,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
675674
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
676675
// CHECK: llvm.inline_asm
677676
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
678-
// CHECK: llvm.inline_asm
679-
// CHECK-SAME: cp.async.commit_group
677+
// CHECK: nvvm.cp.async.commit.group
680678
%a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<16x32x!tt.ptr<f32>, #AL> -> !ttg.memdesc<16x32xf32, #A, #smem, mutable>
681679
ttg.async_commit_group
682680
tt.return
@@ -732,8 +730,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
732730
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
733731
// CHECK: llvm.inline_asm
734732
// CHECK-SAME: cp.async.ca.shared.global [ ${{.*}} + 0 ], [ ${{.*}} + 0 ], 0x4, 0x4
735-
// CHECK: llvm.inline_asm
736-
// CHECK-SAME: cp.async.commit_group
733+
// CHECK: nvvm.cp.async.commit.group
737734
%a = ttg.async_copy_global_to_local %a_ptr, %tensor : tensor<32x32x!tt.ptr<f32>, #AL> -> !ttg.memdesc<32x32xf32, #A, #smem, mutable>
738735
ttg.async_commit_group
739736
tt.return

test/Conversion/tritongpu_to_llvm_blackwell.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
1212
// CHECK: %[[P0:.+]] = llvm.icmp "eq" %[[WID]], %[[C0]] : i32
1313
// CHECK: %[[P1:.+]] = llvm.and %{{.*}}, %[[P0]] : i1
1414
// CHECK: llvm.cond_br %[[P1]]
15-
// CHECK: %[[E:.+]] = llvm.inline_asm asm_dialect = att operand_attrs = [] "elect.sync _|$0, 0xffffffff;", "=b" : () -> i1
15+
// CHECK: %[[E:.+]] = nvvm.elect.sync -> i1
1616
// CHECK-COUNT-8: @$5 tcgen05.mma.cta_group::1.kind::f16 [ $0 + 0 ], $1, $2, $3, $4;", "r,l,l,r,b,b" %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %[[E]]
1717
// CHECK: @$0 tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [$1];", "b,l" %[[E]]
1818
tt.func @tc_gen5_mma(%a: !ttg.memdesc<128x128xf16, #shared, #ttg.shared_memory>,

test/Conversion/tritongpu_to_llvm_hopper.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
182182
%cst = arith.constant dense<0.000000e+00> : tensor<1024xf32, #blocked>
183183
%neg_limit = arith.subf %cst, %limit : tensor<1024xf32, #blocked>
184184

185-
// CHECK: "min.xorsign.abs.f32 $0, $1, $2;", "=f,f,f"
185+
// CHECK-COUNT-8: nvvm.fmin.xorsign.abs.f
186186
%12 = tt.clampf %x, %neg_limit, %limit, propagateNan = none : tensor<1024xf32, #blocked>
187187
tt.return
188188
}
@@ -208,12 +208,12 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
208208
#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
209209
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
210210
// CHECK-LABEL: cvt_mma_to_dot_fp8
211-
// CHECK: prmt.b32
212-
// CHECK: prmt.b32
211+
// CHECK: nvvm.prmt
212+
// CHECK: nvvm.prmt
213213
// CHECK: nvvm.shfl.sync
214214
// CHECK: nvvm.shfl.sync
215-
// CHECK: prmt.b32
216-
// CHECK: prmt.b32
215+
// CHECK: nvvm.prmt
216+
// CHECK: nvvm.prmt
217217
tt.func @cvt_mma_to_dot_fp8(%a: tensor<128x64xf8E5M2, #mma>) {
218218
%opA = ttg.convert_layout %a : tensor<128x64xf8E5M2, #mma> -> tensor<128x64xf8E5M2, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>
219219
tt.return

0 commit comments

Comments
 (0)