Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bin/RegisterTritonDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
mlir::registerTritonAMDGPUReorderInstructions();
mlir::registerTritonAMDGPUStreamPipelineV2();
mlir::registerTritonAMDGPUCanonicalizePointers();
mlir::registerTritonAMDGPUConvertToBufferOps();

// TODO: register Triton & TritonGPU passes
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
Expand Down
2 changes: 2 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class TargetInfoBase {

virtual int getSharedAddressSpace() const = 0;

virtual bool supportVectorizedAtomics() const = 0;

virtual ~TargetInfoBase() {}
};
} // namespace mlir::triton
Expand Down
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace mlir::triton {
inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
// clang-format off
"AMDGCN_ENABLE_DUMP",
"AMDGCN_USE_BUFFER_OPS",
"DISABLE_FAST_REDUCTION",
"DISABLE_LLVM_OPT",
"DISABLE_MMA_V3",
Expand Down
80 changes: 44 additions & 36 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,27 +109,30 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
}

// FIXME [Dot LL]
// Do for all DotOperandEncodingAttr once we have LLs for all of them
static bool isSupportedDotOpLayout(Attribute layout) {
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
return mma.isAmpere() && dot.getKWidth() == 8;
}
if (isa<AMDMfmaEncodingAttr>(dot.getParent()))
return true;
}
return false;
};

LogicalResult
matchAndRewrite(LocalLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemDescType srcTy = op.getSrc().getType();
RankedTensorType dstTy = op.getType();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
// FIXME [Dot LL]
// Do for all DotOperandEncodingAttr once we have LLs for all of them
auto isAmpereLargeKWidth = [](Attribute layout) {
if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent())) {
return mma.isAmpere() && dot.getKWidth() == 8;
}
}
return false;
};
if (isa<SharedEncodingAttr>(srcLayout) &&
(isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
dstLayout) ||
isAmpereLargeKWidth(dstLayout))) {
isSupportedDotOpLayout(dstLayout))) {
return lowerSharedToDistributed(op, adaptor, getTypeConverter(),
rewriter);
}
Expand Down Expand Up @@ -167,10 +170,10 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
auto srcTy = op.getSrc().getType();
auto dstTy = op.getResult().getType();
auto dstShape = dstTy.getShape();
assert(dstShape.size() <= 2 &&
"Unexpected rank of ConvertLayout(shared->blocked)");
auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding());
auto dstLayout = dstTy.getEncoding();
assert((dstShape.size() <= 2 || isSupportedDotOpLayout(dstLayout)) &&
"Unexpected rank of ConvertLayout(shared->distributed)");
auto inOrd = getOrder(srcSharedLayout);

auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
Expand All @@ -184,31 +187,36 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
// FIXME [Dot LL]
// Ampere case
// In this case, we need to pack the outputs into i32
if (isa<DotOperandEncodingAttr>(dstTy.getEncoding())) {
if (elemLlvmTy.isInteger(8)) {
auto concat = [&](Value a1, Value a2, Value a3, Value a4) {
return or_(or_(zext(i32_ty, a1), shl(zext(i32_ty, a2), i32_val(8))),
or_(shl(zext(i32_ty, a3), i32_val(16)),
shl(zext(i32_ty, a4), i32_val(24))));
};
SmallVector<Value> outVals32(outVals.size() / 4);
for (int i = 0; i < outVals32.size(); ++i) {
outVals32[i] = concat(outVals[4 * i], outVals[4 * i + 1],
outVals[4 * i + 2], outVals[4 * i + 3]);
}
outVals = outVals32;
} else {
assert(elemLlvmTy.isBF16() && "Unexpected element type");
auto concat = [&](Value a, Value b) {
return or_(zext(i32_ty, bitcast(a, i16_ty)),
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
};
if (auto dotOp = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding())) {
if (auto parent = dyn_cast<NvidiaMmaEncodingAttr>(dotOp.getParent())) {
if (parent.isAmpere()) {
if (elemLlvmTy.isInteger(8)) {
auto concat = [&](Value a1, Value a2, Value a3, Value a4) {
return or_(
or_(zext(i32_ty, a1), shl(zext(i32_ty, a2), i32_val(8))),
or_(shl(zext(i32_ty, a3), i32_val(16)),
shl(zext(i32_ty, a4), i32_val(24))));
};
SmallVector<Value> outVals32(outVals.size() / 4);
for (int i = 0; i < outVals32.size(); ++i) {
outVals32[i] = concat(outVals[4 * i], outVals[4 * i + 1],
outVals[4 * i + 2], outVals[4 * i + 3]);
}
outVals = outVals32;
} else {
assert(elemLlvmTy.isBF16() && "Unexpected element type");
auto concat = [&](Value a, Value b) {
return or_(zext(i32_ty, bitcast(a, i16_ty)),
shl(zext(i32_ty, bitcast(b, i16_ty)), i32_val(16)));
};

SmallVector<Value> outVals32(outVals.size() / 2);
for (int i = 0; i < outVals32.size(); ++i) {
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
SmallVector<Value> outVals32(outVals.size() / 2);
for (int i = 0; i < outVals32.size(); ++i) {
outVals32[i] = concat(outVals[2 * i], outVals[2 * i + 1]);
}
outVals = outVals32;
}
}
outVals = outVals32;
}
}

Expand Down
3 changes: 2 additions & 1 deletion python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4026,10 +4026,11 @@ def _kernel(dst, src, CACHE: tl.constexpr):
amdgcn = pgm.asm['amdgcn']
cg_cache_modifier_str = 'nt'
cv_cache_modifier_str = 'sc0 sc1'
buffer_load_line = [line for line in amdgcn.splitlines() if "buffer_load" in line]
global_load_line = [line for line in amdgcn.splitlines() if "global_load" in line]
flat_load_line = [line for line in amdgcn.splitlines() if "flat_load" in line]
if cache == '' or cache == '.ca':
assert cg_cache_modifier_str not in global_load_line[0]
assert cg_cache_modifier_str not in (global_load_line[0] if global_load_line else buffer_load_line[0])
if cache == '.cg':
assert cg_cache_modifier_str in global_load_line[0]
if cache == '.cv':
Expand Down
12 changes: 12 additions & 0 deletions test/Conversion/amd/builtin_func_to_llvm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=True" --convert-builtin-func-to-llvm="ftz=True" | FileCheck %s --check-prefix=LLVM_FTZ
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm="arch=gfx942 ftz=False" --convert-builtin-func-to-llvm="ftz=False" | FileCheck %s --check-prefix=LLVM_NO_FTZ

#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "hip:gfx942", "triton_gpu.threads-per-warp" = 64 : i32} {
tt.func public @test_fast_expf(%arg0: tensor<64xf32, #blocked>) attributes {noinline = false} {
// LLVM_FTZ: llvm.amdgcn.exp2.f32
// LLVM_NO_FTZ: llvm.exp2.f32
%0 = tt.extern_elementwise %arg0 {libname = "libdevice", libpath = "", pure = true, symbol = "__triton_hip_fast_expf"} : (tensor<64xf32, #blocked>) -> tensor<64xf32, #blocked>
tt.return
}
}
15 changes: 8 additions & 7 deletions test/Conversion/amd/compute-base-ptr.mlir
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s
// RUN: triton-opt %s --split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 --mlir-print-debuginfo --mlir-pretty-debuginfo| FileCheck %s

#blocked = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = false}>
#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.shared = 544 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: @local_load_offset
tt.func @local_load_offset(%arg0: tensor<16x16xf16, #mma>) {
%0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked>
%1 = triton_gpu.local_alloc %0 {allocation.offset = 0 : i32} : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory>
%0 = triton_gpu.convert_layout %arg0 {allocation.offset = 0 : i32} : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #blocked> loc(#loc1)
%1 = triton_gpu.local_alloc %0 {allocation.offset = 0 : i32} : (tensor<16x16xf16, #blocked>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> loc(#loc2)
// This catches base ptr calculation in the computeBasePtr, checks if the gep has correct element type.
// CHECK: llvm.sub
// CHECK-NEXT: llvm.getelementptr
// CHECK-SAME: (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16
%2 = triton_gpu.local_load %1 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>>
// CHECK: llvm.getelementptr {{.*}} (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 local_load:3:0
%2 = triton_gpu.local_load %1 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 16}>> loc(#loc3)
tt.return
}
}
#loc1 = loc("conert_layout":1:0)
#loc2 = loc("local_alloc":2:0)
#loc3 = loc("local_load":3:0)
28 changes: 28 additions & 0 deletions test/Conversion/amd/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,31 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return
}
}

// -----

// Smoke test to check that mfma 32 and dot operand layouts can work with small tensors, for example with shape 16x16
#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}>
#dotop0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}>
#dotop1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth=4}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], hasLeadingOffset = false}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: small_mfma_tensor_conversions
tt.func public @small_mfma_tensor_conversions(%arg0: tensor<16x16xf16, #mfma>, %arg1: tensor<16x16x!tt.ptr<f32>, #mfma>) {
// CHECK-NOT: triton_gpu.convert_layout
%0 = triton_gpu.local_alloc %arg0 : (tensor<16x16xf16, #mfma>) -> !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory>
// CHECK-4: store {{.*}} vector<4xf16>
%1 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dotop0>
// CHECK-2: load {{.*}} vector<4xf16>
%2 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #dotop1>
// CHECK-8: load {{.*}} vector<1xf16>
%3 = triton_gpu.local_load %0 : !tt.memdesc<16x16xf16, #shared, #triton_gpu.shared_memory> -> tensor<16x16xf16, #mfma>
// CHECK-4: load {{.*}} vector<4xf16>
%4 = tt.fp_to_fp %3 : tensor<16x16xf16, #mfma> -> tensor<16x16xf32, #mfma>

%5 = tt.dot %1, %2, %4 : tensor<16x16xf16, #dotop0> * tensor<16x16xf16, #dotop1> -> tensor<16x16xf32, #mfma>
// Store result to prevent DCE from removing all conversion related code
%6 = triton_gpu.local_alloc %5 : (tensor<16x16xf32, #mfma>) -> !tt.memdesc<16x16xf32, #shared, #triton_gpu.shared_memory>
tt.return
}
}
2 changes: 1 addition & 1 deletion test/Conversion/tritongpu_to_llvm_hopper.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm=compute-capability=90 2>&1 | FileCheck %s
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=81' 2>&1 | FileCheck %s

#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>
#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
Expand Down
44 changes: 44 additions & 0 deletions test/Conversion/tritongpu_to_llvm_hopper_ptx80.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-gpu-to-llvm='compute-capability=90 ptx-version=80' 2>&1 | FileCheck %s

#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @atomic_add_f32_nomask(%dest_ptrs: tensor<256x!tt.ptr<f32>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>) attributes {noinline = false} {
// CHECK-LABEL: atomic_add_f32_nomask
// CHECK: atom.global.gpu.acq_rel.add.f32
// CHECK: atom.global.gpu.acq_rel.add.f32
// CHECK: atom.global.gpu.acq_rel.add.f32
// CHECK: atom.global.gpu.acq_rel.add.f32
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr<f32>, #blocked>, tensor<256xf32, #blocked>) -> tensor<256xf32, #blocked>
tt.return
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @atomic_add_f32_withmask(%dest_ptrs: tensor<256x!tt.ptr<f32>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 2 : i32}) attributes {noinline = false} {
// CHECK-LABEL: atomic_add_f32_withmask
// CHECK: atom.global.gpu.acq_rel.add.f32
// CHECK: atom.global.gpu.acq_rel.add.f32
// CHECK: atom.global.gpu.acq_rel.add.f32
// CHECK: atom.global.gpu.acq_rel.add.f32
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr<f32>, #blocked>, tensor<256xf32, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf32, #blocked>
tt.return
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr<f16>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 4 : i32}) attributes {noinline = false} {
// CHECK-LABEL: atomic_add_f16_withmask
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr<f16>, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked>
tt.return
}
}
40 changes: 40 additions & 0 deletions test/TritonGPU/amd/amd-canonicalize-pointers.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,46 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
//
// This is the same as conversion3, but now the `arith.extsi` operations
// disappeared and all the offsets are 32 bits.
//
// CHECK-LABEL: tt.func @conversion4
tt.func @conversion4(%arg0: !tt.ptr<f32>{tt.pointer_range = 32 : i32})-> tensor<1024xf32, #blocked>{
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>

