Skip to content

Commit 0bc798e

Browse files
tkarnagit-crd
authored andcommitted
[MLIR][XeGPU][TransformOps] Add insert_prefetch op (llvm#167356)
Adds `transform.xegpu.insert_prefetch` transform op that inserts `xegpu.prefetch_nd` ops for the given `Value` in an `scf.for` loop.
1 parent cd9f75c commit 0bc798e

File tree

6 files changed

+408
-1
lines changed

6 files changed

+408
-1
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,4 +200,48 @@ def SetGPULaunchThreadsOp
200200
}];
201201
}
202202

203+
def InsertPrefetchOp : Op<Transform_Dialect, "xegpu.insert_prefetch", [
204+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
205+
TransformOpInterface
206+
]> {
207+
208+
let summary = "Adds xegpu prefetch ops to matmul operand tiles.";
209+
let description = [{
210+
Given a target value (e.g., `vector`) residing in a `scf.for` loop, this
211+
transform finds the corresponding `xegpu.load_nd` op and inserts
212+
`xegpu.prefetch_nd` operations for the tile. The load op must reside within
213+
the `scf.for` loop. Number of prefetch steps is set by the `nb_prefetch`
214+
argument (default value is 1). Returns a handle to the created
215+
`xegpu.create_nd_desc` op.
216+
}];
217+
218+
let arguments = (ins TransformValueHandleTypeInterface:$target,
219+
Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_nb_prefetch,
220+
DefaultValuedOptionalAttr<I64Attr, "1">:$static_nb_prefetch
221+
);
222+
223+
let results = (outs TransformHandleTypeInterface:$desc_op);
224+
225+
let assemblyFormat = [{
226+
$target
227+
`nb_prefetch` `=` ($dynamic_nb_prefetch^):($static_nb_prefetch)?
228+
attr-dict `:` functional-type(operands, results)
229+
}];
230+
231+
let extraClassDeclaration = [{
232+
::mlir::DiagnosedSilenceableFailure apply(
233+
::mlir::transform::TransformRewriter &rewriter,
234+
::mlir::transform::TransformResults &transformResults,
235+
::mlir::transform::TransformState &state);
236+
237+
OpFoldResult getNbPrefetch() {
238+
auto cxt = getContext();
239+
if (getDynamicNbPrefetch())
240+
return OpFoldResult(getDynamicNbPrefetch());
241+
return OpFoldResult(IntegerAttr::get(
242+
IntegerType::get(cxt, 64), getStaticNbPrefetch()));
243+
}
244+
}];
245+
}
246+
203247
#endif // XEGPU_TRANSFORM_OPS

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
1010
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1111
#include "mlir/Dialect/SCF/IR/SCF.h"
12+
#include "mlir/Dialect/SCF/Utils/Utils.h"
1213
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1314
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
1415

@@ -405,6 +406,137 @@ void transform::SetGPULaunchThreadsOp::getEffects(
405406
modifiesPayload(effects);
406407
}
407408

