Skip to content

Commit 8a0e145

Browse files
committed
add unrolling support scatter operations
1 parent 30b099e commit 8a0e145

File tree

2 files changed

+296
-6
lines changed

2 files changed

+296
-6
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 153 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -409,19 +409,14 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
409409

410410
auto newTdescTy = getUnrolledTypes(tdescTy, *targetShape)[0];
411411

412-
413412
TypedValue<::mlir::VectorType> indiceVec = op.getOffsets();
414-
415413
VectorType indiceVecTy = indiceVec.getType();
416-
417414
SmallVector<Type> convertedIndiceTypes =
418415
getUnrolledTypes(indiceVecTy, *targetShape);
419-
420416
SmallVector<Value> convertedIndiceVec =
421417
pack(indiceVec, convertedIndiceTypes, *targetShape, loc, rewriter);
422418

423419
SmallVector<Value> newOps;
424-
425420
for (auto indice : convertedIndiceVec) {
426421
auto newOp = rewriter.create<xegpu::CreateDescOp>(loc, newTdescTy, op.getSource(), indice);
427422
newOps.push_back(newOp);
@@ -434,12 +429,164 @@ struct UnrollCreateDescOp : public UnrollPattern<xegpu::CreateDescOp> {
434429
}
435430
};
436431

432+
struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
433+
using UnrollPattern<xegpu::LoadGatherOp>::UnrollPattern;
434+
LogicalResult matchAndRewrite(xegpu::LoadGatherOp op,
435+
PatternRewriter &rewriter) const override {
436+
437+
Location loc = op.getLoc();
438+
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
439+
xegpu::TensorDescType tdescTy = op.getTensorDescType();
440+
441+
VectorType maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
442+
443+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
444+
if (!targetShape)
445+
return failure();
446+
447+
Type elemTy = tdescTy.getElementType();
448+
VectorType newValueTy = valueTy.cloneWith(*targetShape, elemTy);
449+
450+
SmallVector<Type> convertedTdescTypes =
451+
getUnrolledTypes(tdescTy, *targetShape);
452+
SmallVector<Value> convertedTdescs = pack(
453+
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
454+
455+
SmallVector<Type> convertedMaskTypes =
456+
getUnrolledTypes(maskTy, *targetShape);
457+
SmallVector<Value> convertedMasks = pack(
458+
op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
459+
460+
SmallVector<Value> newOps;
461+
for (auto [t, m] : llvm::zip(convertedTdescs, convertedMasks)) {
462+
auto newOp =
463+
rewriter.create<xegpu::LoadGatherOp>(loc, newValueTy, t, m,
464+
op.getTransposeAttr(), op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
465+
newOps.push_back(newOp);
466+
}
467+
468+
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
469+
470+
rewriter.replaceOp(op, castOp);
471+
return success();
472+
}
473+
};
474+
475+
struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
476+
using UnrollPattern<xegpu::PrefetchOp>::UnrollPattern;
477+
LogicalResult matchAndRewrite(xegpu::PrefetchOp op,
478+
PatternRewriter &rewriter) const override {
479+
Location loc = op.getLoc();
480+
xegpu::TensorDescType tdescTy = op.getTensorDescType();
481+
482+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
483+
if (!targetShape)
484+
return failure();
485+
486+
SmallVector<Type> convertedTdescTypes =
487+
getUnrolledTypes(tdescTy, *targetShape);
488+
SmallVector<Value> convertedTdesc = pack(
489+
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
490+
491+
for (auto t : convertedTdesc)
492+
rewriter.create<xegpu::PrefetchOp>(loc, TypeRange(), t, op->getAttrs());
493+
494+
rewriter.eraseOp(op);
495+
return success();
496+
}
497+
};
498+
499+
struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
500+
using UnrollPattern<xegpu::StoreScatterOp>::UnrollPattern;
501+
LogicalResult matchAndRewrite(xegpu::StoreScatterOp op,
502+
PatternRewriter &rewriter) const override {
503+
504+
Location loc = op.getLoc();
505+
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
506+
xegpu::TensorDescType tdescTy = op.getTensorDescType();
507+
508+
VectorType maskTy;
509+
if (op.getMask())
510+
maskTy = llvm::dyn_cast<VectorType>(op.getMask().getType());
511+
512+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
513+
if (!targetShape)
514+
return failure();
515+
516+
SmallVector<Type> convertedValTypes =
517+
getUnrolledTypes(valueTy, *targetShape);
518+
SmallVector<Type> convertedTdescTypes =
519+
getUnrolledTypes(tdescTy, *targetShape);
520+
521+
SmallVector<Value> convertedValues =
522+
pack(op.getValue(), convertedValTypes, *targetShape, loc, rewriter);
523+
SmallVector<Value> convertedTdescs =
524+
pack(op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
525+
526+
SmallVector<Value> convertedMasks;
527+
if (op.getMask()) {
528+
SmallVector<Type> convertedMaskTypes =
529+
getUnrolledTypes(maskTy, *targetShape);
530+
convertedMasks =
531+
pack(op.getMask(), convertedMaskTypes, *targetShape, loc, rewriter);
532+
}
533+
534+
for (size_t i = 0; i < convertedValues.size(); ++i) {
535+
Value v = convertedValues[i];
536+
Value t = convertedTdescs[i];
537+
Value m = op.getMask() ? convertedMasks[i] : nullptr;
538+
rewriter.create<xegpu::StoreScatterOp>(
539+
loc, v, t, m, op.getTransposeAttr(), op.getL1HintAttr(),
540+
op.getL2HintAttr(), op.getL3HintAttr());
541+
}
542+
543+
rewriter.eraseOp(op);
544+
return success();
545+
}
546+
};
547+
548+
struct UnrollUpdateOffsetOp : public UnrollPattern<xegpu::UpdateOffsetOp> {
549+
using UnrollPattern<xegpu::UpdateOffsetOp>::UnrollPattern;
550+
LogicalResult matchAndRewrite(xegpu::UpdateOffsetOp op,
551+
PatternRewriter &rewriter) const override {
552+
Location loc = op.getLoc();
553+
xegpu::TensorDescType tdescTy = op.getTensorDescType();
554+
555+
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
556+
if (!targetShape)
557+
return failure();
558+
559+
SmallVector<Type> convertedTdescTypes =
560+
getUnrolledTypes(tdescTy, *targetShape);
561+
SmallVector<Value> convertedTdesc = pack(
562+
op.getTensorDesc(), convertedTdescTypes, *targetShape, loc, rewriter);
563+
564+
TypedValue<::mlir::VectorType> offsetVec = op.getOffsets();
565+
VectorType offsetVecTy = offsetVec.getType();
566+
SmallVector<Type> convertedOffsetTypes =
567+
getUnrolledTypes(offsetVecTy, *targetShape);
568+
SmallVector<Value> convertedOffsetVec =
569+
pack(offsetVec, convertedOffsetTypes, *targetShape, loc, rewriter);
570+
571+
SmallVector<Value> newOps;
572+
for (auto [t, o] : llvm::zip(convertedTdesc, convertedOffsetVec)) {
573+
auto newOp = rewriter.create<xegpu::UpdateOffsetOp>(
574+
loc, t.getType(), t, o);
575+
newOps.push_back(newOp);
576+
}
577+
Value castOp = unpack(newOps, op.getType(), *targetShape, loc, rewriter);
578+
rewriter.replaceOp(op, castOp);
579+
return success();
580+
}
581+
};
582+
437583
} // namespace
438584

439585
void mlir::xegpu::populateXeGPUUnrollPatterns(
440586
RewritePatternSet &patterns, const xegpu::UnrollOptions &options) {
441587
patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
442588
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp,
443-
UnrollCreateDescOp>(
589+
UnrollCreateDescOp, UnrollLoadGatherOp,
590+
UnrollStoreScatterOp, UnrollPrefetchOp, UnrollUpdateOffsetOp>(
444591
patterns.getContext(), options);
445592
}

mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,147 @@ gpu.module @test {
158158
%c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
159159
gpu.return %c : vector<32x32xf32>
160160
}
161+
162+
//-----
163+
164+
// CHECK-LABEL: test_create_tdesc_vec
165+
// CHECK-SAME: [[arg0:%.+]]: ui64
166+
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
167+
gpu.func @test_create_tdesc_vec(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
168+
%cst = arith.constant dense<[
169+
0, 8, 16, 24, 32, 40, 48, 56,
170+
64, 72, 80, 88, 96, 104, 112, 120,
171+
128, 136, 144, 152, 160, 168, 176, 184,
172+
192, 200, 208, 216, 224, 232, 240, 248
173+
]> : vector<32xindex>
174+
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
175+
gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
176+
}
177+
178+
//-----
179+
180+
// CHECK-LABEL: test_create_tdesc_step
181+
// CHECK-SAME: [[arg0:%.+]]: ui64
182+
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
183+
gpu.func @test_create_tdesc_step(%src: ui64) -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>> {
184+
%step = arith.constant dense<8> : vector<32xindex>
185+
%seq = vector.step : vector<32xindex>
186+
%cst = arith.muli %seq, %step : vector<32xindex>
187+
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
188+
gpu.return %tdesc : !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
189+
}
190+
191+
//-----
192+
193+
// CHECK-LABEL: test_load
194+
// CHECK-SAME: [[arg0:%.+]]: ui64
195+
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
196+
// CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
197+
gpu.func @test_load(%src: ui64) -> vector<32xf32> {
198+
%cst = arith.constant dense<[
199+
0, 8, 16, 24, 32, 40, 48, 56,
200+
64, 72, 80, 88, 96, 104, 112, 120,
201+
128, 136, 144, 152, 160, 168, 176, 184,
202+
192, 200, 208, 216, 224, 232, 240, 248
203+
]> : vector<32xindex>
204+
205+
%c17 = arith.constant 17: index
206+
%mask = vector.create_mask %c17: vector<32xi1>
207+
208+
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
209+
%ld = xegpu.load %tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
210+
211+
gpu.return %ld : vector<32xf32>
212+
}
213+
214+
//-----
215+
216+
// CHECK-LABEL: test_prefetch
217+
// CHECK-SAME: [[arg0:%.+]]: ui64
218+
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
219+
// CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
220+
gpu.func @test_prefetch(%src: ui64) {
221+
222+
%cst = arith.constant dense<[
223+
0, 8, 16, 24, 32, 40, 48, 56,
224+
64, 72, 80, 88, 96, 104, 112, 120,
225+
128, 136, 144, 152, 160, 168, 176, 184,
226+
192, 200, 208, 216, 224, 232, 240, 248
227+
]> : vector<32xindex>
228+
229+
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
230+
231+
xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
232+
gpu.return
233+
}
234+
235+
//-----
236+
237+
// CHECK-LABEL: test_store
238+
// CHECK-SAME: [[arg0:%.+]]: ui64
239+
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
240+
// CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
241+
gpu.func @test_store(%src: ui64) {
242+
%cst = arith.constant dense<[
243+
0, 8, 16, 24, 32, 40, 48, 56,
244+
64, 72, 80, 88, 96, 104, 112, 120,
245+
128, 136, 144, 152, 160, 168, 176, 184,
246+
192, 200, 208, 216, 224, 232, 240, 248
247+
]> : vector<32xindex>
248+
249+
%c17 = arith.constant 17: index
250+
%mask = vector.create_mask %c17: vector<32xi1>
251+
252+
%st_vec = arith.constant dense<1023.>: vector<32xf32>
253+
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
254+
xegpu.store %st_vec, %tdesc, %mask: vector<32xf32>, !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1>
255+
256+
gpu.return
257+
}
258+
259+
//-----
260+
261+
// CHECK-LABEL: test_prefetch_load_store_update
262+
// CHECK-SAME: [[arg0:%.+]]: ui64
263+
// CHECK-COUNT-2: xegpu.create_tdesc [[arg0]], {{.*}} : ui64, vector<16xindex> -> !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
264+
// CHECK-COUNT-2: xegpu.prefetch {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>
265+
// CHECK-COUNT-2: xegpu.update_offset {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xindex>
266+
// CHECK-COUNT-2: xegpu.load {{.*}} : !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1> -> vector<16xf32>
267+
// CHECK-COUNT-2: xegpu.store {{.*}} : vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>>, vector<16xi1>
268+
269+
gpu.func @test_prefetch_load_store_update(%src: ui64) {
270+
271+
%cst = arith.constant dense<[
272+
0, 8, 16, 24, 32, 40, 48, 56,
273+
64, 72, 80, 88, 96, 104, 112, 120,
274+
128, 136, 144, 152, 160, 168, 176, 184,
275+
192, 200, 208, 216, 224, 232, 240, 248
276+
]> : vector<32xindex>
277+
278+
%tdesc = xegpu.create_tdesc %src, %cst : ui64, vector<32xindex> -> !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
279+
280+
xegpu.prefetch %tdesc: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>
281+
282+
%delta = arith.constant dense<[
283+
32, 32, 32, 32, 32, 32, 32, 32,
284+
32, 32, 32, 32, 32, 32, 32, 64,
285+
128, 128, 128, 128, 128, 128, 128, 128,
286+
128, 128, 128, 128, 128, 128, 128, 256
287+
]> : vector<32xindex>
288+
%new_tdesc = xegpu.update_offset %tdesc, %delta
289+
: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xindex>
290+
291+
%c17 = arith.constant 17: index
292+
%mask = vector.create_mask %c17: vector<32xi1>
293+
294+
%ld_vec = xegpu.load %new_tdesc, %mask: !xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>, vector<32xi1> -> vector<32xf32>
295+
296+
%st_vec = arith.addf %ld_vec, %ld_vec : vector<32xf32>
297+
xegpu.store %st_vec, %tdesc, %mask:
298+
vector<32xf32>,
299+
!xegpu.tensor_desc<32xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<inst_data = [16]>>,
300+
vector<32xi1>
301+
302+
gpu.return
303+
}
161304
}

0 commit comments

Comments
 (0)