@@ -215,17 +215,16 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
215215 // LayoutAttr
216216
217217 auto computeTileShapeAndCount = [&](ArrayRef<int64_t > shape,
218- DenseI32ArrayAttr sgDataAttr,
219- DenseI32ArrayAttr sgLayoutAttr) {
218+ DenseI32ArrayAttr sgDataAttr,
219+ DenseI32ArrayAttr sgLayoutAttr) {
220220 SmallVector<int64_t > tileShape;
221221 auto sgLayout = llvm::to_vector_of<int64_t >(sgLayoutAttr.asArrayRef ());
222222 if (sgDataAttr)
223223 tileShape = llvm::to_vector_of<int64_t >(sgDataAttr.asArrayRef ());
224224 else
225225 tileShape = computeShapeRatio (shape, sgLayout).value_or (tileShape);
226226 assert (tileShape.size () && " failed to compute tileShape" );
227- SmallVector<int64_t > distUnit =
228- computeElementwiseMul (sgLayout, tileShape);
227+ SmallVector<int64_t > distUnit = computeElementwiseMul (sgLayout, tileShape);
229228 int count = computeProduct (shape) / computeProduct (distUnit);
230229 return std::make_pair (tileShape, count);
231230 };
@@ -249,8 +248,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
249248 if (layout.isWgLayout ()) {
250249 // for WgToSg, the subShape is either from sgData or computed as
251250 // shape/sgLayout
252- std::tie (subShape, count) = computeTileShapeAndCount (
253- shape, layout.getSgData (), layout.getSgLayout ());
251+ std::tie (subShape, count) = computeTileShapeAndCount (shape, layout.getSgData (), layout.getSgLayout ());
254252 } else if (DenseI32ArrayAttr instData = layout.getInstData ()) {
255253 // for unrolling, the subShape is determined by inst_data
256254 subShape = llvm::to_vector_of<int64_t >(instData.asArrayRef ());
@@ -280,8 +278,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
280278 if (layout.isWgLayout ()) {
281279 // for WgToSg, the subShape is either from sgData or computed as
282280 // shape/sgLayout
283- std::tie (subShape, count) = computeTileShapeAndCount (
284- shape, layout.getSgData (), layout.getSgLayout ());
281+ std::tie (subShape, count) = computeTileShapeAndCount (shape, layout.getSgData (), layout.getSgLayout ());
285282 layout = layout.dropSgLayoutAndData ();
286283 } else if (DenseI32ArrayAttr instData = layout.getInstData ()) {
287284 // for unrolling, the subShape is determined by inst_data
@@ -298,7 +295,11 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
298295 });
299296
300297 converter.addSourceMaterialization (materializeCast);
301- converter.addTargetMaterialization (materializeCast);
298+ converter.addTargetMaterialization ([&](OpBuilder &builder, TypeRange type,
299+ ValueRange inputs, Location loc) {
300+ return builder.create <UnrealizedConversionCastOp>(loc, type, inputs)
301+ .getResults ();
302+ });
302303
303304 mlir::ConversionTarget target (*context);
304305 target.addLegalOp <UnrealizedConversionCastOp>();
0 commit comments