Skip to content

Commit e240e47

Browse files
committed
Temp save.
1 parent 687e831 commit e240e47

File tree

5 files changed

+150
-41
lines changed

5 files changed

+150
-41
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 94 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
10+
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
1011
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
1112

1213
#include "mlir/Conversion/LLVMCommon/Pattern.h"
@@ -426,18 +427,6 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
426427
}
427428
};
428429

429-
template <
430-
typename OpType,
431-
typename = std::enable_if_t<llvm::is_one_of<
432-
OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp, xegpu::CreateDescOp,
433-
xegpu::UpdateOffsetOp, xegpu::PrefetchOp>::value>>
434-
int64_t getElemByteSize(OpType op) {
435-
// Get the element byte size from the tensor descriptor.
436-
auto elemBitWidth =
437-
op.getTensorDesc().getType().getElementType().getIntOrFloatBitWidth();
438-
return elemBitWidth / 8;
439-
}
440-
441430
// Add a builder that creates
442431
// offset * elemByteSize + baseAddr
443432
auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc,
@@ -456,23 +445,23 @@ class CreateDescToXeVMPattern
456445
LogicalResult
457446
matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
458447
ConversionPatternRewriter &rewriter) const override {
448+
auto eTy = op.getTensorDescType().getElementType();
449+
if (eTy.getIntOrFloatBitWidth() % 8 != 0) {
450+
return rewriter.notifyMatchFailure(op,
451+
"Expected element type bit width to be multiple of 8.");
452+
}
459453
auto loc = op.getLoc();
454+
// offsets are provided as scalar i64 by type converter.
460455
auto offsets = adaptor.getOffsets();
461-
// Source type can be a 1D memref or ui64
462-
// Using "op" instead of "adaptor" since we want to access memref type
463-
// instead of LLVM struct type.
464-
auto memrefTy = dyn_cast<MemRefType>(op.getSource().getType());
465-
Value subGroupAddr;
466-
if (memrefTy) {
467-
subGroupAddr = memref::ExtractAlignedPointerAsIndexOp::create(
468-
rewriter, loc, op.getSource());
469-
subGroupAddr = arith::IndexCastUIOp::create(
470-
rewriter, loc, rewriter.getI64Type(), subGroupAddr);
471-
} else {
472-
subGroupAddr = adaptor.getSource();
473-
}
456+
// Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32).
457+
// But type converter will convert them to integer types.
458+
Value addr = adaptor.getSource();
459+
// ui32 or i32 are passed as i32 so they need to be casted to i64.
460+
if (addr.getType() != rewriter.getI64Type())
461+
addr = arith::IndexCastUIOp::create(
462+
rewriter, loc, rewriter.getI64Type(), addr);
474463
auto laneAddr =
475-
addOffset(rewriter, loc, subGroupAddr, offsets, getElemByteSize(op));
464+
addOffset(rewriter, loc, addr, offsets, getElemByteSize(op));
476465
rewriter.replaceOp(op, laneAddr);
477466
return success();
478467
}
@@ -485,11 +474,18 @@ class UpdateOffsetToXeVMPattern
485474
matchAndRewrite(xegpu::UpdateOffsetOp op,
486475
xegpu::UpdateOffsetOp::Adaptor adaptor,
487476
ConversionPatternRewriter &rewriter) const override {
477+
auto eTy = op.getTensorDescType().getElementType();
478+
if (eTy.getIntOrFloatBitWidth() % 8 != 0) {
479+
return rewriter.notifyMatchFailure(op,
480+
"Expected element type bit width to be multiple of 8.");
481+
}
488482
auto loc = op.getLoc();
489-
Value newOffsetForLane =
483+
// scatter descriptor is provided as scalar i64 by type converter.
484+
// offsets are provided as scalar i64 by type converter.
485+
Value newOffset =
490486
addOffset(rewriter, loc, adaptor.getTensorDesc(), adaptor.getOffsets(),
491487
getElemByteSize(op));
492-
rewriter.replaceOp(op, newOffsetForLane);
488+
rewriter.replaceOp(op, newOffset);
493489
return success();
494490
}
495491
};
@@ -505,19 +501,38 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
505501
auto loc = op.getLoc();
506502
auto ctxt = rewriter.getContext();
507503
auto tdescTy = op.getTensorDescType();
508-
auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
509-
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
504+
LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
505+
ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
506+
if (tdescTy)
507+
ptrTypeLLVM = LLVM::LLVMPointerType::get(
508+
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
510509
Value basePtrI64;
511510
if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
512511
basePtrI64 = adaptor.getSource();
512+
if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
513+
auto addrSpace = memRefTy.getMemorySpaceAsInt();
514+
if (addrSpace != 0)
515+
ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
516+
}
513517
} else {
514518
basePtrI64 = adaptor.getDest();
519+
if (auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
520+
auto addrSpace = memRefTy.getMemorySpaceAsInt();
521+
if (addrSpace != 0)
522+
ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
523+
}
515524
}
525+
if (basePtrI64.getType() != rewriter.getI64Type()) {
526+
basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(),
527+
basePtrI64);
528+
}
529+
basePtrI64.dump();
516530
Value offsets = adaptor.getOffsets();
531+
offsets.dump();
517532
Value mask = adaptor.getMask();
533+
mask.dump();
518534
if (offsets) {
519-
VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
520-
if (offsetsVecTy) {
535+
if (dyn_cast<VectorType>(offsets.getType())){
521536
// Offset needs be scalar.
522537
return rewriter.notifyMatchFailure(op,
523538
"Expected offsets to be a scalar.");
@@ -526,8 +541,10 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
526541
addOffset(rewriter, loc, basePtrI64, offsets, getElemByteSize(op));
527542
}
528543
}
544+
basePtrI64.dump();
529545
Value basePtrLLVM =
530546
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
547+
basePtrLLVM.dump();
531548
VectorType srcOrDstVecTy = op.getValueType();
532549
VectorType srcOrDstFlatVecTy = VectorType::get(
533550
srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
@@ -597,6 +614,10 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
597614
ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
598615
Value basePtrI64 = adaptor.getSource();
599616
Value offsets = adaptor.getOffsets();
617+
if (basePtrI64.getType() != rewriter.getI64Type()) {
618+
basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(),
619+
basePtrI64);
620+
}
600621
if (offsets) {
601622
VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
602623
if (offsetsVecTy) {
@@ -836,6 +857,26 @@ struct ConvertXeGPUToXeVMPass
836857
auto i32Type = IntegerType::get(&getContext(), 32);
837858
return VectorType::get(8, i32Type);
838859
});
860+
typeConverter.addConversion([&](MemRefType type) -> Type {
861+
// Convert MemRefType to i64 type.
862+
return IntegerType::get(&getContext(), 64);
863+
});
864+
865+
auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
866+
ValueRange inputs,
867+
Location loc) -> Value {
868+
if (inputs.size() != 1)
869+
return {};
870+
auto input = inputs.front();
871+
if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
872+
873+
Value addr = memref::ExtractAlignedPointerAsIndexOp::create(
874+
builder, loc, input);
875+
return arith::IndexCastUIOp::create(builder, loc, type,
876+
addr).getResult();
877+
}
878+
return {};
879+
};
839880

