Skip to content

Commit de0a1bb

Browse files
committed
add unit test
1 parent 62aa1dd commit de0a1bb

File tree

5 files changed

+203
-2
lines changed

5 files changed

+203
-2
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,10 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
410410

411411
let extraClassDeclaration = [{
412412

413+
int64_t getRank() const {
414+
return getParent().getRank() - getDims().size();
415+
}
416+
413417
DenseI32ArrayAttr getOrder() const {
414418
return getParent().getOrder();
415419
}

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
296296
// XeGPU_SliceAttr
297297
//===----------------------------------------------------------------------===//
298298
LogicalResult
299-
SliceAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
299+
SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
300300
xegpu::LayoutAttr parent, DenseI64ArrayAttr dims) {
301301
if (!parent || !dims)
302302
return emitError() << "expected parent layout and dims attribute";
@@ -322,7 +322,68 @@ SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
322322
FailureOr<SmallVector<SmallVector<Value>>>
323323
SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
324324
ArrayRef<int64_t> shape) {
325-
return failure();
325+
assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
326+
if (!isWgLayout())
327+
return failure();
328+
329+
auto sgLayout = getEffectiveSgLayout().value();
330+
331+
SmallVector<int64_t> sgShape;
332+
if (auto maybeSgShape = getEffectiveSgData())
333+
sgShape = maybeSgShape.value();
334+
else if (auto ratio = computeShapeRatio(shape, sgLayout))
335+
sgShape = ratio.value();
336+
else
337+
return failure();
338+
339+
// distUnit[i] is the minimum value between shape[i] and
340+
// sgLayout[i] * sgShape[i]
341+
SmallVector<int64_t> distUnit = llvm::map_to_vector(
342+
llvm::zip_equal(shape, computeElementwiseMul(sgLayout, sgShape)),
343+
[](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
344+
345+
// delinearize Ids
346+
auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
347+
if (failed(maybeIds))
348+
return failure();
349+
// The effective sgIds for offsets computing correspond
350+
// to the dims that are not sliced.
351+
ArrayRef<int64_t> dims = getDims().asArrayRef();
352+
SmallVector<Value> sgIds =
353+
XeGPUDialect::dropDims(ArrayRef<Value>(*maybeIds), dims);
354+
355+
// nd local offset, localOffset[i] = sgId[i] * sgShape[i]
356+
SmallVector<Value> localOffsets = llvm::map_to_vector(
357+
llvm::zip(sgIds, sgShape), [&](const auto &t) -> Value {
358+
return builder.createOrFold<index::MulOp>(
359+
loc, std::get<0>(t),
360+
builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
361+
});
362+
363+
SmallVector<SmallVector<Value>> offsets;
364+
for (SmallVector<int64_t> unitOffs : StaticTileOffsetRange(shape, distUnit)) {
365+
SmallVector<Value> base =
366+
llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
367+
return builder.create<arith::ConstantIndexOp>(loc, d);
368+
});
369+
370+
SmallVector<Value> adds = llvm::map_to_vector(
371+
llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
372+
return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
373+
std::get<1>(t));
374+
});
375+
376+
SmallVector<Value> mods = llvm::map_to_vector(
377+
llvm::zip_equal(adds, distUnit), [&](const auto &t) -> Value {
378+
return builder.createOrFold<index::RemUOp>(
379+
loc, std::get<0>(t),
380+
builder.create<arith::ConstantIndexOp>(loc, std::get<1>(t)));
381+
});
382+
383+
offsets.push_back(mods);
384+
}
385+
386+
return offsets;
326387
}
327388

328389
//===----------------------------------------------------------------------===//

mlir/test/Dialect/XeGPU/layout.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,10 @@ gpu.func @convert_layout_wg(%a: vector<32x64xf16>) {
5050
gpu.return
5151
}
5252

53+
gpu.func @slice_attr_repeat_dim() {
54+
//CHECK: arith.constant {layout_result_0 = #xegpu.slice<<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>, dims = [2]>} dense<8> : vector<16x8xindex>
55+
%cst = arith.constant {layout_result_0 = #xegpu.slice<<sg_layout = [16, 1, 1], sg_data = [1, 8, 2]>, dims = [2]>} dense<8> : vector<16x8xindex>
56+
gpu.return
57+
}
58+
5359
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt --test-xegpu-layout-interface --cse -split-input-file %s | FileCheck %s
2+
3+
#block = #xegpu.layout<sg_layout = [4, 8], sg_data = [32, 32]>
4+
#slice = #xegpu.slice<#block, dims=[1]>
5+
6+
//CHECk: #map = affine_map<()[s0] -> (s0 floordiv 8)>
7+
gpu.module @test_1_1_assignment {
8+
gpu.func @create_nd_tdesc() -> vector<128xindex> {
9+
//CHECK: [[sgId:%.+]] = gpu.subgroup_id : index
10+
//CHECK: [[IDY:%.+]] = affine.apply #map()[[[sgId]]]
11+
//CHECK: [[c32:%.+]] = arith.constant 32 : index
12+
//CHECK: [[LOCALY:%.+]] = index.mul [[IDY]], [[c32]]
13+
//CHECK: [[c0:%.+]] = arith.constant 0 : index
14+
//CHECK: [[Y:%.+]] = arith.addi [[LOCALY]], [[c0]] : index
15+
//CHECK: [[c128:%.+]] = arith.constant 128 : index
16+
//CHECK: [[MODY:%.+]] = index.remu [[Y]], [[c128]]
17+
//CHECK: [[BASE:%.+]] = vector.step : vector<32xindex>
18+
//CHECK: [[CAST:%.+]] = vector.broadcast [[MODY]] : index to vector<32xindex>
19+
//CHECK: [[ADD:%.+]] = arith.addi [[BASE]], [[CAST]] : vector<32xindex>
20+
%step = vector.step {layout_result_0 = #slice}: vector<128xindex>
21+
gpu.return %step : vector<128xindex>
22+
}
23+
}

mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
10+
#include "mlir/Dialect/Index/IR/IndexDialect.h"
1011
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1112
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1213
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
14+
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
1315
#include "mlir/Pass/Pass.h"
1416
#include "mlir/Pass/PassManager.h"
17+
#include "mlir/Transforms/DialectConversion.h"
1518
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1619

1720
using namespace mlir;
@@ -149,12 +152,116 @@ struct TestXeGPUUnrollingPatterns
149152
}
150153
};
151154

