Skip to content

Commit 68a3466

Browse files
anmyachevwhitneywhtsang
authored andcommitted
Fix build and test failures from '8a5862d'
Signed-off-by: Anatoly Myachev <[email protected]> fix build Signed-off-by: Anatoly Myachev <[email protected]> fix test_core.py after merge Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 2582391 commit 68a3466

File tree

5 files changed

+23
-21
lines changed

5 files changed

+23
-21
lines changed

python/test/unit/language/test_core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6419,13 +6419,14 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device, t
64196419
# expect compute scratch buffer to not error on xpu
64206420
raise
64216421
pytest.skip("Can't compute scratch buffer size")
6422-
lds_size = get_hip_lds_size()
6422+
lds_size = triton.runtime.driver.active.utils.get_device_properties(
6423+
triton.runtime.driver.active.get_current_device())["max_shared_mem"] if is_xpu() else get_hip_lds_size()
64236424
# consider int32 dtype in scratch buffer size,
64246425
# because it is the largest dtype used in convert_layout in this test
64256426
int32_size = 4
64266427
# skip even if scratch buffer equal to lds_size, because real scratch buffer is typically larger due to padding
64276428
if scratch_shape[0] * scratch_shape[1] * int32_size >= lds_size:
6428-
pytest.skip("Scratch buffer is too large")
6429+
pytest.xfail("Scratch buffer is too large")
64296430
if is_cuda() and isinstance(interm_layout, PaddedSharedLayout):
64306431
pytest.skip("PaddedSharedLayout is not supported on CUDA")
64316432

test/Conversion/intel/shared_to_dot_layout_convert.mlir

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
2424
// COM: Start of ttg.local_load. Load the value from SLM to register.
2525
// CHECK: %[[WORK_ITEM_ID_:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[CST_0]])
2626
// CHECK: %[[WORK_ITEM_ID:.*]] = llvm.trunc %[[WORK_ITEM_ID_]] : i64 to i32
27-
// CHECK-COUNT-128: %[[LD_RES:.*]] = llvm.load {{.*}} {alignment = 2 : i64} : !llvm.ptr<3> -> vector<1xf16>
27+
// CHECK-COUNT-128: %[[LD_RES:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf16>
2828
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> -> tensor<128x64xf16, #dot_operand_a>
2929

3030
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #dpas>
@@ -62,7 +62,7 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
6262
// COM: Start of ttg.local_load. Load the value from SLM to register.
6363
// CHECK: %[[WORK_ITEM_ID_:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[CST_0]])
6464
// CHECK: %[[WORK_ITEM_ID:.*]] = llvm.trunc %[[WORK_ITEM_ID_]] : i64 to i32
65-
// CHECK-COUNT-128: %[[LD_RES:.*]] = llvm.load {{.*}} {alignment = 2 : i64} : !llvm.ptr<3> -> vector<1xf16>
65+
// CHECK-COUNT-128: %[[LD_RES:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf16>
6666
%AA_DOT = ttg.local_load %AA : !ttg.memdesc<128x64xf16, #shared, #ttg.shared_memory> -> tensor<128x64xf16, #dot_operand_a>
6767