409+
DiagnosedSilenceableFailure
410+
transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
411+
transform::TransformResults &results,
412+
transform::TransformState &state) {
413+
auto targetValues = state.getPayloadValues(getTarget());
414+
if (!llvm::hasSingleElement(targetValues))
415+
return emitDefiniteFailure()
416+
<< "requires exactly one target value handle (got "
417+
<< llvm::range_size(targetValues) << ")";
418+
auto value = *targetValues.begin();
419+
420+
int64_t nbPrefetch = getStaticNbPrefetch();
421+
if (getDynamicNbPrefetch()) {
422+
// Get dynamic prefetch count from transform param or handle.
423+
SmallVector<int32_t> dynamicNbPrefetch;
424+
auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch,
425+
{getDynamicNbPrefetch()});
426+
if (!status.succeeded())
427+
return status;
428+
if (dynamicNbPrefetch.size() != 1)
429+
return emitDefiniteFailure()
430+
<< "requires exactly one value for dynamic_nb_prefetch";
431+
nbPrefetch = dynamicNbPrefetch[0];
432+
}
433+
if (nbPrefetch <= 0)
434+
return emitSilenceableFailure(getLoc())
435+
<< "nb_prefetch must be a positive integer.";
436+
437+
// Find load operation of the operand.
438+
auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
439+
if (!maybeLoadOp)
440+
return emitSilenceableFailure(getLoc()) << "Could not find load op.";
441+
auto loadOp = *maybeLoadOp;
442+
if (loadOp.getMixedOffsets().size() == 0) {
443+
auto diag = emitSilenceableFailure(getLoc())
444+
<< "Load op must have offsets.";
445+
diag.attachNote(loadOp.getLoc()) << "load op";
446+
return diag;
447+
}
448+
449+
// Find the parent scf.for loop.
450+
auto forOp = loadOp->getParentOfType<scf::ForOp>();
451+
if (!forOp) {
452+
auto diag = emitSilenceableFailure(getLoc())
453+
<< "Load op is not contained in a scf.for loop.";
454+
diag.attachNote(loadOp.getLoc()) << "load op";
455+
return diag;
456+
}
457+
458+
// Find descriptor op.
459+
auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value);
460+
if (!maybeDescOp)
461+
return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
462+
auto descOp = *maybeDescOp;
463+
if (descOp.getMixedOffsets().size() > 0) {
464+
auto diag = emitSilenceableFailure(getLoc())
465+
<< "desc op with offsets is not supported.";
466+
diag.attachNote(descOp.getLoc()) << "desc op";
467+
}
468+
469+
// Clone desc op outside the loop.
470+
rewriter.setInsertionPoint(forOp);
471+
auto newDescOp =
472+
cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation()));
473+
474+
// Clone reduction loop to emit initial prefetches.
475+
// Compute upper bound of the init loop: start + nbPrefetch * step.
476+
auto nbPrefetchCst =
477+
arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch);
478+
auto nbStep = rewriter.createOrFold<arith::MulIOp>(
479+
forOp.getLoc(), nbPrefetchCst, forOp.getStep());
480+
auto initUpBound = rewriter.createOrFold<arith::AddIOp>(
481+
forOp.getLoc(), forOp.getLowerBound(), nbStep);
482+
auto initForOp =
483+
scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
484+
initUpBound, forOp.getStep());
485+
486+
auto ctx = rewriter.getContext();
487+
auto readCacheHint =
488+
xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
489+
490+
// Modify loadOp mixedOffsets by replacing the for loop induction variable
491+
// with the given value.
492+
auto getPrefetchOffsets =
493+
[&](Value replacementVal) -> SmallVector<OpFoldResult> {
494+
IRMapping mapping;
495+
mapping.map(forOp.getInductionVar(), replacementVal);
496+
SmallVector<Value> dynamicOffsets =
497+
llvm::to_vector(llvm::map_range(loadOp.getOffsets(), [&](Value v) {
498+
return mapping.lookupOrDefault(v);
499+
}));
500+
auto constOffsets = loadOp.getConstOffsets().value();
501+
return getMixedValues(constOffsets, dynamicOffsets, ctx);
502+
};
503+
504+
// Insert prefetch op in init loop.
505+
// Replace induction var with the init loop induction var.
506+
rewriter.setInsertionPointToStart(initForOp.getBody());
507+
xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
508+
newDescOp.getResult(),
509+
getPrefetchOffsets(initForOp.getInductionVar()),
510+
readCacheHint, readCacheHint, readCacheHint);
511+
512+
// Insert prefetch op in main loop.
513+
// Calculate prefetch offset after the init prefetches have been issued.
514+
rewriter.setInsertionPointToStart(forOp.getBody());
515+
auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
516+
forOp.getInductionVar(), nbStep);
517+
// Replace induction var with correct offset.
518+
xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
519+
newDescOp.getResult(),
520+
getPrefetchOffsets(prefetchOffset), readCacheHint,
521+
readCacheHint, readCacheHint);
522+
523+
// Unroll the init loop.
524+
if (failed(loopUnrollFull(initForOp)))
525+
return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop";
526+
527+
results.set(llvm::cast<OpResult>(getResult()), {newDescOp});
528+
529+
return DiagnosedSilenceableFailure::success();
530+
}
531+
532+
void transform::InsertPrefetchOp::getEffects(
533+
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
534+
onlyReadsHandle(getTargetMutable(), effects);
535+
onlyReadsHandle(getDynamicNbPrefetchMutable(), effects);
536+
producesHandle(getOperation()->getOpResults(), effects);
537+
modifiesPayload(effects);
538+
}
539+
408540
namespace {
409541
class XeGPUTransformDialectExtension
410542
: public transform::TransformDialectExtension<

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .._ods_common import _cext as _ods_cext
1212
from .._ods_common import (
1313
MixedValues,
14+
MixedInt,
1415
get_op_result_or_value as _get_op_result_or_value,
1516
_dispatch_dynamic_index_list,
1617
)
@@ -134,6 +135,7 @@ def __init__(
134135
)
135136

136137

138+
@_ods_cext.register_operation(_Dialect, replace=True)
137139
class SetGPULaunchThreadsOp(SetGPULaunchThreadsOp):
138140
"""Specialization for SetGPULaunchThreadsOp class."""
139141

@@ -168,3 +170,44 @@ def set_gpu_launch_threads(
168170
ip=None,
169171
) -> SetGPULaunchThreadsOp:
170172
return SetGPULaunchThreadsOp(launch_op, threads, loc=loc, ip=ip)
173+
174+
175+
@_ods_cext.register_operation(_Dialect, replace=True)
176+
class InsertPrefetchOp(InsertPrefetchOp):
177+
"""Specialization for InsertPrefetchOp class."""
178+
179+
def __init__(
180+
self,
181+
target: Value,
182+
*,
183+
nb_prefetch: Optional[MixedInt] = 1,
184+
loc=None,
185+
ip=None,
186+
):
187+
static_nb_prefetch = 1
188+
dynamic_nb_prefetch = None
189+
if isinstance(nb_prefetch, int):
190+
static_nb_prefetch = nb_prefetch
191+
elif isinstance(nb_prefetch, IntegerAttr):
192+
static_nb_prefetch = nb_prefetch.value # pytype: disable=attribute-error
193+
elif isinstance(nb_prefetch, (Operation, Value, OpView)):
194+
dynamic_nb_prefetch = nb_prefetch
195+
196+
super().__init__(
197+
transform.AnyOpType.get(),
198+
target,
199+
dynamic_nb_prefetch=dynamic_nb_prefetch,
200+
static_nb_prefetch=static_nb_prefetch,
201+
loc=loc,
202+
ip=ip,
203+
)
204+
205+
206+
def insert_prefetch(
207+
target: Value,
208+
*,
209+
nb_prefetch: Optional[MixedInt] = 1,
210+
loc=None,
211+
ip=None,
212+
) -> OpResult:
213+
return InsertPrefetchOp(target, nb_prefetch=nb_prefetch, loc=loc, ip=ip).result

mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,34 @@ module attributes {transform.with_named_sequence} {
124124
transform.yield
125125
}
126126
}
127+
128+
// -----
129+
130+
// CHECK-LABEL: @insert_prefetch_dpas_c
131+
func.func @insert_prefetch_dpas_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
132+
%c32 = arith.constant 32 : index
133+
%c4096 = arith.constant 4096 : index
134+
%c0 = arith.constant 0 : index
135+
%0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
136+
// expected-note@below {{load op}}
137+
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
138+
%3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
139+
%4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
140+
%2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
141+
%5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
142+
%6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
143+
%7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
144+
scf.yield %7 : vector<256x256xf16>
145+
}
146+
return
147+
}
148+
149+
module attributes {transform.with_named_sequence} {
150+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
151+
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op
152+
%1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value
153+
// expected-error@below {{Load op is not contained in a scf.for loop.}}
154+
%2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
155+
transform.yield
156+
}
157+
}

