@@ -225,7 +225,7 @@ Value xegpu::createVectorWithShapeFromValues(OpBuilder &builder, Location loc,
225225 return result;
226226}
227227
228- void xegpu::doSCFStructuralTypeConversionWithTensorType (Operation *op) {
228+ void xegpu::doSCFStructuralTypeConversionWithTensorType (Operation *op, TypeConverter converter ) {
229229 MLIRContext *context = op->getContext ();
230230
231231 auto materializeCast = [&](OpBuilder &builder, Type type, ValueRange inputs,
@@ -307,109 +307,11 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
307307
308308 { // perform the conversion from RankedTensorType to VectorType based on the
309309 // LayoutAttr
310- auto computeTileShapeAndCount = [&](ArrayRef<int64_t > shape,
311- DenseI32ArrayAttr sgDataAttr,
312- DenseI32ArrayAttr sgLayoutAttr) {
313- SmallVector<int64_t > tileShape;
314- auto sgLayout = llvm::to_vector_of<int64_t >(sgLayoutAttr.asArrayRef ());
315- if (sgDataAttr)
316- tileShape = llvm::to_vector_of<int64_t >(sgDataAttr.asArrayRef ());
317- else
318- tileShape = computeShapeRatio (shape, sgLayout).value_or (tileShape);
319- assert (tileShape.size () && " failed to compute tileShape" );
320- SmallVector<int64_t > distUnit =
321- computeElementwiseMul (sgLayout, tileShape);
322- int count = computeProduct (shape) / computeProduct (distUnit);
323- return std::make_pair (tileShape, count);
324- };
325-
326- TypeConverter converter;
327- converter.addConversion ([&](Type type) -> Type { return type; });
328- converter.addConversion (
329- [&](RankedTensorType type,
330- SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
331- ArrayRef<int64_t > shape = type.getShape ();
332- auto encoding = type.getEncoding ();
333- Type elemTy = type.getElementType ();
334-
335- // init count and subShape to the default value. If the LayoutAttr
336- // is not present, it will return a VectorType with original shape.
337- int count = 1 ;
338- SmallVector<int64_t > subShape (shape);
339-
340- if (auto layout =
341- llvm::dyn_cast_if_present<xegpu::LayoutAttr>(encoding)) {
342- if (layout.isWgLayout ()) {
343- // for WgToSg, the subShape is either from sgData or computed as
344- // shape/sgLayout
345- std::tie (subShape, count) = computeTileShapeAndCount (
346- shape, layout.getSgData (), layout.getSgLayout ());
347- } else if (DenseI32ArrayAttr instData = layout.getInstData ()) {
348- // for unrolling, the subShape is determined by inst_data
349- subShape = llvm::to_vector_of<int64_t >(instData.asArrayRef ());
350- count = computeProduct (shape) / computeProduct (subShape);
351- }
352- }
353- auto newTy = VectorType::get (subShape, elemTy);
354- result.append (count, newTy);
355- return success ();
356- });
357-
358- converter.addConversion (
359- [&](xegpu::TensorDescType type,
360- SmallVectorImpl<Type> &result) -> std::optional<LogicalResult> {
361- MLIRContext *ctx = type.getContext ();
362- Type elemTy = type.getElementType ();
363- Attribute encoding = type.getEncoding ();
364- ArrayRef<int64_t > shape = type.getShape ();
365-
366- // init count and newTy to the default value. If the layout attribute
367- // is not present, it will return the original type.
368- int count = 1 ;
369- Type newTy = type;
370-
371- if (xegpu::LayoutAttr layout = type.getLayoutAttr ()) {
372- SmallVector<int64_t > subShape (shape);
373- if (layout.isWgLayout ()) {
374- // for WgToSg, the subShape is either from sgData or computed as
375- // shape/sgLayout
376- std::tie (subShape, count) = computeTileShapeAndCount (
377- shape, layout.getSgData (), layout.getSgLayout ());
378- layout = layout.dropSgLayoutAndData ();
379- } else if (DenseI32ArrayAttr instData = layout.getInstData ()) {
380- // for unrolling, the subShape is determined by inst_data
381- subShape = llvm::to_vector_of<int64_t >(instData.asArrayRef ());
382- count = computeProduct (shape) / computeProduct (subShape);
383- layout = layout.dropInstData ();
384- }
385-
386- newTy = xegpu::TensorDescType::get (ctx, subShape, elemTy, encoding,
387- layout);
388- }
389-
390- result.append (count, newTy);
391- return success ();
392- });
393-
394- converter.addSourceMaterialization (materializeCast);
395- converter.addTargetMaterialization ([&](OpBuilder &builder, TypeRange type,
396- ValueRange inputs, Location loc) {
397- return builder.create <UnrealizedConversionCastOp>(loc, type, inputs)
398- .getResults ();
399- });
400-
401- mlir::ConversionTarget target (*context);
402- target.addDynamicallyLegalOp <UnrealizedConversionCastOp>(
403- [&](UnrealizedConversionCastOp op) {
404- auto isTensorTy = [&](Type type) {
405- return isa<RankedTensorType>(type);
406- };
407- if (llvm::any_of (op->getOperandTypes (), isTensorTy) ||
408- llvm::any_of (op->getResultTypes (), isTensorTy))
409- return false ;
410- return true ;
411- });
412310
311+ // Handle the UnrealizedConversionCastOp introduced by the first step.
312+ // For vector->RankedTensorType, it will simply forward the inputs.
313+ // For RankedTensorType->vector, it will update the inputs with the
314+ // one from the adaptor.
413315 class UnrealizedConversionCastOpPattern
414316 : public OpConversionPattern<mlir::UnrealizedConversionCastOp> {
415317 using OpConversionPattern<
@@ -444,6 +346,24 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(Operation *op) {
444346 }
445347 };
446348
349+ converter.addSourceMaterialization (materializeCast);
350+ converter.addTargetMaterialization ([&](OpBuilder &builder, TypeRange type,
351+ ValueRange inputs, Location loc) {
352+ return builder.create <UnrealizedConversionCastOp>(loc, type, inputs)
353+ .getResults ();
354+ });
355+
356+ mlir::ConversionTarget target (*context);
357+ target.addDynamicallyLegalOp <UnrealizedConversionCastOp>(
358+ [&](UnrealizedConversionCastOp op) {
359+ auto isTensorTy = [&](Type type) {
360+ return isa<RankedTensorType>(type);
361+ };
362+ if (llvm::any_of (op->getOperandTypes (), isTensorTy) ||
363+ llvm::any_of (op->getResultTypes (), isTensorTy))
364+ return false ;
365+ return true ;
366+ });
447367 mlir::RewritePatternSet patterns (context);
448368 patterns.insert <UnrealizedConversionCastOpPattern>(context);
449369 scf::populateSCFStructuralTypeConversionsAndLegality (converter, patterns,
0 commit comments