|
20 | 20 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
21 | 21 | #include "mlir/Dialect/SCF/IR/SCF.h" |
22 | 22 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
| 23 | +#include "mlir/Dialect/Vector/IR/VectorOps.h" |
23 | 24 | #include "mlir/Dialect/XeGPU/IR/XeGPU.h" |
24 | 25 | #include "mlir/Pass/Pass.h" |
25 | 26 | #include "mlir/Support/LLVM.h" |
@@ -390,7 +391,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { |
390 | 391 | // Load result or Store valye Type can be vector or scalar. |
391 | 392 | Type valOrResTy; |
392 | 393 | if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) |
393 | | - valOrResTy = op.getResult().getType(); |
| 394 | + valOrResTy = |
| 395 | + this->getTypeConverter()->convertType(op.getResult().getType()); |
394 | 396 | else |
395 | 397 | valOrResTy = adaptor.getValue().getType(); |
396 | 398 | VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy); |
@@ -878,10 +880,30 @@ struct ConvertXeGPUToXeVMPass |
878 | 880 | } |
879 | 881 | return {}; |
880 | 882 | }; |
881 | | - typeConverter.addSourceMaterialization(memrefMaterializationCast); |
882 | | - typeConverter.addSourceMaterialization(ui64MaterializationCast); |
883 | | - typeConverter.addSourceMaterialization(ui32MaterializationCast); |
884 | | - typeConverter.addSourceMaterialization(vectorMaterializationCast); |
| 883 | + |
| 884 | + // If result type of original op is single element vector and lowered type |
| 885 | + // is scalar. This materialization cast creates a single element vector by |
| 886 | + // broadcasting the scalar value. |
| 887 | + auto singleElementVectorMaterializationCast = |
| 888 | + [](OpBuilder &builder, Type type, ValueRange inputs, |
| 889 | + Location loc) -> Value { |
| 890 | + if (inputs.size() != 1) |
| 891 | + return {}; |
| 892 | + auto input = inputs.front(); |
| 893 | + if (input.getType().isIntOrIndexOrFloat()) { |
| 894 | + // If the input is a scalar, and the target type is a vector of single |
| 895 | + // element, create a single element vector by broadcasting. |
| 896 | + if (auto vecTy = dyn_cast<VectorType>(type)) { |
| 897 | + if (vecTy.getNumElements() == 1) { |
| 898 | + return vector::BroadcastOp::create(builder, loc, vecTy, input) |
| 899 | + .getResult(); |
| 900 | + } |
| 901 | + } |
| 902 | + } |
| 903 | + return {}; |
| 904 | + }; |
| 905 | + typeConverter.addSourceMaterialization( |
| 906 | + singleElementVectorMaterializationCast); |
885 | 907 | typeConverter.addTargetMaterialization(memrefMaterializationCast); |
886 | 908 | typeConverter.addTargetMaterialization(ui32MaterializationCast); |
887 | 909 | typeConverter.addTargetMaterialization(ui64MaterializationCast); |
|
0 commit comments