Skip to content

Commit 14bfdba

Browse files
authored
Update create_nd_descriptor base address for 1d tile (#832)
commit 80a415a531800b1935b6b5e33b82d3fc5cb45b63 Author: Gune <[email protected]> Date: Thu Aug 1 15:17:36 2024 +0530 Update create_nd_descriptor base address for 1d tile The base address before this fix assumes that the tile will be always 2d. For a 1d tile, the base address needs to be adjusted similarly.
1 parent 5a8f4db commit 14bfdba

File tree

6 files changed

+218
-24
lines changed

6 files changed

+218
-24
lines changed

lib/Conversion/XeGPUToVC/XeGPUToVC.cpp

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -97,43 +97,91 @@ static func::CallOp createFuncCall(PatternRewriter &rewriter, Location loc,
9797
return rewriter.create<func::CallOp>(loc, fn, resultType, operands);
9898
}
9999

100-
// Given an n-dim memref, a tensor descriptor defines a 2d memory region with
101-
// respect to the two inner-most dimensions. Other outer dimensions affect the
102-
// base address. For example, given
100+
// Given an n-dim memref, a tensor descriptor with tile rank of 2 defines a 2d
101+
// memory region with respect to the two inner-most dimensions. Other outer
102+
// dimensions affect the base address of the 2d plane.
103+
// For 2d, we compute the base address of 2d plane, assuming the coordinates
104+
// [0, 0] for the innermost 2 dimensions. The payload will record tile offset
105+
// within the 2d plane in separate field.
106+
// For example, given
103107
// %m: memref<2x7x32x64xf16>
104108
// And this access
105109
// %m[%a, %b, %c, %d]
110+
//
106111
// The base address will be adjusted as follows:
107-
// new_base = base(%m) + %b * (32*64*2) + %a * (7*32*64*2)
112+
// base address of plane for 2d tile = base(%m) + %b * (32*64*2) + %a *
113+
// (7*32*64*2)
108114
// 2 is the number of bytes of the element type.
115+
//
116+
// For 1d, we compute the base address of the 1d tile, not the plane.
117+
// So the tile offset is also added to the base address.
118+
//
119+
// For tile rank of 1, the base address will be adjusted as:
120+
// base address of tile for 1d tile = base(%m) + %d * (2) + %c * (64*2) +
121+
// %b * (32*64*2) + %a * (7*32*64*2)
122+
109123
static Value adjustBasePointer(ConversionPatternRewriter &rewriter,
110-
xegpu::CreateNdDescOp op, Value base) {
124+
xegpu::CreateNdDescOp op, Value memrefBaseAddr) {
125+
auto memType = dyn_cast<MemRefType>(op.getSource().getType());
126+
127+
// FIXME: Only support static shape for now
128+
if (!memType || !memType.hasStaticShape())
129+
return memrefBaseAddr;
130+
111131
auto loc = op.getLoc();
112132

113133
auto createIndexConstant = [&](unsigned index) {
114134
return rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexType(),
115135
rewriter.getIndexAttr(index));
116136
};
117137

118-
if (auto memType = dyn_cast<MemRefType>(op.getSource().getType());
119-
memType && memType.getRank() > 2) {
120-
assert(memType.hasStaticShape() && "only support static shape for now");
121-
auto shape = memType.getShape();
122-
int64_t i = memType.getRank() - 1;
123-
unsigned stride = memType.getElementType().getIntOrFloatBitWidth() / 8;
124-
stride *= shape[i--];
125-
stride *= shape[i--];
126-
auto offsets = op.getMixedOffsets();
138+
auto tileRank = op.getTensorDesc().getType().getRank();
139+
auto offsets = op.getMixedOffsets();
140+
auto strides = mlir::getStridesAndOffset(memType).first;
141+
int64_t i = memType.getRank() - 1;
142+
143+
auto computeBase =
144+
[&](Value base) {
145+
for (; i >= 0; --i) {
146+
unsigned stride =
147+
strides[i] * memType.getElementType().getIntOrFloatBitWidth() / 8;
148+
auto factor = createIndexConstant(stride);
149+
auto offset = offsets.pop_back_val();
150+
Value offsetVal;
151+
152+
if (offset.is<Value>()) {
153+
offsetVal = offset.get<Value>();
154+
} else {
155+
offsetVal = createIndexConstant(
156+
llvm::cast<IntegerAttr>(offset.get<Attribute>()).getInt());
157+
}
158+
auto linearOffset =
159+
rewriter.create<arith::MulIOp>(loc, offsetVal, factor);
160+
base = rewriter.create<arith::AddIOp>(loc, base, linearOffset);
161+
}
162+
163+
return base;
164+
};
165+
166+
if (tileRank == 2 && memType.getRank() > 2) {
167+
// base address of plane for 2d: base addr of memref + offsets (starting
168+
// from j to i) for a given memref<ixjxkxlxf16>
169+
170+
i -= 2;
127171
offsets.pop_back_n(2);
128-
for (; i >= 0; --i) {
129-
auto factor = createIndexConstant(stride);
130-
auto linearOffset = rewriter.create<arith::MulIOp>(
131-
loc, offsets.pop_back_val().get<Value>(), factor);
132-
base = rewriter.create<arith::AddIOp>(loc, base, linearOffset);
133-
stride *= shape[i];
134-
}
172+
173+
auto baseOf2dPlane = computeBase(memrefBaseAddr);
174+
return baseOf2dPlane;
175+
}
176+
177+
if (tileRank == 1) {
178+
// base address of tile for 1d: base addr of memref + offsets (starting from
179+
// k to i) for a given memref<ixjxkxlxf16>
180+
auto baseOf1dTile = computeBase(memrefBaseAddr);
181+
return baseOf1dTile;
135182
}
136-
return base;
183+
184+
return memrefBaseAddr;
137185
}
138186

139187
struct CreateNdDescPattern

lib/ExecutionEngine/ImexRunnerUtils.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ _mlir_ciface_fillResource1DRandomF16(UnrankedMemRefType<f16> *ptr,
8888
_mlir_ciface_fillResource1DRandom(ptr, lower, upper, genInt);
8989
}
9090

91+
/// Fills 1D memref of f32 type with random values uniformly
92+
extern "C" void
93+
_mlir_ciface_fillResource1DRandomF32(UnrankedMemRefType<float> *ptr,
94+
const float lower, const float upper,
95+
const bool genInt) {
96+
_mlir_ciface_fillResource1DRandom(ptr, lower, upper, genInt);
97+
}
98+
9199
extern "C" void _mlir_ciface_printMemrefBF16(UnrankedMemRefType<bf16> *M) {
92100
_mlir_ciface_printMemref(M);
93101
}

test/Conversion/XeGPUToVC/create_nd_desc.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,5 +98,24 @@ module @gemm attributes {gpu.container_module} {
9898
//CHECK: gpu.return
9999
gpu.return
100100
}
101+
102+
// CHECK: gpu.func @test_create_nd_tdesc_4(%[[arg0:.*]]: memref<8x16xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
103+
gpu.func @test_create_nd_tdesc_4(%arg0: memref<8x16xf16>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>}{
104+
//CHECK: %c1 = arith.constant 1 : index
105+
%c1 = arith.constant 1 : index
106+
107+
//CHECK: %intptr = memref.extract_aligned_pointer_as_index %arg0 : memref<8x16xf16> -> index
108+
//CHECK: %c2 = arith.constant 2 : index
109+
//CHECK: %0 = arith.muli %c1, %c2 : index
110+
//CHECK: %1 = arith.addi %intptr, %0 : index
111+
//CHECK: %c32 = arith.constant 32 : index
112+
//CHECK: %2 = arith.muli %c1, %c32 : index
113+
//CHECK: %3 = arith.addi %1, %2 : index
114+
//CHECK: %4 = arith.index_castui %3 : index to i64
115+
//CHECK: %5 = vector.insert %4, %cst [0] : i64 into vector<4xi64>
116+
%0 = xegpu.create_nd_tdesc %arg0[%c1, %c1] : memref<8x16xf16> -> !xegpu.tensor_desc<16xf16>
117+
//CHECK: gpu.return
118+
gpu.return
119+
}
101120
}
102121
}

