Skip to content

Commit c4201fa

Browse files
authored
Use 2D block loads for post-DPAS chained ops (#3000)
Use the 2D block load for tensor pointer loads where the layout is a DPAS layout but the result of the load is not directly used in the DPAS computation.
1 parent 6e466f4 commit c4201fa

File tree

2 files changed

+227
-23
lines changed

2 files changed

+227
-23
lines changed

test/TritonIntelGPU/blockptr_load.mlir

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,67 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}
2828

2929
// -----
3030

31+
// CHECK-DAG: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
32+
// CHECK-DAG: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
33+
// CHECK-DAG: llvm.func spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x1cPU3AS1viiiDv2_iPj(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
34+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 4], order = [1, 0]}>
35+
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
36+
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}>
37+
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
38+
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} {
39+
tt.func public @matmul_no_scf_with_add_kernel(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f32>, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg8: i64) {
40+
%C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #dpas>
41+
%c0_i32 = arith.constant 0 : i32
42+
%c1_i64 = arith.constant 1 : i64
43+
%ptrA = tt.make_tensor_ptr %arg0, [%arg3, %arg5], [%arg6, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #dot0>>
44+
%ptrB = tt.make_tensor_ptr %arg1, [%arg5, %arg4], [%arg8, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16, #dot1>>
45+
// CHECK-COUNT-2: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
46+
// CHECK-COUNT-2: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x1cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
47+
// CHECK-COUNT-8: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f({{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
48+
%A = tt.load %ptrA {boundaryCheck = array<i32: 1>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #dot0>>
49+
%B = tt.load %ptrB {boundaryCheck = array<i32: 0>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x64xf16, #dot1>>
50+
%D = tt.dot %A, %B, %C, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x64xf16, #dot1> -> tensor<64x64xf32, #dpas>
51+
%ptrX = tt.make_tensor_ptr %arg2, [%arg3, %arg4], [%arg8, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf32, #dpas>>
52+
// CHECK-COUNT-4: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_32b_8r16x1cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
53+
%X = tt.load %ptrX {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x64xf32, #dpas>>
54+
// CHECK-COUNT-32: llvm.fadd {{.*}}, {{.*}}
55+
%0 = arith.addf %D, %X : tensor<64x64xf32, #dpas>
56+
tt.return
57+
}
58+
}
59+
60+
// -----
61+
62+
// CHECK-DAG: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
63+
// CHECK-DAG: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
64+
// CHECK-DAG: llvm.func spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x1cPU3AS1viiiDv2_iPj(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
65+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 4], order = [1, 0]}>
66+
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
67+
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}>
68+
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
69+
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} {
70+
tt.func public @matmul_no_scf_with_add_transpose_kernel(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f32>, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg8: i64) {
71+
%C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #dpas>
72+
%c0_i32 = arith.constant 0 : i32
73+
%c1_i64 = arith.constant 1 : i64
74+
%ptrA = tt.make_tensor_ptr %arg0, [%arg3, %arg5], [%arg6, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #dot0>>
75+
%ptrB = tt.make_tensor_ptr %arg1, [%arg5, %arg4], [%arg8, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16, #dot1>>
76+
// CHECK-COUNT-2: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
77+
// CHECK-COUNT-2: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x1cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
78+
// CHECK-COUNT-8: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f({{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
79+
%A = tt.load %ptrA {boundaryCheck = array<i32: 1>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #dot0>>
80+
%B = tt.load %ptrB {boundaryCheck = array<i32: 0>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x64xf16, #dot1>>
81+
%D = tt.dot %A, %B, %C, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x64xf16, #dot1> -> tensor<64x64xf32, #dpas>
82+
%ptrX = tt.make_tensor_ptr %arg2, [%arg3, %arg4], [%arg8, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf32, #dpas>>
83+
// CHECK-NOT: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_32b_8r16x1cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
84+
%X = tt.load %ptrX {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "column_major"} : !tt.ptr<tensor<64x64xf32, #dpas>>
85+
%0 = arith.addf %D, %X : tensor<64x64xf32, #dpas>
86+
tt.return
87+
}
88+
}
89+
90+
// -----
91+
3192
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}>
3293
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
3394
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=1}>

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 166 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,9 @@ struct LoadOpConversion
499499
auto tensorType = cast<RankedTensorType>(resultType);
500500

501501
// Only lower loadOp with dpas layout encoding.
502-
if (!hasDotDpasEncoding(tensorType))
502+
auto encoding = tensorType.getEncoding();
503+
const bool hasDpasLayout = isa<DpasEncodingAttr>(encoding);
504+
if (!hasDpasLayout && !hasDotDpasEncoding(tensorType))
503505
return failure();
504506

505507
Attribute blockIOAttr =
@@ -514,20 +516,24 @@ struct LoadOpConversion
514516
"Only row_major or column_major is supported");
515517
const bool memoryRowMajor = (memoryLayoutInfo == "row_major");
516518

517-
DotOperandEncodingAttr dotLayout = getDotEncoding(tensorType).value();
518-
auto dotOrder = dotLayout.getThreadOrder();
519-
size_t rank = dotOrder.size();
520-
const bool valueRowMajor =
521-
(dotOrder[rank - 2] == 1 && dotOrder[rank - 1] == 0);
522-
assert((valueRowMajor ||
523-
(dotOrder[rank - 2] == 0 && dotOrder[rank - 1] == 1)) &&
524-
"Only row_major or column_major is allowed");
525-
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
526-
527-
auto dpasLayout = cast<DpasEncodingAttr>(dotLayout.getParent());
519+
auto getOpIdx = [&]() -> DpasEncodingAttr::OpIdx {
520+
if (hasDpasLayout) {
521+
return DpasEncodingAttr::OpIdx::OperandC;
522+
} else {
523+
auto dotLayout = getDotEncoding(tensorType).value();
524+
return static_cast<DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx());
525+
}
526+
};
527+
auto opIdx = getOpIdx();
528528

529-
auto opIdx = static_cast<DpasEncodingAttr::OpIdx>(dotLayout.getOpIdx());
530529
Type eltTy = tensorType.getElementType();
530+
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
531+
532+
auto dpasLayout = hasDpasLayout
533+
? cast<DpasEncodingAttr>(encoding)
534+
: cast<DpasEncodingAttr>(
535+
getDotEncoding(tensorType).value().getParent());
536+
531537
const ArrayRef<int64_t> tensorShape = tensorType.getShape();
532538
unsigned numElems = getTotalElemsPerThread(resultType);
533539
SmallVector<int64_t> numReps =
@@ -543,6 +549,143 @@ struct LoadOpConversion
543549
SmallVector<Value> multiDimWarpId =
544550
delinearize(rewriter, loc, warpId, warpsPerCTA, dpasOrder);
545551

552+
if (hasDpasLayout) {
553+
// A block load with the DPAS layout but without the DotDpasLayout is
554+
// expected to follow the ordering of the DPAS output. For a 2D block
555+
// load, the rows are distributed across work items/SIMD lanes and the
556+
// column vectors are available for each work item to process. This layout
557+
// aligns to the DPAS layout as the DPAS operation output layout
558+
// distributes rows across work items.
559+
560+
size_t rank = dpasOrder.size();
561+
const bool valueRowMajor =
562+
(dpasOrder[rank - 2] == 1 && dpasOrder[rank - 1] == 0);
563+
assert((valueRowMajor ||
564+
(dpasOrder[rank - 2] == 0 && dpasOrder[rank - 1] == 1)) &&
565+
"Only row_major or column_major is allowed");
566+
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
567+
568+
if (isTransposeRequired) {
569+
// TODO: this would likely require a shuffle to match the expected
570+
// ordering coming out of the DPAS layout and requires more
571+
// investigation
572+
return failure();
573+
}
574+
575+
MLIRContext *ctx = rewriter.getContext();
576+
577+
Value elemSizeInBytes = i32_val(elemSizeInBits / 8);
578+
579+
SmallVector<unsigned> elemsPerInstr = dpasLayout.getDPASInstShapeC();
580+
int64_t elemsPerLane = product<unsigned>(elemsPerInstr) / threadsPerWarp;
581+
Type load2DGenXType =
582+
LLVM::getFixedVectorType(IntegerType::get(ctx, elemSizeInBits),
583+
elemsPerLane); // make it opaque type.
584+
585+
auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX,
586+
offsetBaseY] =
587+
getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter);
588+
baseWidth = trunc(i32_ty, baseWidth);
589+
baseHeight = trunc(i32_ty, baseHeight);
590+
591+
auto pitch = trunc(i32_ty, rowStride);
592+
593+
SmallVector<unsigned> repClusterShape = dpasLayout.getShapeC();
594+
unsigned outerDimWarpNum =
595+
std::min<unsigned>(warpsPerCTA[rank - 2],
596+
mlir::ceil<unsigned>(tensorShape[rank - 2],
597+
repClusterShape[rank - 2]));
598+
unsigned innerDimWarpNum =
599+
std::min<unsigned>(warpsPerCTA[rank - 1],
600+
mlir::ceil<unsigned>(tensorShape[rank - 1],
601+
repClusterShape[rank - 1]));
602+
Value outerDimWarpId =
603+
urem(multiDimWarpId[rank - 2], i32_val(outerDimWarpNum));
604+
Value innerDimWarpId =
605+
urem(multiDimWarpId[rank - 1], i32_val(innerDimWarpNum));
606+
int64_t numRepOuter = numReps[1];
607+
int64_t numRepInner = numReps[2];
608+
609+
std::array<unsigned, 2> replicaStride = {
610+
outerDimWarpNum * repClusterShape[rank - 2],
611+
innerDimWarpNum * repClusterShape[rank - 1]};
612+
std::array<unsigned, 2> warpStride = {repClusterShape[rank - 2],
613+
repClusterShape[rank - 1]};
614+
615+
Value dimWarpId0 = mul(outerDimWarpId, i32_val(warpStride[0]));
616+
Value dimWarpId1 = mul(innerDimWarpId, i32_val(warpStride[1]));
617+
Value warpId0Offset = add(dimWarpId0, offsetBaseY);
618+
Value warpId1Offset = add(dimWarpId1, offsetBaseX);
619+
620+
ArrayRef<unsigned> repCluster = dpasLayout.getRepCluster();
621+
unsigned valOffset = 0;
622+
623+
SmallVector<Value> unpackedLoadedVals;
624+
625+
for (int m = 0; m < numRepOuter; ++m) {
626+
for (int n = 0; n < numRepInner; ++n) {
627+
for (int repM = 0; repM < repCluster[0]; ++repM) {
628+
629+
Value offsetY =
630+
add(warpId0Offset,
631+
i32_val(m * replicaStride[0] + repM * elemsPerInstr[0]));
632+
for (int repN = 0; repN < repCluster[1]; ++repN) {
633+
Value offsetX =
634+
add(warpId1Offset,
635+
i32_val(n * replicaStride[1] + repN * elemsPerInstr[1]));
636+
637+
auto load2dOp = rewriter.create<TritonGEN::Matrix2DBlockLoadOp>(
638+
loc, load2DGenXType,
639+
/*ptr*/ base,
640+
/*base_width*/ mul(baseWidth, elemSizeInBytes),
641+
/*base_height*/ baseHeight,
642+
/*base_pitch*/ mul(pitch, elemSizeInBytes),
643+
/*x*/ trunc(i32_ty, offsetX),
644+
/*y*/ trunc(i32_ty, offsetY),
645+
/*elem_size_in_bits*/ elemSizeInBits,
646+
/*tile_width*/ elemsPerInstr[1],
647+
/*tile_height*/ elemsPerInstr[0],
648+
/*v_blocks*/ 1,
649+
/*transpose*/ false,
650+
/*vnni_transform*/ false);
651+
if (failed(load2dOp.verify())) {
652+
// Explicitly invoke verifier because `triton_gen` ops are
653+
// immediately lowered further to a builtin call.
654+
return failure();
655+
}
656+
657+
Value ret = bitcast(
658+
load2dOp, LLVM::getFixedVectorType(eltTy, elemsPerLane));
659+
660+
for (size_t i = 0; i < elemsPerLane; i++) {
661+
Value loaded = extract_element(eltTy, ret, i32_val(i));
662+
unpackedLoadedVals.push_back(loaded);
663+
}
664+
}
665+
}
666+
}
667+
}
668+
669+
TritonGPUToLLVMTypeConverter *typeConverter = getTypeConverter();
670+
Type llvmResultStructTy = typeConverter->convertType(op.getType());
671+
Value resultStruct = packLLElements(
672+
loc, typeConverter, unpackedLoadedVals, rewriter, llvmResultStructTy);
673+
rewriter.replaceOp(op, {resultStruct});
674+
675+
return success();
676+
}
677+
678+
DotOperandEncodingAttr dotLayout = getDotEncoding(tensorType).value();
679+
auto dotOrder = dotLayout.getThreadOrder();
680+
681+
size_t rank = dotOrder.size();
682+
const bool valueRowMajor =
683+
(dotOrder[rank - 2] == 1 && dotOrder[rank - 1] == 0);
684+
assert((valueRowMajor ||
685+
(dotOrder[rank - 2] == 0 && dotOrder[rank - 1] == 1)) &&
686+
"Only row_major or column_major is allowed");
687+
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
688+
546689
bool isOperandA = (opIdx == DpasEncodingAttr::OpIdx::OperandA);
547690
SmallVector<unsigned> dpasInstShape = isOperandA
548691
? dpasLayout.getDPASInstShapeA()
@@ -573,11 +716,11 @@ struct LoadOpConversion
573716
// input operands to DPAS.
574717
// TODO: add support for int4 and int2.
575718
unsigned opsPerChannel = dpasLayout.getOpsPerChannel();
576-
unsigned elemBits = eltTy.getIntOrFloatBitWidth();
577-
if ((opsPerChannel == 4 && elemBits == 8) ||
578-
(opsPerChannel == 2 && elemBits == 16) ||
579-
(opsPerChannel == 1 && elemBits == 32)) {
580-
loadResultElemType = (isOperandA && elemBits != 32) ? i16_ty : i32_ty;
719+
if ((opsPerChannel == 4 && elemSizeInBits == 8) ||
720+
(opsPerChannel == 2 && elemSizeInBits == 16) ||
721+
(opsPerChannel == 1 && elemSizeInBits == 32)) {
722+
loadResultElemType =
723+
(isOperandA && elemSizeInBits != 32) ? i16_ty : i32_ty;
581724
packedElemsPerLanePerDPASInst =
582725
isOperandA ? elemsPerLanePerDPASInst / (opsPerChannel == 4 ? 2 : 1)
583726
: elemsPerLanePerDPASInst / opsPerChannel;
@@ -651,7 +794,7 @@ struct LoadOpConversion
651794

652795
// PVC 2D load supports 64 bytes per row at most. Load multiple dot operands
653796
// by enlarging the vBlocks.
654-
unsigned totalBytesPerRowPerDPASOp = tileWidth * elemBits / 8;
797+
unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8;
655798
numOperandsPer2DloadN =
656799
std::min(numOperandsPer2DloadN, 64 / totalBytesPerRowPerDPASOp);
657800
vBlocks = numOperandsPer2DloadN;
@@ -695,12 +838,12 @@ struct LoadOpConversion
695838
baseWidth = trunc(i32_ty, baseWidth);
696839
baseHeight = trunc(i32_ty, baseHeight);
697840

698-
unsigned originalElemBits = elemBits;
841+
const unsigned originalElemBits = elemSizeInBits;
699842
if (isTransposeRequired) {
700843
// adjust the block io parameter to align HW's limitations on
701844
// transposing load.
702845
tileWidth = tileWidth / (32 / originalElemBits);
703-
elemBits = 32;
846+
elemSizeInBits = 32;
704847
}
705848
Value elemSizeInBytes = i32_val(originalElemBits / 8);
706849

@@ -747,14 +890,14 @@ struct LoadOpConversion
747890
/*base_pitch*/ mul(pitch, elemSizeInBytes),
748891
/*x*/ trunc(i32_ty, offsetX),
749892
/*y*/ trunc(i32_ty, offsetY),
750-
/*elem_size_in_bits*/ elemBits,
893+
/*elem_size_in_bits*/ elemSizeInBits,
751894
/*tile_width*/ tileWidth,
752895
/*tile_height*/ tileHeight,
753896
/*v_blocks*/ vBlocks,
754897
/*transpose*/ isTransposeRequired,
755898
/*vnni_transform*/
756899
(usePackedType && !isOperandA && !isTransposeRequired &&
757-
eltTy.getIntOrFloatBitWidth() != 32));
900+
originalElemBits != 32));
758901
if (failed(load2dOp.verify())) {
759902
// Explicitly invoke verifier because `triton_gen` ops are
760903
// immediately lowered further to a builtin call.

0 commit comments

Comments
 (0)