155+
#undef DEBUG_TYPE
156+
#define DEBUG_TYPE "test-xegpu-layout-interface"
157+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
158+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
159+
160+
class TestStepOpPattern : public OpConversionPattern<vector::StepOp> {
161+
using OpConversionPattern<vector::StepOp>::OpConversionPattern;
162+
163+
LogicalResult
164+
matchAndRewrite(vector::StepOp op, OneToNOpAdaptor adaptor,
165+
ConversionPatternRewriter &rewriter) const override {
166+
167+
auto layoutName = xegpu::getLayoutName(op->getResult(0));
168+
auto sliceAttr = op->getAttrOfType<xegpu::SliceAttr>(layoutName);
169+
if (!sliceAttr || sliceAttr.getRank() != 1)
170+
return failure();
171+
172+
std::optional<SmallVector<int64_t>> sgShape =
173+
sliceAttr.getEffectiveSgData();
174+
if (!sgShape)
175+
return failure();
176+
177+
Location loc = op.getLoc();
178+
VectorType type = op.getResult().getType();
179+
auto wgShape = type.getShape();
180+
181+
Value sgId =
182+
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
183+
auto maybeOffsets = sliceAttr.getOffsets(rewriter, loc, sgId, wgShape);
184+
if (failed(maybeOffsets))
185+
return failure();
186+
187+
VectorType newTy = type.cloneWith(*sgShape, type.getElementType());
188+
Value base = vector::StepOp::create(rewriter, loc, newTy);
189+
SmallVector<Value> newOps;
190+
for (auto offsets : *maybeOffsets) {
191+
Value bcast =
192+
vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
193+
Value add = arith::AddIOp::create(rewriter, loc, base, bcast);
194+
newOps.push_back(add);
195+
}
196+
rewriter.replaceOpWithMultiple(op, {newOps});
197+
return success();
198+
}
199+
};
200+
201+
struct TestXeGPULayoutInterface
202+
: public PassWrapper<TestXeGPULayoutInterface,
203+
OperationPass<gpu::GPUModuleOp>> {
204+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPULayoutInterface)
205+
206+
StringRef getArgument() const final { return "test-xegpu-layout-interface"; }
207+
208+
StringRef getDescription() const final {
209+
return "Test the implementation of XeGPU Layout interfaces";
210+
}
211+
212+
void getDependentDialects(::mlir::DialectRegistry &registry) const override {
213+
registry.insert<arith::ArithDialect>();
214+
registry.insert<memref::MemRefDialect>();
215+
registry.insert<xegpu::XeGPUDialect>();
216+
registry.insert<vector::VectorDialect>();
217+
registry.insert<index::IndexDialect>();
218+
}
219+
220+
TestXeGPULayoutInterface() = default;
221+
TestXeGPULayoutInterface(const TestXeGPULayoutInterface &pass)
222+
: PassWrapper(pass) {}
223+
224+
void runOnOperation() override {
225+
MLIRContext *ctx = &getContext();
226+
227+
TypeConverter typeConverter;
228+
auto materializeCast = [&](mlir::OpBuilder &builder, mlir::Type type,
229+
mlir::ValueRange inputs,
230+
mlir::Location loc) -> mlir::Value {
231+
return builder.create<UnrealizedConversionCastOp>(loc, type, inputs)
232+
.getResult(0);
233+
};
234+
typeConverter.addSourceMaterialization(materializeCast);
235+
typeConverter.addTargetMaterialization(materializeCast);
236+
237+
RewritePatternSet patterns(ctx);
238+
patterns.add<TestStepOpPattern>(typeConverter, ctx);
239+
240+
ConversionTarget target(*ctx);
241+
auto isLegal = [&](xegpu::SliceAttr layout) -> bool {
242+
return !layout || !layout.isWgLayout();
243+
};
244+
245+
target.addDynamicallyLegalOp<vector::StepOp>(
246+
[&](vector::StepOp op) -> bool {
247+
auto layoutName = xegpu::getLayoutName(op->getResult(0));
248+
auto sliceAttr = op->getAttrOfType<xegpu::SliceAttr>(layoutName);
249+
return isLegal(sliceAttr);
250+
});
251+
252+
target.markUnknownOpDynamicallyLegal([](Operation *op) { return true; });
253+
254+
(void)applyPartialConversion(getOperation(), target, std::move(patterns));
255+
}
256+
};
257+
152258
} // namespace
153259

154260
namespace mlir {
155261
namespace test {
156262
void registerTestXeGPULowerings() {
157263
PassRegistration<TestXeGPUUnrollingPatterns>();
264+
PassRegistration<TestXeGPULayoutInterface>();
158265
}
159266
} // namespace test
160267
} // namespace mlir

0 commit comments

Comments
 (0)