//CHECK: %0 = tt.get_program_id x : i32
//CHECK: %[[pid:.*]] = arith.muli %0, {{.*}} : i32
//CHECK: %[[makerange:.*]] = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
//CHECK: %[[uniformOffset1:.*]] = arith.addi %[[pid]], {{.*}} : i32
//CHECK: %[[tensorOffset1:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked>
//CHECK: %[[uniformOffset0:.*]] = arith.addi %[[pid:.*]], %{{.*}} : i32
//CHECK: %[[tensorOffset3:.*]] = arith.addi %{{.*}}, %[[makerange]] : tensor<1024xi32, #blocked>
//CHECK: %[[zero:.*]] = tt.splat %{{.*}} : i32 -> tensor<1024xi32, #blocked>
//CHECK: %[[uniformPtr0:.*]] = tt.addptr %arg0, %[[uniformOffset0:.*]] : !tt.ptr<f32>, i32
//CHECK: %[[tensorOffset0:.*]]= arith.addi %[[tensorOffset3]], %[[zero]] : tensor<1024xi32, #blocked>
//CHECK: %[[uniformPtr1:.*]] = tt.addptr %[[uniformPtr0]], %[[uniformOffset1]] : !tt.ptr<f32>, i32
//CHECK: %[[tensorOffset2:.*]] = arith.addi %[[tensorOffset1]], %[[tensorOffset0]]: tensor<1024xi32, #blocked>
//CHECK: %[[scalarPtr:.*]] = tt.splat %[[uniformPtr1]] : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
//CHECK: %[[newPtr:.*]] = tt.addptr %[[scalarPtr]], %[[tensorOffset2]] : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
//CHECK: tt.load %[[newPtr]]
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
%7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
%8 = tt.load %7 : tensor<1024x!tt.ptr<f32>, #blocked>
tt.return %8 : tensor<1024xf32, #blocked>
}
}

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
// CHECK-LABEL: tt.func @forOp
Expand Down
Loading
Loading