diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index 2b449e8301..98e68459fa 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -88,7 +88,6 @@ class TargetInfoBase { virtual bool supportLdMatrix() const { return false; } virtual bool supportStMatrix() const { return false; } virtual bool isCuda() const { return false; } - virtual bool isXpu() const { return false; } // Annotate target specific information to local load operations during // lowering to LLVM. `llLoadOp` is the generated LLVM load op. diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 642a1918a8..77421e7ebd 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -208,10 +208,8 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) { auto dstTy = cvtLayout.getType(); if (!cvtNeedsSharedMemory(srcTy, dstTy)) return 0; - // Pesimistically take the max. We will revisit later - auto elems = std::max(getNumScratchElemsSwizzledCvt(srcTy, dstTy), - getNumScratchElemsPaddedCvt(srcTy, dstTy)); - + // The generic pass uses swizzling + auto elems = getNumScratchElemsSwizzledCvt(srcTy, dstTy); return elems * getBitwidth(srcTy) / 8; } if (isa(op)) { diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index bc64b1ccf4..c201541d3b 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -63,7 +63,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } else if (llvm::is_contained(dims, kWarp)) { // Case 2: Transfer between values in the same CTA, in which case we move // values through shared memory. - return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); + transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter); + return success(); } else if (llvm::is_contained(dims, kLane)) { // Case 3. Transfer between values in the same warp, in which case we try // to move values using warp shuffles, though if the pattern is @@ -74,7 +75,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion // TODO: Since data is only transferred within a warp over shared memory, // we should use `bar.warp.sync` instead of `barrier`, which will improve // latency when warps issue barriers on different cycles. - return transferWithinBlock(op, srcLayout, dstLayout, adaptor, rewriter); + transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter); + return success(); } else if (llvm::is_contained(dims, kRegister)) { // Case 4. Transfer between values in the same thread, in which case we // simply reorder the elements of adaptor.getSrc(). @@ -110,27 +112,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion return success(); } - LogicalResult transferWithinBlock(ConvertLayoutOp op, - const LinearLayout &srcLayout, - const LinearLayout &dstLayout, - OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - assert(cvtNeedsSharedMemory(op.getSrc().getType(), op.getType())); - - // Try to use swizzling to implement the conversion - // HACK Remove once XPU tests pass for the swizzling path - if (!targetInfo.isXpu()) { - transferWithinBlockSwizzling(op, adaptor.getSrc(), rewriter); - return success(); - } - - Value result = transferWithinBlockPadding(op, adaptor.getSrc(), targetInfo, - getTypeConverter(), rewriter); - - rewriter.replaceOp(op, result); - return success(); - } - SmallVector transferWithinBlockSwizzlingImpl( Location loc, ConversionPatternRewriter &rewriter, const LinearLayout &srcLayout, const LinearLayout &dstLayout, diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index bfdd62e268..6747475247 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -201,22 +201,22 @@ tt.func @longlive(%A : !tt.ptr) { // This example triggers graph coloring with > 1 colors. // expected-remark @below {{multi_color}} -// expected-remark @below {{size = 1504}} +// expected-remark @below {{size = 1376}} tt.func @multi_color(%A : !tt.ptr) { - // expected-remark @below {{offset = 1152, size = 64}} + // expected-remark @below {{offset = 1024, size = 64}} %cst = ttg.local_alloc : () -> !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> - // expected-remark @below {{offset = 1472, size = 32}} + // expected-remark @below {{offset = 1344, size = 32}} %cst_0 = ttg.local_alloc : () -> !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> - // expected-remark @below {{offset = 1216, size = 128}} + // expected-remark @below {{offset = 1088, size = 128}} %cst_1 = ttg.local_alloc : () -> !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> - // expected-remark @below {{scratch offset = 0, size = 1152}} + // expected-remark @below {{scratch offset = 0, size = 1024}} %0 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> %1 = ttg.local_load %cst : !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x8xf16, #AL> // expected-remark @below {{offset = 0, size = 128}} %cst_3 = ttg.local_alloc : () -> !ttg.memdesc<4x16xf16, #A_SHARED, #ttg.shared_memory, mutable> %2 = ttg.local_load %cst_0 : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL> - // expected-remark @below {{scratch offset = 0, size = 1152}} + // expected-remark @below {{scratch offset = 0, size = 1024}} %3 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> // expected-remark @below {{offset = 512, size = 256}} %cst_4 = ttg.local_alloc : () -> !ttg.memdesc<4x32xf16, #A_SHARED, #ttg.shared_memory, mutable> @@ -226,7 +226,7 @@ tt.func @multi_color(%A : !tt.ptr) { %5 = ttg.local_load %cst_5 : !ttg.memdesc<4x8xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x8xf16, #AL> // expected-remark @below {{offset = 0, size = 512}} %cst_6 = ttg.local_alloc : () -> !ttg.memdesc<8x32xf16, #A_SHARED, #ttg.shared_memory, mutable> - // expected-remark @below {{offset = 1344, size = 128}} + // expected-remark @below {{offset = 1216, size = 128}} %cst_7 = ttg.local_alloc : () -> !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> %6 = ttg.local_load %cst_0 : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL> // expected-remark @below {{offset = 0, size = 512}} @@ -237,7 +237,7 @@ tt.func @multi_color(%A : !tt.ptr) { %cst_10 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> %7 = ttg.local_load %cst_1 : !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x4xf16, #AL> %8 = ttg.local_load %cst_4 : !ttg.memdesc<4x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x32xf16, #AL> - // expected-remark @below {{scratch offset = 0, size = 1152}} + // expected-remark @below {{scratch offset = 0, size = 1024}} %9 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> %cst_11 = arith.constant dense<0.000000e+00> : tensor<4x4xf16, #AL> %10 = ttg.local_load %cst_7 : !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<2x32xf16, #AL> @@ -248,16 +248,16 @@ tt.func @multi_color(%A : !tt.ptr) { // This example triggers graph coloring with multiple rounds // expected-remark @below {{multi_color_multi_rounds}} -// expected-remark @below {{size = 9504}} +// expected-remark @below {{size = 9376}} tt.func @multi_color_multi_rounds(%arg0: !tt.ptr) { - // expected-remark @below {{offset = 9472, size = 32}} + // expected-remark @below {{offset = 9344, size = 32}} %cst = ttg.local_alloc : () -> !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> - // expected-remark @below {{offset = 9344, size = 128}} + // expected-remark @below {{offset = 9216, size = 128}} %cst_0 = ttg.local_alloc : () -> !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> // expected-remark @below {{offset = 0, size = 8192}} %cst_1 = ttg.local_alloc : () -> !ttg.memdesc<1024x4xf16, #A_SHARED, #ttg.shared_memory, mutable> %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> - // expected-remark @below {{scratch offset = 8192, size = 1152}} + // expected-remark @below {{scratch offset = 8192, size = 1024}} %0 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> %1 = ttg.local_load %cst : !ttg.memdesc<4x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<4x4xf16, #AL> // expected-remark @below {{offset = 8704, size = 128}} @@ -267,7 +267,7 @@ tt.func @multi_color_multi_rounds(%arg0: !tt.ptr) { %cst_4 = ttg.local_alloc : () -> !ttg.memdesc<1x16x16xf16, #A_SHARED, #ttg.shared_memory, mutable> %3 = ttg.local_load %cst_0 : !ttg.memdesc<16x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<16x4xf16, #AL> %4 = ttg.local_load %cst_1 : !ttg.memdesc<1024x4xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<1024x4xf16, #AL> - // expected-remark @below {{scratch offset = 0, size = 1152}} + // expected-remark @below {{scratch offset = 0, size = 1024}} %5 = ttg.convert_layout %cst_2 : tensor<16x32xf16, #AL> -> tensor<16x32xf16, #BL> %6 = ttg.local_load %cst_3 : !ttg.memdesc<2x32xf16, #A_SHARED, #ttg.shared_memory, mutable> -> tensor<2x32xf16, #AL> tt.return diff --git a/test/Conversion/intel/dpas_to_block_layout_convert.mlir b/test/Conversion/intel/dpas_to_block_layout_convert.mlir index 4138695a91..bc1381495c 100644 --- a/test/Conversion/intel/dpas_to_block_layout_convert.mlir +++ b/test/Conversion/intel/dpas_to_block_layout_convert.mlir @@ -11,13 +11,14 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha %cst = arith.constant dense<0.000000e+00> : tensor<128x256xf16, #mma> // CHECK-DAG: %[[CST_3:.*]] = llvm.mlir.constant(3 : i32) : i32 - // CHECK-DAG: %[[CST_16384:.*]] = llvm.mlir.constant(16384 : i32) : i32 // CHECK-DAG: %[[CST_8192:.*]] = llvm.mlir.constant(8192 : i32) : i32 + // CHECK-DAG: %[[CST_387:.*]] = llvm.mlir.constant(387 : i32) : i32 // CHECK-DAG: %[[CST_384:.*]] = llvm.mlir.constant(384 : i32) : i32 - // CHECK-DAG: %[[CST_112:.*]] = llvm.mlir.constant(112 : i32) : i32 + // CHECK-DAG: %[[CST_64:.*]] = llvm.mlir.constant(64 : i32) : i32 + // CHECK-DAG: %[[CST_48:.*]] = llvm.mlir.constant(48 : i32) : i32 // CHECK-DAG: %[[CST_15:.*]] = llvm.mlir.constant(15 : i32) : i32 - // CHECK-DAG: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK-DAG: %[[CST_6:.*]] = llvm.mlir.constant(6 : i32) : i32 + // CHECK-DAG: %[[CST_14:.*]] = llvm.mlir.constant(14 : i32) : i32 + // CHECK-DAG: %[[CST_12:.*]] = llvm.mlir.constant(12 : i32) : i32 // CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32 // CHECK-DAG: %[[CST_2:.*]] = llvm.mlir.constant(2 : i32) : i32 // CHECK-DAG: %[[CST_1:.*]] = llvm.mlir.constant(1 : i32) : i32 @@ -35,36 +36,40 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha // CHECK: %[[VAL_26:.*]] = llvm.or %[[CST_0]], %[[VAL_25]] : i32 // CHECK: %[[VAL_27:.*]] = llvm.shl %[[warpId]], %[[CST_4]] : i32 // CHECK: %[[VAL_28:.*]] = llvm.or %[[VAL_26]], %[[VAL_27]] : i32 - // CHECK: %[[VAL_29:.*]] = llvm.and %[[VAL_28]], %[[CST_384]] : i32 - // CHECK: %[[VAL_30:.*]] = llvm.shl %[[VAL_29]], %[[CST_6]] : i32 + // CHECK: %[[VAL_29:.*]] = llvm.and %[[VAL_28]], %[[CST_3]] : i32 + // CHECK: %[[VAL_30:.*]] = llvm.shl %[[VAL_29]], %[[CST_14]] : i32 // CHECK: %[[VAL_31:.*]] = llvm.xor %[[CST_0]], %[[VAL_30]] : i32 - // CHECK: %[[VAL_32:.*]] = llvm.and %[[VAL_28]], %[[CST_112]] : i32 - // CHECK: %[[VAL_33:.*]] = llvm.shl %[[VAL_32]], %[[CST_1]] : i32 + // CHECK: %[[VAL_32:.*]] = llvm.and %[[VAL_28]], %[[CST_387]] : i32 + // CHECK: %[[VAL_33:.*]] = llvm.shl %[[VAL_32]], %[[CST_4]] : i32 // CHECK: %[[VAL_34:.*]] = llvm.xor %[[VAL_31]], %[[VAL_33]] : i32 - // CHECK: %[[VAL_35:.*]] = llvm.and %[[VAL_28]], %[[CST_15]] : i32 - // CHECK: %[[VAL_36:.*]] = llvm.lshr %[[VAL_35]], %[[CST_0]] : i32 + // CHECK: %[[VAL_35:.*]] = llvm.and %[[VAL_28]], %[[CST_48]] : i32 + // CHECK: %[[VAL_36:.*]] = llvm.shl %[[VAL_35]], %[[CST_1]] : i32 // CHECK: %[[VAL_37:.*]] = llvm.xor %[[VAL_34]], %[[VAL_36]] : i32 - // CHECK: %[[VAL_38:.*]] = llvm.xor %[[CST_0]], %[[VAL_37]] : i32 - // CHECK: %[[VAL_39:.*]] = llvm.and %[[VAL_28]], %[[CST_511]] : i32 - // CHECK: %[[VAL_40:.*]] = llvm.shl %[[VAL_39]], %[[CST_3]] : i32 - // CHECK: %[[VAL_41:.*]] = llvm.xor %[[CST_0]], %[[VAL_40]] : i32 - // CHECK: %[[VAL_42:.*]] = llvm.xor %[[CST_0]], %[[VAL_41]] : i32 - // CHECK: %[[VAL_43:.*]] = llvm.xor %[[VAL_38]], %[[CST_0]] : i32 - // CHECK: %[[VAL_44:.*]] = llvm.lshr %[[VAL_43]], %[[CST_8]] : i32 - // CHECK: %[[VAL_45:.*]] = llvm.shl %[[VAL_44]], %[[CST_3]] : i32 - // CHECK: %[[offset:.*]] = llvm.add %[[VAL_45]], %[[VAL_43]] : i32 - // CHECK: %[[VAL_65:.*]] = llvm.getelementptr inbounds %[[SMEM]]{{\[}}%[[offset]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 - // CHECK: %[[VAL_66:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[CST_0]] : i32] : vector<1xf16> - - // COM: Because the values per thread of DPAS layout is not contiguous. The values are stored in the SLM in a non-vectorized way. - // COM: Total 64 stores are generated to save the tensor of the DPAS layout to the SLM. 128*256/(4*8*16) = 64 - // CHECK: llvm.store %[[VAL_66]], %[[VAL_65]] : vector<1xf16>, !llvm.ptr<3> - // CHECK-COUNT-63: llvm.store {{.*}}, {{.*}} : vector<1xf16>, !llvm.ptr<3> + // CHECK: %[[VAL_38:.*]] = llvm.and %[[VAL_28]], %[[CST_12]] : i32 + // CHECK: %[[VAL_39:.*]] = llvm.lshr %[[VAL_38]], %[[CST_0]] : i32 + // CHECK: %[[VAL_40:.*]] = llvm.xor %[[VAL_37]], %[[VAL_39]] : i32 + // CHECK: %[[VAL_41:.*]] = llvm.and %[[VAL_28]], %[[CST_64]] : i32 + // CHECK: %[[VAL_42:.*]] = llvm.icmp "eq" %[[VAL_41]], %[[CST_0]] : i32 + // CHECK: %[[VAL_43:.*]] = llvm.select %[[VAL_42]], %[[CST_0]], %[[CST_8192]] : i1, i32 + // CHECK: %[[VAL_44:.*]] = llvm.xor %[[VAL_40]], %[[VAL_43]] : i32 + // CHECK: %[[VAL_45:.*]] = llvm.xor %[[CST_0]], %[[VAL_44]] : i32 + // CHECK: %[[VAL_46:.*]] = llvm.mul %[[CST_0]], %[[CST_2]] : i32 + // CHECK: %[[VAL_47:.*]] = llvm.xor %[[VAL_45]], %[[VAL_46]] : i32 + // CHECK: %[[VAL_48:.*]] = llvm.xor %[[VAL_47]], %[[CST_0]] : i32 + // CHECK: %[[offset:.*]] = llvm.add %[[VAL_48]], %[[CST_0]] : i32 + // CHECK: %[[VAL_65:.*]] = llvm.getelementptr inbounds %[[SMEM]]{{\[}}%[[offset]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_66:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[CST_0]] : i32] : vector<2xf16> + // CHECK: %[[VAL_67:.*]] = llvm.insertelement {{.*}}, %[[VAL_66]]{{\[}}%[[CST_1]] : i32] : vector<2xf16> + + // COM: Because the values per thread of DPAS layout is contiguous. The values are stored in the SLM in vectorized way. + // COM: Total 32 stores are generated to save the tensor of the DPAS layout to the SLM. 128*256/(4*8*16*2) = 32 + // CHECK: llvm.store %[[VAL_67]], %[[VAL_65]] : vector<2xf16>, !llvm.ptr<3> + // CHECK-COUNT-31: llvm.store {{.*}}, {{.*}} : vector<2xf16>, !llvm.ptr<3> // CHECK: llvm.call spir_funccc @_Z7barrierj(%[[CST_1]]) {convergent, no_unwind, will_return} : (i32) -> () // COM: Because the values per thread of blocked layout is contiguous. The values are loaded from the SLM in a vectorized way. // COM: Total 8 loads are generated to load the tensor of the blocked layout from the SLM. 128*256/(16*2*16*8) = 8 - // CHECK-COUNT-8: {{.*}} = llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16> + // CHECK-COUNT-4: {{.*}} = llvm.load {{.*}} : !llvm.ptr<3> -> vector<4xf16> %93 = ttg.convert_layout %cst {allocation.offset = 0 : i32} : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked> %80 = tt.splat %arg0 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> @@ -90,11 +95,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha // CHECK-DAG: %[[CST_3:.*]] = llvm.mlir.constant(3 : i32) : i32 // CHECK-DAG: %[[CST_8192:.*]] = llvm.mlir.constant(8192 : i32) : i32 // CHECK-DAG: %[[CST_4096:.*]] = llvm.mlir.constant(4096 : i32) : i32 + // CHECK-DAG: %[[CST_387:.*]] = llvm.mlir.constant(387 : i32) : i32 // CHECK-DAG: %[[CST_384:.*]] = llvm.mlir.constant(384 : i32) : i32 - // CHECK-DAG: %[[CST_112:.*]] = llvm.mlir.constant(112 : i32) : i32 + // CHECK-DAG: %[[CST_64:.*]] = llvm.mlir.constant(64 : i32) : i32 + // CHECK-DAG: %[[CST_48:.*]] = llvm.mlir.constant(48 : i32) : i32 // CHECK-DAG: %[[CST_15:.*]] = llvm.mlir.constant(15 : i32) : i32 - // CHECK-DAG: %[[CST_8:.*]] = llvm.mlir.constant(8 : i32) : i32 - // CHECK-DAG: %[[CST_5:.*]] = llvm.mlir.constant(5 : i32) : i32 + // CHECK-DAG: %[[CST_14:.*]] = llvm.mlir.constant(14 : i32) : i32 + // CHECK-DAG: %[[CST_12:.*]] = llvm.mlir.constant(12 : i32) : i32 // CHECK-DAG: %[[CST_4:.*]] = llvm.mlir.constant(4 : i32) : i32 // CHECK-DAG: %[[CST_2:.*]] = llvm.mlir.constant(2 : i32) : i32 // CHECK-DAG: %[[CST_1:.*]] = llvm.mlir.constant(1 : i32) : i32 @@ -113,43 +120,40 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32, ttg.sha // CHECK: %[[VAL_26:.*]] = llvm.or %[[CST_0]], %[[VAL_25]] : i32 // CHECK: %[[VAL_27:.*]] = llvm.shl %[[warpId]], %[[CST_4]] : i32 // CHECK: %[[VAL_28:.*]] = llvm.or %[[VAL_26]], %[[VAL_27]] : i32 - // CHECK: %[[VAL_29:.*]] = llvm.and %[[VAL_28]], %[[CST_384]] : i32 - // CHECK: %[[VAL_30:.*]] = llvm.shl %[[VAL_29]], %[[CST_5]] : i32 + // CHECK: %[[VAL_29:.*]] = llvm.and %[[VAL_28]], %[[CST_3]] : i32 + // CHECK: %[[VAL_30:.*]] = llvm.shl %[[VAL_29]], %[[CST_14]] : i32 // CHECK: %[[VAL_31:.*]] = llvm.xor %[[CST_0]], %[[VAL_30]] : i32 - // CHECK: %[[VAL_32:.*]] = llvm.and %[[VAL_28]], %[[CST_112]] : i32 - // CHECK: %[[VAL_33:.*]] = llvm.shl %[[VAL_32]], %[[CST_1]] : i32 + // CHECK: %[[VAL_32:.*]] = llvm.and %[[VAL_28]], %[[CST_387]] : i32 + // CHECK: %[[VAL_33:.*]] = llvm.shl %[[VAL_32]], %[[CST_4]] : i32 // CHECK: %[[VAL_34:.*]] = llvm.xor %[[VAL_31]], %[[VAL_33]] : i32 - // CHECK: %[[VAL_35:.*]] = llvm.and %[[VAL_28]], %[[CST_15]] : i32 - // CHECK: %[[VAL_36:.*]] = llvm.lshr %[[VAL_35]], %[[CST_0]] : i32 + // CHECK: %[[VAL_35:.*]] = llvm.and %[[VAL_28]], %[[CST_48]] : i32 + // CHECK: %[[VAL_36:.*]] = llvm.shl %[[VAL_35]], %[[CST_1]] : i32 // CHECK: %[[VAL_37:.*]] = llvm.xor %[[VAL_34]], %[[VAL_36]] : i32 - // CHECK: %[[VAL_38:.*]] = llvm.xor %[[CST_0]], %[[VAL_37]] : i32 - // CHECK: %[[VAL_39:.*]] = llvm.and %[[VAL_28]], %[[CST_511]] : i32 - // CHECK: %[[VAL_40:.*]] = llvm.shl %[[VAL_39]], %[[CST_3]] : i32 - // CHECK: %[[VAL_41:.*]] = llvm.xor %[[CST_0]], %[[VAL_40]] : i32 - // CHECK: %[[VAL_42:.*]] = llvm.xor %[[CST_0]], %[[VAL_41]] : i32 - // CHECK: %[[VAL_43:.*]] = llvm.xor %[[VAL_38]], %[[CST_0]] : i32 - // CHECK: %[[VAL_44:.*]] = llvm.lshr %[[VAL_43]], %[[CST_8]] : i32 - // CHECK: %[[VAL_45:.*]] = llvm.shl %[[VAL_44]], %[[CST_3]] : i32 - // CHECK: %[[offset:.*]] = llvm.add %[[VAL_45]], %[[VAL_43]] : i32 - // CHECK: %[[VAL_65:.*]] = llvm.getelementptr inbounds %[[SMEM]]{{\[}}%[[offset]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, f16 - // CHECK: %[[VAL_66:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[CST_0]] : i32] : vector<1xf16> - - // COM: Because the values per thread of DPAS layout is not contiguous. The values are stored in the SLM in a non-vectorized way. - // COM: Total 32 stores are generated to save the tensor of the DPAS layout to the SLM. 64*256/(4*8*16) = 32 - // CHECK: llvm.store %[[VAL_66]], %[[VAL_65]] : vector<1xf16>, !llvm.ptr<3> - // CHECK-COUNT-31: llvm.store {{.*}}, {{.*}} : vector<1xf16>, !llvm.ptr<3> + // CHECK: %[[VAL_38:.*]] = llvm.and %[[VAL_28]], %[[CST_12]] : i32 + // CHECK: %[[VAL_39:.*]] = llvm.lshr %[[VAL_38]], %[[CST_0]] : i32 + // CHECK: %[[VAL_40:.*]] = llvm.xor %[[VAL_37]], %[[VAL_39]] : i32 + // CHECK: %[[VAL_41:.*]] = llvm.and %[[VAL_28]], %[[CST_64]] : i32 + // CHECK: %[[VAL_42:.*]] = llvm.icmp "eq" %[[VAL_41]], %[[CST_0]] : i32 + // CHECK: %[[VAL_43:.*]] = llvm.select %[[VAL_42]], %[[CST_0]], %[[CST_8192]] : i1, i32 + // CHECK: %[[VAL_44:.*]] = llvm.xor %[[VAL_40]], %[[VAL_43]] : i32 + // CHECK: %[[VAL_45:.*]] = llvm.xor %[[CST_0]], %[[VAL_44]] : i32 + // CHECK: %[[VAL_46:.*]] = llvm.mul %[[CST_0]], %[[CST_2]] : i32 + // CHECK: %[[VAL_47:.*]] = llvm.xor %[[VAL_45]], %[[VAL_46]] : i32 + // CHECK: %[[VAL_48:.*]] = llvm.xor %[[VAL_47]], %[[CST_0]] : i32 + // CHECK: %[[offset:.*]] = llvm.add %[[VAL_48]], %[[CST_0]] : i32 + // CHECK: %[[VAL_65:.*]] = llvm.getelementptr inbounds %[[SMEM]]{{\[}}%[[offset]]] : (!llvm.ptr<3>, i32) -> !llvm.ptr<3>, i8 + // CHECK: %[[VAL_66:.*]] = llvm.insertelement {{.*}}, {{.*}}{{\[}}%[[CST_0]] : i32] : vector<2xf16> + // CHECK: %[[VAL_67:.*]] = llvm.insertelement {{.*}}, %[[VAL_66]]{{\[}}%[[CST_1]] : i32] : vector<2xf16> + + // COM: Because the values per thread of DPAS layout is contiguous. The values are stored in the SLM in vectorized way. + // COM: Total 32 stores are generated to save the tensor of the DPAS layout to the SLM. 128*256/(4*8*16*2) = 32 + // CHECK: llvm.store %[[VAL_67]], %[[VAL_65]] : vector<2xf16>, !llvm.ptr<3> + // CHECK-COUNT-31: llvm.store {{.*}}, {{.*}} : vector<2xf16>, !llvm.ptr<3> // CHECK: llvm.call spir_funccc @_Z7barrierj(%[[CST_1]]) {convergent, no_unwind, will_return} : (i32) -> () // COM: Because the values per thread of blocked layout is contiguous. The values are loaded from the SLM in a vectorized way. - // COM: Total 4 loads are generated to load the tensor of the blocked layout from the SLM. 128*256/(16*2*16*8) = 8 - // CHECK-COUNT-4: {{.*}} = llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16> - - // COM: The 2nd round of exchanging values. - // CHECK: llvm.call spir_funccc @_Z7barrierj(%[[CST_1]]) {convergent, no_unwind, will_return} : (i32) -> () - // CHECK-COUNT-32: llvm.store {{.*}}, {{.*}} : vector<1xf16>, !llvm.ptr<3> - // CHECK: llvm.call spir_funccc @_Z7barrierj(%[[CST_1]]) {convergent, no_unwind, will_return} : (i32) -> () - // CHECK-COUNT-4: {{.*}} = llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16> - + // COM: Total 16 loads are generated to load the tensor of the blocked layout from the SLM. 128*256/(16*2*16*4) = 16 + // CHECK-COUNT-16: {{.*}} = llvm.load {{.*}} : !llvm.ptr<3> -> vector<4xf16> %93 = ttg.convert_layout %cst {allocation.offset = 0 : i32} : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked> %80 = tt.splat %arg0 : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> %83 = tt.broadcast %80 : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x256x!tt.ptr, #blocked> diff --git a/test/Conversion/intel/tritongpu_to_gen.mlir b/test/Conversion/intel/tritongpu_to_gen.mlir index 2350e6eee6..5341e879d2 100644 --- a/test/Conversion/intel/tritongpu_to_gen.mlir +++ b/test/Conversion/intel/tritongpu_to_gen.mlir @@ -816,12 +816,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr // CHECK-LABEL: convert_layout_dpas_block tt.func @convert_layout_dpas_blocked(%arg0: tensor<32x16xf32, #dpas>) { // CHECK: llvm.store - // CHECK-SAME: vector<1xf32>, !llvm.ptr<3> + // CHECK-SAME: vector<2xf32>, !llvm.ptr<3> // CHECK: llvm.store - // CHECK-SAME: vector<1xf32>, !llvm.ptr<3> + // CHECK-SAME: vector<2xf32>, !llvm.ptr<3> // CHECK: llvm.call spir_funccc @_Z7barrierj({{.*}}) {{.*}} : (i32) -> () // CHECK: llvm.load - // CHECK-SAME: !llvm.ptr<3> -> vector<4xf32> + // CHECK-SAME: !llvm.ptr<3> -> vector<2xf32> %0 = ttg.convert_layout %arg0 : tensor<32x16xf32, #dpas> -> tensor<32x16xf32, #blocked0> tt.return } @@ -836,13 +836,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr // CHECK-LABEL: convert_layout_dpas_block tt.func @convert_layout_dpas_blocked(%arg0: tensor<32x64xf32, #dpas>) { // CHECK: llvm.store - // CHECK-SAME: vector<1xf32>, !llvm.ptr<3> + // CHECK-SAME: vector<4xf32>, !llvm.ptr<3> // CHECK: llvm.store - // CHECK-SAME: vector<1xf32>, !llvm.ptr<3> + // CHECK-SAME: vector<4xf32>, !llvm.ptr<3> // CHECK: llvm.store - // CHECK-SAME: vector<1xf32>, !llvm.ptr<3> + // CHECK-SAME: vector<4xf32>, !llvm.ptr<3> // CHECK: llvm.store - // CHECK-SAME: vector<1xf32>, !llvm.ptr<3> + // CHECK-SAME: vector<4xf32>, !llvm.ptr<3> // CHECK: llvm.call spir_funccc @_Z7barrierj({{.*}}) {{.*}} : (i32) -> () // CHECK: llvm.load // CHECK-SAME: !llvm.ptr<3> -> vector<4xf32> @@ -858,9 +858,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} { // CHECK-LABEL: convert_layout_dpas_transpose tt.func @convert_layout_dpas_transpose(%arg0: tensor<128x256xf8E5M2, #dpas>) { - // CHECK-COUNT-128: llvm.store %{{.*}} : vector<1xi8>, !llvm.ptr<3> + // CHECK-COUNT-16: llvm.store %{{.*}} : vector<16xi8>, !llvm.ptr<3> // CHECK: llvm.call spir_funccc @_Z7barrierj({{.*}}) {{.*}} : (i32) -> () - // CHECK-COUNT-80: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<1xi8> + // CHECK-COUNT-2: llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xi8> %0 = ttg.convert_layout %arg0 : tensor<128x256xf8E5M2, #dpas> -> tensor<128x256xf8E5M2, #blocked> tt.return } @@ -902,7 +902,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked1d_to_slice1 tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) { - // CHECK-COUNT-8: llvm.load {{.*}} : !llvm.ptr<3> + // CHECK-COUNT-2: llvm.load {{.*}} : !llvm.ptr<3> -> vector<4xi32> %cvt = ttg.convert_layout %src : tensor<32xi32, #blocked0> -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> tt.return } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h index c05bc31a9d..5a1607a51a 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.h @@ -69,8 +69,6 @@ class TargetInfo : public mlir::triton::TargetInfoBase { StringRef name, StringRef value, unsigned addressSpace) const; - bool isXpu() const override { return true; } - protected: virtual bool isSupportedWarpReduceOp(Operation *op, unsigned numLanesToReduce, unsigned warpSize) const = 0;