Skip to content

Commit 71a38e8

Browse files
committed
[mlir][xegpu][transformops] add insert_prefetch op
1 parent 2681497 commit 71a38e8

File tree

6 files changed

+386
-1
lines changed

6 files changed

+386
-1
lines changed

mlir/include/mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.td

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,4 +161,47 @@ def SetOpLayoutAttrOp : Op<Transform_Dialect, "xegpu.set_op_layout_attr", [
161161
}];
162162
}
163163

164+
def InsertPrefetchOp : Op<Transform_Dialect, "xegpu.insert_prefetch", [
165+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
166+
TransformOpInterface
167+
]> {
168+
169+
let summary = "Adds xegpu prefetch ops to matmul operand tiles.";
170+
let description = [{
171+
Given a target value (e.g., `vector`) residing in a `scf.for` loop, this
172+
transform finds the corresponding `xegpu.load_nd` op and inserts
173+
`xegpu.prefetch` operations for the tile. The load op must reside within the
174+
`scf.for` loop. Number of prefetch steps is set by the `nb_prefetch`
175+
argument. Returns a handle to the created `xegpu.create_nd_desc` op.
176+
}];
177+
178+
let arguments = (ins TransformValueHandleTypeInterface:$target,
179+
Optional<TransformAnyParamTypeOrAnyHandle>:$dynamic_nb_prefetch,
180+
DefaultValuedOptionalAttr<I64Attr, "1">:$static_nb_prefetch
181+
);
182+
183+
let results = (outs TransformHandleTypeInterface:$desc_op);
184+
185+
let assemblyFormat = [{
186+
$target
187+
`nb_prefetch` `=` ($dynamic_nb_prefetch^):($static_nb_prefetch)?
188+
attr-dict `:` functional-type(operands, results)
189+
}];
190+
191+
let extraClassDeclaration = [{
192+
::mlir::DiagnosedSilenceableFailure apply(
193+
::mlir::transform::TransformRewriter &rewriter,
194+
::mlir::transform::TransformResults &transformResults,
195+
::mlir::transform::TransformState &state);
196+
197+
OpFoldResult getNbPrefetch() {
198+
auto cxt = getContext();
199+
if (getDynamicNbPrefetch())
200+
return OpFoldResult(getDynamicNbPrefetch());
201+
return OpFoldResult(IntegerAttr::get(
202+
IntegerType::get(cxt, 64), getStaticNbPrefetch()));
203+
}
204+
}];
205+
}
206+
164207
#endif // XEGPU_TRANSFORM_OPS

mlir/lib/Dialect/XeGPU/TransformOps/XeGPUTransformOps.cpp

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Dialect/XeGPU/TransformOps/XeGPUTransformOps.h"
1010
#include "mlir/Dialect/SCF/IR/SCF.h"
11+
#include "mlir/Dialect/SCF/Utils/Utils.h"
1112
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1213
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
1314

@@ -341,6 +342,143 @@ void transform::SetOpLayoutAttrOp::getEffects(
341342
modifiesPayload(effects);
342343
}
343344