test/Conversion/XeGPUToVC/xegpu-to-vc.mlir

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ module @gemm attributes {gpu.container_module} {
2222

2323
// CHECK: %[[A_STRUCT:.*]] = arith.constant dense<0> : vector<4xi64>
2424
// CHECK: %[[A_BASEPTR:.*]] = memref.extract_aligned_pointer_as_index {{.*}} : memref<128xf16> -> index
25-
// CHECK: %[[A_BASEADDR:.*]] = arith.index_castui %[[A_BASEPTR]] : index to i64
25+
// CHECK: %[[A_ELEMBYTES:.*]] = arith.constant 2 : index
26+
// CHECK: %[[A_OFFSET:.*]] = arith.constant 0 : index
27+
// CHECK: %[[A_STRIDE:.*]] = arith.muli %[[A_OFFSET]], %[[A_ELEMBYTES]] : index
28+
// CHECK: %[[A_UPDATEDBASEPTR:.*]] = arith.addi %[[A_BASEPTR]], %[[A_STRIDE]] : index
29+
// CHECK: %[[A_BASEADDR:.*]] = arith.index_castui %[[A_UPDATEDBASEPTR]] : index to i64
2630
// CHECK: %[[A_PAYLOAD_v4i64:.*]] = vector.insert %[[A_BASEADDR]], %[[A_STRUCT]] [0] : i64 into vector<4xi64>
2731
// CHECK: %[[A_PAYLOAD_v8i32:.*]] = vector.bitcast %[[A_PAYLOAD_v4i64]] : vector<4xi64> to vector<8xi32>
2832
%0 = xegpu.create_nd_tdesc %arg00[0] : memref<128xf16> -> !xegpu.tensor_desc<128xf16>
@@ -44,7 +48,11 @@ module @gemm attributes {gpu.container_module} {
4448

4549
// CHECK: %[[C_STRUCT:.*]] = arith.constant dense<0> : vector<4xi64>
4650
// CHECK: %[[C_BASEPTR:.*]] = memref.extract_aligned_pointer_as_index {{.*}} : memref<128xf32> -> index
47-
// CHECK: %[[C_BASE:.*]] = arith.index_castui %[[C_BASEPTR]] : index to i64
51+
// CHECK: %[[C_ELEMBYTES:.*]] = arith.constant 4 : index
52+
// CHECK: %[[C_OFFSET:.*]] = arith.constant 0 : index
53+
// CHECK: %[[C_STRIDE:.*]] = arith.muli %[[C_OFFSET]], %[[C_ELEMBYTES]] : index
54+
// CHECK: %[[C_UPDATEDBASEPTR:.*]] = arith.addi %[[C_BASEPTR]], %[[C_STRIDE]] : index
55+
// CHECK: %[[C_BASE:.*]] = arith.index_castui %[[C_UPDATEDBASEPTR]] : index to i64
4856
// CHECK: %[[C_PAYLOAD:.*]] = vector.insert %[[C_BASE]], %[[C_STRUCT]] [0] : i64 into vector<4xi64>
4957
// CHECK: %[[C_PAYLOAD_v8i32:.*]] = vector.bitcast %[[C_PAYLOAD]] : vector<4xi64> to vector<8xi32>
5058
%2 = xegpu.create_nd_tdesc %arg02[0] : memref<128xf32> -> !xegpu.tensor_desc<128xf32>
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// RUN: %python_executable %imex_runner --requires=l0-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc-rawsend-false.pp \
2+
// RUN: --runner imex-cpu-runner -e main \
3+
// RUN: --entry-point-result=void \
4+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%levelzero_runtime --filecheck
5+
// RUN: %python_executable %imex_runner --requires=sycl-runtime -i %s --pass-pipeline-file=%p/xegpu-to-func-vc-rawsend-false.pp \
6+
// RUN: --runner imex-cpu-runner -e main \
7+
// RUN: --entry-point-result=void \
8+
// RUN: --shared-libs=%irunner_utils,%mlir_runner_utils,%mlir_c_runner_utils,%sycl_runtime --filecheck
9+
module @gemm attributes {gpu.container_module} {
10+
memref.global "private" constant @__constant_8x16xf32 : memref<8x16xf32> = dense<0.0>
11+
func.func @test(%arg0: memref<8x16xf32>, %arg1: memref<8x16xf32>) -> memref<8x16xf32> attributes {llvm.emit_c_interface} {
12+
%c1 = arith.constant 1 : index
13+
%c8 = arith.constant 8 : index
14+
15+
%memref = gpu.alloc host_shared () : memref<8x16xf32>
16+
memref.copy %arg0, %memref : memref<8x16xf32> to memref<8x16xf32>
17+
%memref_1 = gpu.alloc host_shared () : memref<8x16xf32>
18+
memref.copy %arg1, %memref_1 : memref<8x16xf32> to memref<8x16xf32>
19+
%memref_2 = gpu.alloc host_shared () : memref<8x16xf32>
20+
gpu.launch_func @test_kernel::@test_kernel blocks in (%c1, %c1, %c1) threads in (%c8, %c1, %c1) args(%memref : memref<8x16xf32>, %memref_1 : memref<8x16xf32>, %memref_2 : memref<8x16xf32>)
21+
gpu.dealloc %memref : memref<8x16xf32>
22+
gpu.dealloc %memref_1 : memref<8x16xf32>
23+
return %memref_2 : memref<8x16xf32>
24+
}
25+
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute]>, api=OpenCL, #spirv.resource_limits<>>} {
26+
gpu.func @test_kernel(%arg0: memref<8x16xf32>, %arg1: memref<8x16xf32>, %arg2: memref<8x16xf32>) kernel attributes {VectorComputeFunctionINTEL, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
27+
%thread_id_x = gpu.thread_id x
28+
cf.br ^bb1
29+
^bb1:
30+
%0 = xegpu.create_nd_tdesc %arg1[%thread_id_x, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<16xf32>
31+
%1 = xegpu.load_nd %0 : !xegpu.tensor_desc<16xf32> -> vector<16xf32>
32+
%2 = xegpu.create_nd_tdesc %arg0[%thread_id_x, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<16xf32>
33+
%3 = xegpu.load_nd %2 : !xegpu.tensor_desc<16xf32> -> vector<16xf32>
34+
%4 = arith.addf %3, %1 : vector<16xf32>
35+
%5 = xegpu.create_nd_tdesc %arg2[%thread_id_x, 0] : memref<8x16xf32> -> !xegpu.tensor_desc<16xf32>
36+
xegpu.store_nd %4, %5 : vector<16xf32>, !xegpu.tensor_desc<16xf32>
37+
gpu.return
38+
}
39+
}
40+
func.func @main() attributes {llvm.emit_c_interface} {
41+
%c_gen_int = arith.constant 0 : i1
42+
%cf_lower = arith.constant -0.5 : f32
43+
%cf_upper = arith.constant 0.5 : f32
44+
45+
%A = memref.alloc() : memref<8x16xf32>
46+
%A_random = memref.cast %A : memref<8x16xf32> to memref<*xf32>
47+
call @fillResource1DRandomF32(%A_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf32>, f32, f32, i1) -> ()
48+
49+
%B = memref.alloc() : memref<8x16xf32>
50+
%B_random = memref.cast %B : memref<8x16xf32> to memref<*xf32>
51+
call @fillResource1DRandomF32(%B_random, %cf_lower, %cf_upper, %c_gen_int) : (memref<*xf32>, f32, f32, i1) -> ()
52+
53+
// calculate the result C matrix
54+
%c16 = arith.constant 16 : index
55+
%c8 = arith.constant 8 : index
56+
%c1 = arith.constant 1 : index
57+
%c0 = arith.constant 0 : index
58+
%ref = memref.alloc() : memref<8x16xf32>
59+
scf.for %i = %c0 to %c8 step %c1 {
60+
scf.for %j = %c0 to %c16 step %c1 {
61+
%a = memref.load %A[%i, %j] : memref<8x16xf32>
62+
%b = memref.load %B[%i, %j] : memref<8x16xf32>
63+
%c = arith.addf %a, %b : f32
64+
memref.store %c, %ref[%i, %j] : memref<8x16xf32>
65+
}
66+
}
67+
68+
%C = call @test(%A, %B) : (memref<8x16xf32>, memref<8x16xf32>) -> memref<8x16xf32>
69+
70+
%C_cast = memref.cast %C : memref<8x16xf32> to memref<*xf32>
71+
%ref_cast = memref.cast %ref : memref<8x16xf32> to memref<*xf32>
72+
// call @printMemrefF32(%C_cast) : (memref<*xf32>) -> ()
73+
// CHECK: [ALLCLOSE: TRUE]
74+
call @printAllcloseF32(%ref_cast, %C_cast) : (memref<*xf32>, memref<*xf32>) -> ()
75+
return
76+
}
77+
func.func private @printMemrefF32(memref<*xf32>) attributes {llvm.emit_c_interface}
78+
func.func private @fillResource1DRandomF32(memref<*xf32>, f32, f32, i1) attributes {llvm.emit_c_interface}
79+
func.func private @printAllcloseF32(memref<*xf32>, memref<*xf32>) attributes {llvm.emit_c_interface}
80+
}
81+
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// gpu dialect with intel intrinsic functions (func dialect) to
2+
// llvm dialect (for host code) and
3+
// spirv dialect (for device code) lowering pipeline.
4+
// Ready for imex runner starting from GPU dialect.
5+
builtin.module(
6+
imex-vector-linearize
7+
gpu.module(convert-xegpu-to-vc{useRawSend=false})
8+
reconcile-unrealized-casts
9+
bf16-to-gpu
10+
imex-convert-gpu-to-spirv
11+
spirv.module(spirv-lower-abi-attrs
12+
spirv-update-vce)
13+
func.func(llvm-request-c-wrappers)
14+
serialize-spirv
15+
convert-vector-to-scf
16+
convert-gpu-to-gpux
17+
convert-scf-to-cf
18+
convert-cf-to-llvm
19+
convert-vector-to-llvm
20+
convert-index-to-llvm
21+
convert-arith-to-llvm
22+
convert-func-to-llvm
23+
convert-math-to-llvm
24+
convert-gpux-to-llvm
25+
convert-index-to-llvm
26+
expand-strided-metadata
27+
lower-affine
28+
finalize-memref-to-llvm
29+
reconcile-unrealized-casts)
30+
// End

0 commit comments

Comments
 (0)