mlir/test/Dialect/XeGPU/transform-ops.mlir

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,95 @@ module attributes {transform.with_named_sequence} {
308308
transform.yield
309309
}
310310
}
311+
312+
// -----
313+
314+
// CHECK-LABEL: @insert_prefetch_dpas_a
315+
func.func @insert_prefetch_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
316+
// CHECK: %[[C32:.+]] = arith.constant 32 : index
317+
%c32 = arith.constant 32 : index
318+
%c4096 = arith.constant 4096 : index
319+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
320+
%c0 = arith.constant 0 : index
321+
%0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
322+
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
323+
// CHECK: xegpu.create_nd_tdesc %arg0
324+
// CHECK: xegpu.create_nd_tdesc %arg1
325+
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
326+
// CHECK-SAME: !xegpu.tensor_desc<256x32xf16
327+
// CHECK: xegpu.prefetch_nd %[[V0]][%[[C0]], %[[C0]]]
328+
%3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
329+
%4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
330+
// CHECK: scf.for %[[ARG3:.+]] = %[[C0]]
331+
%2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
332+
// CHECK: %[[ADD:.+]] = arith.addi %[[ARG3]], %[[C32]]
333+
// CHECK: xegpu.prefetch_nd %[[V0]][%[[C0]], %[[ADD]]]
334+
%5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
335+
%6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
336+
%7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
337+
scf.yield %7 : vector<256x256xf16>
338+
}
339+
return
340+
}
341+
342+
module attributes {transform.with_named_sequence} {
343+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
344+
%func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
345+
%0 = transform.structured.match ops{["xegpu.dpas"]} in %func : (!transform.any_op) -> !transform.any_op
346+
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
347+
// CHECK: transform.xegpu.insert_prefetch %{{.*}}
348+
%2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
349+
transform.apply_patterns to %func {
350+
transform.apply_patterns.canonicalization
351+
} : !transform.any_op
352+
353+
transform.yield
354+
}
355+
}
356+
357+
// -----
358+
359+
// CHECK-LABEL: @insert_prefetch_dpas_a_nb_param2
360+
func.func @insert_prefetch_dpas_a_nb_param2(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
361+
// CHECK: %[[C64:.+]] = arith.constant 64 : index
362+
// CHECK: %[[C32:.+]] = arith.constant 32 : index
363+
%c32 = arith.constant 32 : index
364+
%c4096 = arith.constant 4096 : index
365+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
366+
%c0 = arith.constant 0 : index
367+
%0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
368+
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
369+
// CHECK: xegpu.create_nd_tdesc %arg0
370+
// CHECK: xegpu.create_nd_tdesc %arg1
371+
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
372+
// CHECK-SAME: !xegpu.tensor_desc<256x32xf16
373+
// CHECK: xegpu.prefetch_nd %[[V0]][0, %[[C0]]]
374+
// CHECK: xegpu.prefetch_nd %[[V0]][0, %[[C32]]]
375+
%3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
376+
%4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
377+
// CHECK: scf.for %[[ARG3:.+]] = %[[C0]]
378+
%2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
379+
// CHECK: %[[ADD:.+]] = arith.addi %[[ARG3]], %[[C64]]
380+
// CHECK: xegpu.prefetch_nd %[[V0]][0, %[[ADD]]]
381+
%5 = xegpu.load_nd %3[0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
382+
%6 = xegpu.load_nd %4[%arg3, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
383+
%7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
384+
scf.yield %7 : vector<256x256xf16>
385+
}
386+
return
387+
}
388+
389+
module attributes {transform.with_named_sequence} {
390+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
391+
%func = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
392+
%0 = transform.structured.match ops{["xegpu.dpas"]} in %func : (!transform.any_op) -> !transform.any_op
393+
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
394+
%nb = transform.param.constant 2 : i64 -> !transform.param<i64>
395+
// CHECK: transform.xegpu.insert_prefetch %{{.*}}
396+
%2 = transform.xegpu.insert_prefetch %1 nb_prefetch = %nb : (!transform.any_value, !transform.param<i64>) -> !transform.any_op
397+
transform.apply_patterns to %func {
398+
transform.apply_patterns.canonicalization
399+
} : !transform.any_op
400+
transform.yield
401+
}
402+
}

0 commit comments

Comments
 (0)