345+
DiagnosedSilenceableFailure
346+
transform::InsertPrefetchOp::apply(transform::TransformRewriter &rewriter,
347+
transform::TransformResults &results,
348+
transform::TransformState &state) {
349+
auto targetValues = state.getPayloadValues(getTarget());
350+
if (!llvm::hasSingleElement(targetValues)) {
351+
return emitDefiniteFailure()
352+
<< "requires exactly one target value handle (got "
353+
<< llvm::range_size(targetValues) << ")";
354+
}
355+
auto value = *targetValues.begin();
356+
357+
int64_t nbPrefetch = getStaticNbPrefetch();
358+
if (getDynamicNbPrefetch()) {
359+
// Get dynamic prefetch count from transform param or handle.
360+
SmallVector<int32_t> dynamicNbPrefetch;
361+
auto status = convertMixedValuesToInt(state, (*this), dynamicNbPrefetch,
362+
{getDynamicNbPrefetch()});
363+
if (!status.succeeded())
364+
return status;
365+
if (dynamicNbPrefetch.size() != 1) {
366+
return emitDefiniteFailure()
367+
<< "requires exactly one value for dynamic_nb_prefetch";
368+
}
369+
nbPrefetch = dynamicNbPrefetch[0];
370+
}
371+
if (nbPrefetch <= 0) {
372+
return emitSilenceableFailure(getLoc())
373+
<< "nb_prefetch must be a positive integer.";
374+
}
375+
376+
// Find load operation of the operand.
377+
auto maybeLoadOp = findProducerOfType<xegpu::LoadNdOp>(value);
378+
if (!maybeLoadOp) {
379+
return emitSilenceableFailure(getLoc()) << "Could not find load op.";
380+
}
381+
auto loadOp = *maybeLoadOp;
382+
if (loadOp.getMixedOffsets().size() == 0) {
383+
auto diag = emitSilenceableFailure(getLoc())
384+
<< "Load op must have offsets.";
385+
diag.attachNote(loadOp.getLoc()) << "load op";
386+
return diag;
387+
}
388+
389+
// Find the parent scf.for loop.
390+
auto forOp = loadOp->getParentOfType<scf::ForOp>();
391+
if (!forOp) {
392+
auto diag = emitSilenceableFailure(getLoc())
393+
<< "Load op is not contained in a scf.for loop.";
394+
diag.attachNote(loadOp.getLoc()) << "load op";
395+
return diag;
396+
}
397+
398+
// Find descriptor op.
399+
auto maybeDescOp = findProducerOfType<xegpu::CreateNdDescOp>(value);
400+
if (!maybeDescOp) {
401+
return emitSilenceableFailure(getLoc()) << "Could not find descriptor op.";
402+
}
403+
auto descOp = *maybeDescOp;
404+
if (descOp.getMixedOffsets().size() > 0) {
405+
auto diag = emitSilenceableFailure(getLoc())
406+
<< "desc op with offsets is not supported.";
407+
diag.attachNote(descOp.getLoc()) << "desc op";
408+
}
409+
410+
// Clone desc op outside the loop.
411+
rewriter.setInsertionPoint(forOp);
412+
auto newDescOp =
413+
cast<xegpu::CreateNdDescOp>(rewriter.clone(*descOp.getOperation()));
414+
415+
// Clone reduction loop to emit initial prefetches.
416+
// Compute upper bound of the init loop: start + nbPrefetch * step.
417+
auto nbPrefetchCst =
418+
arith::ConstantIndexOp::create(rewriter, forOp.getLoc(), nbPrefetch);
419+
auto nbStep = rewriter.createOrFold<arith::MulIOp>(
420+
forOp.getLoc(), nbPrefetchCst, forOp.getStep());
421+
auto initUpBound = rewriter.createOrFold<arith::AddIOp>(
422+
forOp.getLoc(), forOp.getLowerBound(), nbStep);
423+
auto initForOp =
424+
scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
425+
initUpBound, forOp.getStep());
426+
427+
auto ctx = rewriter.getContext();
428+
auto readCacheHint =
429+
xegpu::CachePolicyAttr::get(ctx, xegpu::CachePolicy::CACHED);
430+
431+
// Modify loadOp mixedOffsets by replacing the for loop induction variable
432+
// with the given value.
433+
auto getPrefetchOffsets =
434+
[&](Value replacementVal) -> SmallVector<OpFoldResult> {
435+
IRMapping mapping;
436+
mapping.map(forOp.getInductionVar(), replacementVal);
437+
SmallVector<Value> dynamicOffsets =
438+
llvm::to_vector(llvm::map_range(loadOp.getOffsets(), [&](Value v) {
439+
return mapping.lookupOrDefault(v);
440+
}));
441+
auto constOffsets = loadOp.getConstOffsets().value();
442+
return getMixedValues(constOffsets, dynamicOffsets, ctx);
443+
};
444+
445+
// Insert prefetch op in init loop.
446+
// Replace induction var with the init loop induction var.
447+
rewriter.setInsertionPointToStart(initForOp.getBody());
448+
xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
449+
newDescOp.getResult(),
450+
getPrefetchOffsets(initForOp.getInductionVar()),
451+
readCacheHint, readCacheHint, readCacheHint);
452+
453+
// Insert prefetch op in main loop.
454+
// Calculate prefetch offset after the init prefetches have been issued.
455+
rewriter.setInsertionPointToStart(forOp.getBody());
456+
auto prefetchOffset = arith::AddIOp::create(rewriter, forOp.getLoc(),
457+
forOp.getInductionVar(), nbStep);
458+
// Replace induction var with correct offset.
459+
xegpu::PrefetchNdOp::create(rewriter, newDescOp.getLoc(),
460+
newDescOp.getResult(),
461+
getPrefetchOffsets(prefetchOffset), readCacheHint,
462+
readCacheHint, readCacheHint);
463+
464+
// Unroll the init loop.
465+
if (failed(loopUnrollFull(initForOp))) {
466+
return emitSilenceableFailure(getLoc()) << "Failed to unroll the loop";
467+
}
468+
469+
results.set(llvm::cast<OpResult>(getResult()), {newDescOp});
470+
471+
return DiagnosedSilenceableFailure::success();
472+
}
473+
474+
void transform::InsertPrefetchOp::getEffects(
475+
::llvm::SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
476+
onlyReadsHandle(getTargetMutable(), effects);
477+
onlyReadsHandle(getDynamicNbPrefetchMutable(), effects);
478+
producesHandle(getOperation()->getOpResults(), effects);
479+
modifiesPayload(effects);
480+
}
481+
344482
namespace {
345483
class XeGPUTransformDialectExtension
346484
: public transform::TransformDialectExtension<

mlir/python/mlir/dialects/transform/xegpu.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .._ods_common import _cext as _ods_cext
1212
from .._ods_common import (
1313
MixedValues,
14+
MixedInt,
1415
get_op_result_or_value as _get_op_result_or_value,
1516
_dispatch_dynamic_index_list,
1617
)
@@ -132,3 +133,34 @@ def __init__(
132133
loc=loc,
133134
ip=ip,
134135
)
136+
137+
138+
@_ods_cext.register_operation(_Dialect, replace=True)
139+
class InsertPrefetchOp(InsertPrefetchOp):
140+
"""Specialization for InsertPrefetchOp class."""
141+
142+
def __init__(
143+
self,
144+
target: Value,
145+
*,
146+
nb_prefetch: Optional[MixedInt] = 1,
147+
loc=None,
148+
ip=None,
149+
):
150+
static_nb_prefetch = 1
151+
dynamic_nb_prefetch = None
152+
if isinstance(nb_prefetch, int):
153+
static_nb_prefetch = nb_prefetch
154+
elif isinstance(nb_prefetch, IntegerAttr):
155+
static_nb_prefetch = nb_prefetch.value # pytype: disable=attribute-error
156+
elif isinstance(nb_prefetch, (Operation, Value, OpView)):
157+
dynamic_nb_prefetch = nb_prefetch
158+
159+
super().__init__(
160+
transform.AnyOpType.get(),
161+
target,
162+
dynamic_nb_prefetch=dynamic_nb_prefetch,
163+
static_nb_prefetch=static_nb_prefetch,
164+
loc=loc,
165+
ip=ip,
166+
)

mlir/test/Dialect/XeGPU/transform-ops-invalid.mlir

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,34 @@ module attributes {transform.with_named_sequence} {
7171
transform.yield
7272
}
7373
}
74+
75+
// -----
76+
77+
// CHECK-LABEL: @insert_prefetch_dpas_c
78+
func.func @insert_prefetch_dpas_c(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
79+
%c32 = arith.constant 32 : index
80+
%c4096 = arith.constant 4096 : index
81+
%c0 = arith.constant 0 : index
82+
%0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
83+
// expected-note@below {{load op}}
84+
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
85+
%3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
86+
%4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
87+
%2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
88+
%5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
89+
%6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
90+
%7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
91+
scf.yield %7 : vector<256x256xf16>
92+
}
93+
return
94+
}
95+
96+
module attributes {transform.with_named_sequence} {
97+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
98+
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op
99+
%1 = transform.get_operand %0[2] : (!transform.any_op) -> !transform.any_value
100+
// expected-error@below {{Load op is not contained in a scf.for loop.}}
101+
%2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
102+
transform.yield
103+
}
104+
}