840881
auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
841882
ValueRange inputs,
@@ -847,7 +888,22 @@ struct ConvertXeGPUToXeVMPass
847888
Value cast =
848889
index::CastUOp::create(builder, loc, builder.getIndexType(), input)
849890
.getResult();
850-
return arith::IndexCastOp::create(builder, loc, type, cast).getResult();
891+
return arith::IndexCastUIOp::create(builder, loc, type, cast).getResult();
892+
}
893+
return {};
894+
};
895+
896+
auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
897+
ValueRange inputs,
898+
Location loc) -> Value {
899+
if (inputs.size() != 1)
900+
return {};
901+
auto input = inputs.front();
902+
if (input.getType() == builder.getIntegerType(32, false)) {
903+
Value cast =
904+
index::CastUOp::create(builder, loc, builder.getIndexType(), input)
905+
.getResult();
906+
return arith::IndexCastUIOp::create(builder, loc, type, cast).getResult();
851907
}
852908
return {};
853909
};
@@ -864,15 +920,19 @@ struct ConvertXeGPUToXeVMPass
864920
Value cast =
865921
vector::ExtractOp::create(builder, loc, input, 0).getResult();
866922
if (vecTy.getElementType() == builder.getIndexType())
867-
cast = arith::IndexCastOp::create(builder, loc, type, cast)
923+
cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
868924
.getResult();
869925
return cast;
870926
}
871927
}
872928
return {};
873929
};
930+
typeConverter.addSourceMaterialization(memrefMaterializationCast);
874931
typeConverter.addSourceMaterialization(ui64MaterializationCast);
932+
typeConverter.addSourceMaterialization(ui32MaterializationCast);
875933
typeConverter.addSourceMaterialization(vector1DMaterializationCast);
934+
typeConverter.addTargetMaterialization(memrefMaterializationCast);
935+
typeConverter.addTargetMaterialization(ui32MaterializationCast);
876936
typeConverter.addTargetMaterialization(ui64MaterializationCast);
877937
typeConverter.addTargetMaterialization(vector1DMaterializationCast);
878938
ConversionTarget target(getContext());

mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ gpu.module @create_nd_tdesc {
66
// CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index
77
gpu.func @create_nd_tdesc(%src: memref<8x16xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
88
%stride1: index, %stride2: index) kernel {
9-
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
10-
// CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
9+
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
10+
// CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
1111
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<8xi32>
1212
// CHECK: %[[C0_I32:.*]] = arith.constant 0 : i32
1313
// CHECK: %[[C0_I32_0:.*]] = arith.constant 0 : i32

mlir/test/Conversion/XeGPUToXeVM/loadstoreprefetch.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ gpu.module @test {
55
// CHECK-SAME: %[[ARG0:.*]]: ui64
66
gpu.func @load_gather_ui64_src_constant_offset(%src: ui64) {
77
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG0]] : ui64 to index
8-
// CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
8+
// CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
99
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
1010
// CHECK: %[[VAR2:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
11-
// CHECK: %[[VAR3:.*]] = arith.index_cast %[[VAR2]] : index to i64
11+
// CHECK: %[[VAR3:.*]] = arith.index_castui %[[VAR2]] : index to i64
1212
%0 = arith.constant dense<0> : vector<1xindex>
1313
// CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<1xi1>
1414
// CHECK: %[[VAR4:.*]] = vector.extract %[[CST_0]][0] : i1 from vector<1xi1>
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// RUN: mlir-opt -convert-xegpu-to-xevm %s | FileCheck %s
2+
3+
gpu.module @materializecast {
4+
// CHECK-LABEL: gpu.func @materialize_memref
5+
// CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
6+
gpu.func @materialize_memref(%src: memref<128xf32>) kernel {
7+
// CHECK: XXX
8+
%offset = arith.constant dense<0> : vector<1xindex>
9+
%src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
10+
-> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
11+
gpu.return
12+
}
13+
// CHECK-LABEL: gpu.func @materialize_ui64
14+
// CHECK-SAME: %[[ARG0:.*]]: ui64
15+
gpu.func @materialize_ui64(%src: ui64) kernel {
16+
// CHECK: XXX
17+
%offset = arith.constant dense<0> : vector<1xindex>
18+
%src_tdesc = xegpu.create_tdesc %src, %offset : ui64, vector<1xindex>
19+
-> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
20+
gpu.return
21+
}
22+
// CHECK-LABEL: gpu.func @materialize_ui32
23+
// CHECK-SAME: %[[ARG0:.*]]: ui32
24+
gpu.func @materialize_ui32(%src: ui32) kernel {
25+
%offset = arith.constant dense<0> : vector<1xindex>
26+
//%src_tdesc = xegpu.create_tdesc %src, %offset : ui32, vector<1xindex>
27+
// -> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
28+
gpu.return
29+
}
30+
// CHECK-LABEL: gpu.func @materialize_single_index_vector
31+
// CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
32+
gpu.func @materialize_single_index_vector(%src: memref<128xf32>) kernel {
33+
// CHECK: XXX
34+
%offset = arith.constant dense<0> : vector<1xindex>
35+
%src_tdesc = xegpu.create_tdesc %src, %offset : memref<128xf32>, vector<1xindex>
36+
-> !xegpu.tensor_desc<1xf32, #xegpu.scatter_tdesc_attr<>>
37+
gpu.return
38+
}
39+
// CHECK-LABEL: gpu.func @materialize_single_elem_vector
40+
// CHECK-SAME: %[[ARG0:.*]]: vector<1xi1>
41+
gpu.func @materialize_single_elem_vector(%src: memref<128xf32>) kernel {
42+
// CHECK: XXX
43+
%mask = arith.constant dense<1>: vector<1xi1>
44+
%offset = arith.constant dense<0> : vector<1xindex>
45+
%0 = xegpu.load %src[%offset], %mask <{chunk_size=8, l1_hint = #xegpu.cache_hint<cached>, l2_hint = #xegpu.cache_hint<uncached>}>
46+
: memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<1x8xf32>
47+
gpu.return
48+
}
49+
}

mlir/test/Conversion/XeGPUToXeVM/update_offset.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ gpu.module @update_offset {
44
// CHECK-LABEL: gpu.func @update_offset
55
// CHECK-SAME: %[[ARG0:.*]]: memref<128xf32>
66
gpu.func @update_offset(%src: memref<128xf32>) kernel {
7+
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index
8+
// CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
79
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
810
%offset = arith.constant dense<0> : vector<1xindex>
911
// CHECK: %[[VAR0:.*]] = vector.extract %[[CST]][0] : index from vector<1xindex>
10-
// CHECK: %[[VAR1:.*]] = arith.index_cast %[[VAR0]] : index to i64
11-
// CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<128xf32> -> index
12-
// CHECK: %[[VAR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
12+
// CHECK: %[[VAR1:.*]] = arith.index_castui %[[VAR0]] : index to i64
1313
// CHECK: %[[C4_I64:.*]] = arith.constant 4 : i64
1414
// CHECK: %[[VAR3:.*]] = arith.muli %[[VAR1]], %[[C4_I64]] : i64
1515
// CHECK: %[[VAR4:.*]] = arith.addi %[[VAR2]], %[[VAR3]] : i64

0 commit comments

Comments
 (0)