6868
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #dpas>
@@ -87,23 +87,21 @@ module attributes {"ttg.num-warps" = 32 : i32, "ttg.threads-per-warp" = 16 : i32
8787
// CHECK-SAME: %[[PTR_1:.*]]: !llvm.ptr<1>)
8888
// CHECK-SAME: attributes {intel_reqd_sub_group_size = 16 : i32, {{.*}}} {
8989
tt.func @convert_dot(%B: tensor<64x256xf16, #blocked1>) {
90-
// CHECK-DAG: %[[CST_128:.*]] = llvm.mlir.constant(128 : i32) : i32
91-
// CHECK-DAG: %[[CST_256:.*]] = llvm.mlir.constant(256 : i32) : i32
92-
// CHECK-DAG: %[[CST_32:.*]] = llvm.mlir.constant(32 : i32) : i32
93-
// CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32
94-
// CHECK-DAG: %[[CST_2:.*]] = llvm.mlir.constant(2 : i32) : i32
90+
// CHECK-DAG: %[[CST_14:.*]] = llvm.mlir.constant(14 : i32) : i32
91+
// CHECK-DAG: %[[CST_13:.*]] = llvm.mlir.constant(13 : i32) : i32
92+
// CHECK-DAG: %[[CST_12:.*]] = llvm.mlir.constant(12 : i32) : i32
93+
// CHECK-DAG: %[[CST_11:.*]] = llvm.mlir.constant(11 : i32) : i32
94+
// CHECK-DAG: %[[CST_10:.*]] = llvm.mlir.constant(10 : i32) : i32
95+
// CHECK-DAG: %[[CST_9:.*]] = llvm.mlir.constant(9 : i32) : i32
9596
// CHECK-DAG: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32
96-
// CHECK-DAG: %[[CST_16:.*]] = llvm.mlir.constant(16 : i32) : i32
97-
// CHECK-DAG: %[[CST_64:.*]] = llvm.mlir.constant(64 : i32) : i32
98-
// CHECK-DAG: %[[CST_1:.*]] = llvm.mlir.constant(1 : i32) : i32
9997
// CHECK-DAG: %[[CST_0:.*]] = llvm.mlir.constant(0 : i32) : i32
10098
%BB = ttg.local_alloc %B : (tensor<64x256xf16, #blocked1>) -> !ttg.memdesc<64x256xf16, #shared, #ttg.shared_memory>
10199

102100
// CHECK: llvm.call spir_funccc @_Z7barrierj
103101
// COM: Start of ttg.local_load. Load the value from SLM to register.
104102
// CHECK: %[[WORK_ITEM_ID_:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[CST_0]])
105103
// CHECK: %[[WORK_ITEM_ID:.*]] = llvm.trunc %[[WORK_ITEM_ID_]] : i64 to i32
106-
// CHECK-COUNT-128: %[[LD_RES:.*]] = llvm.load {{.*}} {alignment = 2 : i64} : !llvm.ptr<3> -> vector<1xf16>
104+
// CHECK-COUNT-128: %[[LD_RES:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> vector<1xf16>
107105
%BB_DOT = ttg.local_load %BB : !ttg.memdesc<64x256xf16, #shared, #ttg.shared_memory> -> tensor<64x256xf16, #dot_operand_b>
108106
%cst0 = arith.constant dense<0.000000e+00> : tensor<128x256xf32, #dpas>
109107
%cst1 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #dot_operand_a>

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -875,9 +875,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
875875
// CHECK-LABEL: convert_layout_blocked_shared
876876
tt.func @convert_layout_blocked_shared(%arg0: tensor<128x32xf32, #blocked0>) {
877877
// CHECK: llvm.store
878-
// CHECK-SAME: vector<8xf32>, !llvm.ptr<3>
878+
// CHECK-SAME: vector<4xf32>, !llvm.ptr<3>
879879
// CHECK: llvm.store
880-
// CHECK-SAME: vector<8xf32>, !llvm.ptr<3>
880+
// CHECK-SAME: vector<4xf32>, !llvm.ptr<3>
881881
%0 = ttg.local_alloc %arg0 : (tensor<128x32xf32, #blocked0>) -> !ttg.memdesc<128x32xf32, #shared0, #smem>
882882
tt.return
883883
}
@@ -1432,6 +1432,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
14321432
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
14331433
// CHECK-LABEL: test_base_index_cache
14341434
tt.func @test_base_index_cache(%arg0: tensor<128x32xf32, #blocked0>) {
1435+
// CHECK: llvm.mlir.constant(0 : i32) : i32
14351436
// CHECK: llvm.mlir.constant(0 : i32) : i32
14361437
// CHECK: llvm.mlir.constant(0 : i32) : i32
14371438
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -1449,6 +1450,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
14491450
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
14501451
// CHECK-LABEL: test_index_cache_different_block
14511452
tt.func @test_index_cache_different_block(%arg0: tensor<128x32xf32, #blocked0>, %arg1: i1) {
1453+
// CHECK: llvm.mlir.constant(0 : i32) : i32
14521454
// CHECK: llvm.mlir.constant(0 : i32) : i32
14531455
// CHECK: llvm.mlir.constant(0 : i32) : i32
14541456
// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
@@ -1890,7 +1892,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr
18901892
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
18911893
// CHECK-LABEL: @vectorize_shmem_load
18921894
// CHECK: llvm.load
1893-
// CHECK-SAME: {alignment = 8 : i64} : !llvm.ptr<3> -> vector<8xi8>
1895+
// CHECK-SAME: !llvm.ptr<3> -> vector<8xi8>
18941896
// CHECK-NOT: llvm.load
18951897
tt.func public @vectorize_shmem_load(%shmem : !ttg.memdesc<16x16xi8, #shared, #smem>) {
18961898
%0 = ttg.local_load %shmem : !ttg.memdesc<16x16xi8, #shared, #smem> -> tensor<16x16xi8, #blocked>
@@ -1906,8 +1908,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
19061908
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 32 : i32} {
19071909
// CHECK-LABEL: @vectorize_shmem_store
19081910
// CHECK: llvm.store
1909-
// CHECK-SAME: {alignment = 64 : i64} : vector<16xi32>, !llvm.ptr<3>
1910-
// CHECK-NOT: llvm.store
1911+
// CHECK-SAME: vector<4xi32>, !llvm.ptr<3>
1912+
// CHECK: llvm.store
1913+
// CHECK-SAME: vector<4xi32>, !llvm.ptr<3>
19111914
tt.func public @vectorize_shmem_store(%block : tensor<64x64xi32, #blocked>) {
19121915
%0 = ttg.local_alloc %block : (tensor<64x64xi32, #blocked>) -> !ttg.memdesc<64x64xi32, #shared, #smem>
19131916
tt.return

third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ void TargetInfo::storeMatrixShared(RewriterBase &rewriter, Location loc,
6464

6565
Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr,
6666
std::optional<Value> ctaId, Type elemTy,
67-
Value pred) const {
67+
Value pred, Operation *localLoadOp) const {
6868
assert(cast<mlir::LLVM::LLVMPointerType>(ptr.getType()).getAddressSpace() ==
6969
3 &&
7070
"Invalid addr space for loadShared");

third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ class TargetInfo : public mlir::triton::TargetInfoBase {
2929
std::optional<Value> ctaId, Value val,
3030
Value pred) const override;
3131
Value loadDShared(RewriterBase &rewriter, Location loc, Value ptr,
32-
std::optional<Value> ctaId, Type elemTy,
33-
Value pred) const override;
32+
std::optional<Value> ctaId, Type elemTy, Value pred,
33+
Operation *localLoadOp = nullptr) const override;
3434
bool canUseStMatrix(RankedTensorType tensorTy, ArrayRef<unsigned> repShape,
3535
ArrayRef<unsigned> paddedRepShape,
3636
ArrayRef<unsigned> order,

0 commit comments

Comments
 (0)