mlir/test/Dialect/XeGPU/transform-ops.mlir

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,79 @@ module attributes {transform.with_named_sequence} {
252252
transform.yield
253253
}
254254
}
255+
256+
// -----
257+
258+
// CHECK-LABEL: @insert_prefetch_dpas_a
259+
func.func @insert_prefetch_dpas_a(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
260+
%c32 = arith.constant 32 : index
261+
%c4096 = arith.constant 4096 : index
262+
%c0 = arith.constant 0 : index
263+
%0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
264+
%1 = xegpu.load_nd %0[%c0, %c0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
265+
// CHECK: xegpu.create_nd_tdesc %arg0
266+
// CHECK: xegpu.create_nd_tdesc %arg1
267+
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
268+
// CHECK-SAME: !xegpu.tensor_desc<256x32xf16
269+
// CHECK: xegpu.prefetch_nd %[[V0]]
270+
%3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
271+
%4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
272+
// CHECK: scf.for
273+
%2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
274+
// CHECK: xegpu.prefetch_nd %[[V0]]
275+
%5 = xegpu.load_nd %3[%c0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
276+
%6 = xegpu.load_nd %4[%arg3, %c0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
277+
%7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
278+
scf.yield %7 : vector<256x256xf16>
279+
}
280+
return
281+
}
282+
283+
module attributes {transform.with_named_sequence} {
284+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
285+
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op
286+
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
287+
// CHECK: transform.xegpu.insert_prefetch %{{.*}}
288+
%2 = transform.xegpu.insert_prefetch %1 nb_prefetch = 1 : (!transform.any_value) -> !transform.any_op
289+
transform.yield
290+
}
291+
}
292+
293+
// -----
294+
295+
// CHECK-LABEL: @insert_prefetch_dpas_a_nb_param2
296+
func.func @insert_prefetch_dpas_a_nb_param2(%arg0: memref<4096x4096xf16>, %arg1: memref<4096x4096xf16>, %arg2: memref<4096x4096xf16>) {
297+
%c32 = arith.constant 32 : index
298+
%c4096 = arith.constant 4096 : index
299+
%c0 = arith.constant 0 : index
300+
%0 = xegpu.create_nd_tdesc %arg2 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x256xf16>
301+
%1 = xegpu.load_nd %0[0, 0] : !xegpu.tensor_desc<256x256xf16> -> vector<256x256xf16>
302+
// CHECK: xegpu.create_nd_tdesc %arg0
303+
// CHECK: xegpu.create_nd_tdesc %arg1
304+
// CHECK: %[[V0:.+]] = xegpu.create_nd_tdesc %arg0
305+
// CHECK-SAME: !xegpu.tensor_desc<256x32xf16
306+
// CHECK: xegpu.prefetch_nd %[[V0]]
307+
// CHECK: xegpu.prefetch_nd %[[V0]]
308+
%3 = xegpu.create_nd_tdesc %arg0 : memref<4096x4096xf16> -> !xegpu.tensor_desc<256x32xf16>
309+
%4 = xegpu.create_nd_tdesc %arg1 : memref<4096x4096xf16> -> !xegpu.tensor_desc<32x256xf16>
310+
// CHECK: scf.for
311+
%2 = scf.for %arg3 = %c0 to %c4096 step %c32 iter_args(%arg4 = %1) -> (vector<256x256xf16>) {
312+
// CHECK: xegpu.prefetch_nd %[[V0]]
313+
%5 = xegpu.load_nd %3[0, %arg3] : !xegpu.tensor_desc<256x32xf16> -> vector<256x32xf16>
314+
%6 = xegpu.load_nd %4[%arg3, 0] : !xegpu.tensor_desc<32x256xf16> -> vector<32x256xf16>
315+
%7 = xegpu.dpas %5, %6, %arg4 : vector<256x32xf16>, vector<32x256xf16>, vector<256x256xf16> -> vector<256x256xf16>
316+
scf.yield %7 : vector<256x256xf16>
317+
}
318+
return
319+
}
320+
321+
module attributes {transform.with_named_sequence} {
322+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
323+
%0 = transform.structured.match ops{["xegpu.dpas"]} in %arg0 : (!transform.any_op) -> !transform.any_op
324+
%1 = transform.get_operand %0[0] : (!transform.any_op) -> !transform.any_value
325+
%nb = transform.param.constant 2 : i64 -> !transform.param<i64>
326+
// CHECK: transform.xegpu.insert_prefetch %{{.*}}
327+
%2 = transform.xegpu.insert_prefetch %1 nb_prefetch = %nb : (!transform.any_value, !transform.param<i64>) -> !transform.any_op
328+
transform.yield
329+
}
330+
}

0 commit comments

Comments
 (0)