@@ -113,94 +113,6 @@ IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
113113 return rewriter.getIntegerAttr (dstType, intVal);
114114}
115115
116- struct RawAllocator {
117- RawAllocator (OpBuilder &builder, Location loc) : builder(builder), loc(loc) {}
118-
119- std::variant<Value, int64_t > computeTotalBytes (MemRefType srcType,
120- Value srcMemref) {
121- // Element size in bytes.
122- int64_t elemBitWidth = srcType.getElementTypeBitWidth ();
123- int64_t elemByteWidth = (elemBitWidth + 7 ) / 8 ;
124-
125- if (srcType.hasStaticShape ()) {
126- // Static shape: compute total bytes statically.
127- int64_t numElements = 1 ;
128- for (int64_t dim : srcType.getShape ()) {
129- numElements *= dim;
130- }
131- return numElements * elemByteWidth;
132- }
133-
134- auto sizes = getSizes (srcType, srcMemref);
135- // Compute number of elements dynamically.
136- Value numElements = sizes.front ();
137- for (auto size : llvm::drop_begin (sizes))
138- numElements = builder.create <arith::MulIOp>(loc, numElements, size);
139- Value elemSize = builder.create <arith::ConstantIndexOp>(loc, elemByteWidth);
140-
141- return builder.create <arith::MulIOp>(loc, numElements, elemSize);
142- }
143-
144- SmallVector<Value> getSizes (MemRefType type, Value memref) {
145- SmallVector<Value> sizes;
146- for (unsigned i = 0 ; i < type.getRank (); ++i) {
147- if (type.isDynamicDim (i)) {
148- sizes.push_back (builder.create <memref::DimOp>(loc, memref, i));
149- } else {
150- sizes.push_back (
151- builder.create <arith::ConstantIndexOp>(loc, type.getShape ()[i]));
152- }
153- }
154- return sizes;
155- }
156-
157- SmallVector<Value> getDynamicSizes (MemRefType type, Value memref) {
158- SmallVector<Value> sizes;
159- for (unsigned i = 0 ; i < type.getRank (); ++i) {
160- if (type.isDynamicDim (i)) {
161- sizes.push_back (builder.create <memref::DimOp>(loc, memref, i));
162- }
163- }
164- return sizes;
165- }
166-
167- SmallVector<Value> getIdentityStrides (MemRefType type) {
168- SmallVector<Value> strides;
169- int64_t runningStride = 1 ;
170- for (int64_t dim : llvm::reverse (type.getShape ())) {
171- strides.push_back (
172- builder.create <arith::ConstantIndexOp>(loc, runningStride));
173- if (dim != ShapedType::kDynamic )
174- runningStride *= dim;
175- else
176- runningStride = -1 ; // not handling dynamic strides.
177- }
178- std::reverse (strides.begin (), strides.end ());
179- return strides;
180- }
181-
182- private:
183- OpBuilder &builder;
184- Location loc;
185- };
186-
187- // Replace uses according to predicates automatically.
188- template <typename OpTy>
189- void replaceUsesWithPredicate (
190- OpTy originalValue,
191- ArrayRef<std::pair<std::function<bool (OpOperand &)>, Value>> replacements,
192- ConversionPatternRewriter &rewriter) {
193-
194- for (OpOperand &use : llvm::make_early_inc_range (originalValue->getUses ())) {
195- for (const auto &[predicate, newValue] : replacements) {
196- if (predicate (use)) {
197- use.set (newValue);
198- break ;
199- }
200- }
201- }
202- }
203-
204116// ===----------------------------------------------------------------------===//
205117// Convertion patterns
206118// ===----------------------------------------------------------------------===//
@@ -355,127 +267,6 @@ struct ConvertGPULaunchFuncOp : OpConversionPattern<gpu::LaunchFuncOp> {
355267 }
356268};
357269
358- // ===----------------------------------------------------------------------===//
359- // AllocOp conversion pattern
360- // ===----------------------------------------------------------------------===//
361- template <typename AllocOp>
362- struct ConvertAllocOp : OpConversionPattern<AllocOp> {
363- ConvertAllocOp (MLIRContext *ctx, TypeConverter &typeConverter)
364- : OpConversionPattern<AllocOp>(ctx), typeConverter(typeConverter) {}
365-
366- LogicalResult
367- matchAndRewrite (AllocOp op, typename AllocOp::Adaptor adaptor,
368- ConversionPatternRewriter &rewriter) const override {
369- Location loc = op.getLoc ();
370- MemRefType srcType = llvm::cast<MemRefType>(op.getType ());
371- // Only supports memref types with identity layout. Since this mechanism
372- // requires the usage of memref.ViewOp, which requires the layout to be
373- // identity.
374- if (!srcType.getLayout ().isIdentity ())
375- op.emitError (" only memrefs with identity layout is supported" );
376-
377- auto dstType =
378- dyn_cast_or_null<MemRefType>(typeConverter.convertType (srcType));
379- if (!dstType || dstType == srcType)
380- return failure (); // No need to rewrite.
381-
382- // Helper class to allocate raw memory.
383- RawAllocator allocator (rewriter, loc);
384-
385- // 1. Compute total allocation size.
386- auto totalBytes = allocator.computeTotalBytes (srcType, op.getMemref ());
387-
388- // 2. Create raw i8 buffer.
389- MemRefType rawType;
390- if (std::holds_alternative<int64_t >(totalBytes)) {
391- // Static size.
392- SmallVector<int64_t > staticI8Shape;
393- staticI8Shape.push_back (std::get<int64_t >(totalBytes));
394- rawType = MemRefType::get (staticI8Shape, rewriter.getI8Type (), {},
395- srcType.getMemorySpaceAsInt ());
396- } else {
397- // Dynamic size.
398- rawType = MemRefType::get ({ShapedType::kDynamic }, rewriter.getI8Type (),
399- {}, srcType.getMemorySpaceAsInt ());
400- }
401- Value rawAlloc;
402-
403- if constexpr (std::is_same_v<AllocOp, gpu::AllocOp>) {
404- rawAlloc =
405- rewriter
406- .create <gpu::AllocOp>(
407- loc, rawType,
408- op.getAsyncToken () ? op.getAsyncToken ().getType () : nullptr ,
409- adaptor.getAsyncDependencies (),
410- std::holds_alternative<Value>(totalBytes)
411- ? ValueRange{std::get<Value>(totalBytes)}
412- : ValueRange{},
413- adaptor.getSymbolOperands (), op.getHostShared ())
414- .getResult (0 );
415- } else {
416- rawAlloc = rewriter.create <memref::AllocOp>(
417- loc, rawType,
418- std::holds_alternative<Value>(totalBytes)
419- ? ValueRange{std::get<Value>(totalBytes)}
420- : ValueRange{},
421- op.getSymbolOperands ());
422- }
423-
424- // 3. Create view for original type.
425- SmallVector<Value> dynamicSizes =
426- allocator.getDynamicSizes (srcType, op.getMemref ());
427- // Since we are using memref::ViewOp, only identity strides are supported.
428- SmallVector<Value> dynamicStrides = allocator.getIdentityStrides (srcType);
429- Value zeroOffset = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
430- Value originalView = rewriter.create <memref::ViewOp>(
431- loc, srcType, rawAlloc, zeroOffset, dynamicSizes);
432-
433- // 4. Create view for converted type.
434- Value convertedView = rewriter.create <memref::ViewOp>(
435- loc, dstType, rawAlloc, zeroOffset, dynamicSizes);
436-
437- // 5. Replace uses:
438- // gpu::LaunchFuncOp uses -> Replace the original AllocOp use in
439- // gpu::LaunchFuncOp with the view of the
440- // converted type.
441- //
442- // DeallocOp uses -> Replace the original AllocOp use in dealloc with
443- // the new AllocOp.
444- //
445- // Other uses-> Replace the original AllocOp use with the view of the
446- // original type.
447-
448- SmallVector<OpOperand *> launchFuncUses;
449- SmallVector<OpOperand *> deallocUses;
450- SmallVector<OpOperand *> otherUses;
451-
452- for (OpOperand &use : op->getUses ()) {
453- if (isa<gpu::LaunchFuncOp>(use.getOwner ())) {
454- launchFuncUses.push_back (&use);
455- } else if (isa<memref::DeallocOp>(use.getOwner ()) ||
456- isa<gpu::DeallocOp>(use.getOwner ())) {
457- deallocUses.push_back (&use);
458- } else {
459- otherUses.push_back (&use);
460- }
461- }
462-
463- for (OpOperand *use : launchFuncUses)
464- use->set (convertedView);
465- for (OpOperand *use : deallocUses)
466- use->set (rawAlloc);
467- for (OpOperand *use : otherUses)
468- use->set (originalView);
469-
470- // Erase the original AllocOp.
471- rewriter.eraseOp (op);
472- return success ();
473- }
474-
475- private:
476- TypeConverter &typeConverter;
477- };
478-
479270// ===----------------------------------------------------------------------===//
480271// ArithConstantOp conversion pattern
481272// ===----------------------------------------------------------------------===//
@@ -688,12 +479,10 @@ void mlir::populateImitateUnsupportedTypesTypeConverter(
688479 ValueRange inputs, Location loc) -> Value {
689480 assert (inputs.size () == 1 && " Expected single input" );
690481 Type inputType = inputs[0 ].getType ();
691- if (isa<MemRefType>(resultType) && isa<MemRefType>(inputType)) {
692- return builder.create <UnrealizedConversionCastOp>(loc, resultType, inputs)
693- .getResult (0 );
694- }
695- if ((resultType.isIntOrIndexOrFloat () || isa<VectorType>(resultType)) &&
696- (inputType.isIntOrIndexOrFloat () || isa<VectorType>(inputType))) {
482+ if ((resultType.isIntOrIndexOrFloat () || isa<VectorType>(resultType) ||
483+ isa<MemRefType>(resultType)) &&
484+ (inputType.isIntOrIndexOrFloat () || isa<VectorType>(inputType) ||
485+ isa<MemRefType>(inputType))) {
697486 return builder.create <arith::BitcastOp>(loc, resultType, inputs[0 ])
698487 .getResult ();
699488 }
@@ -724,8 +513,6 @@ void mlir::populateImitateUnsupportedTypesConversionPatterns(
724513 patterns.add <ConvertCallOp>(ctx, typeConverter, convertedFuncTypes);
725514 patterns.add <ConvertArithConstantOp>(ctx, typeConverter, srcTypes, tgtTypes);
726515 patterns.add <ConvertGPULaunchFuncOp>(ctx);
727- patterns.add <ConvertAllocOp<gpu::AllocOp>>(ctx, typeConverter);
728- patterns.add <ConvertAllocOp<memref::AllocOp>>(ctx, typeConverter);
729516}
730517
731518// ===----------------------------------------------------------------------===//
@@ -744,8 +531,11 @@ void mlir::configureImitateUnsupportedTypesLegality(
744531 return true ;
745532 });
746533
747- target.addDynamicallyLegalDialect <gpu::GPUDialect>(
748- [&](Operation *op) { return typeConverter.isLegal (op); });
534+ target.addDynamicallyLegalDialect <gpu::GPUDialect>([&](Operation *op) {
535+ if (op->getParentOfType <gpu::GPUModuleOp>())
536+ return typeConverter.isLegal (op);
537+ return true ;
538+ });
749539
750540 target.addDynamicallyLegalDialect <func::FuncDialect>([&](Operation *op) {
751541 if (op->getParentOfType <gpu::GPUModuleOp>())
@@ -755,7 +545,6 @@ void mlir::configureImitateUnsupportedTypesLegality(
755545 });
756546
757547 target.addLegalOp <gpu::GPUModuleOp>();
758- target.addLegalOp <UnrealizedConversionCastOp>();
759548 // Manually mark arithmetic-performing vector instructions.
760549 target.addLegalOp <vector::ContractionOp, vector::ReductionOp,
761550 vector::MultiDimReductionOp, vector::FMAOp,
@@ -767,6 +556,8 @@ void mlir::configureImitateUnsupportedTypesLegality(
767556 target.addDynamicallyLegalOp <gpu::GPUFuncOp>([&](gpu::GPUFuncOp op) {
768557 return typeConverter.isSignatureLegal (op.getFunctionType ());
769558 });
559+ target.addDynamicallyLegalOp <gpu::LaunchFuncOp>(
560+ [&](gpu::LaunchFuncOp op) { return typeConverter.isLegal (op); });
770561 // Only convert functions and function calls in gpu.module
771562 target.addDynamicallyLegalOp <func::FuncOp>([&](func::FuncOp op) {
772563 if (op->getParentOfType <gpu::GPUModuleOp>())
@@ -779,22 +570,8 @@ void mlir::configureImitateUnsupportedTypesLegality(
779570 return true ;
780571 });
781572
782- // Only convert alloc ops in gpu.module or in host functions and has a use
783- // in LaunchFunc
784- target.addDynamicallyLegalOp <memref::AllocOp>([&](memref::AllocOp op) {
785- if (op->getParentOfType <gpu::GPUModuleOp>())
786- return typeConverter.isLegal (op.getType ());
787- else {
788- for (auto user : op->getUsers ()) {
789- if (isa<gpu::LaunchFuncOp>(user))
790- return typeConverter.isLegal (op.getType ());
791- }
792- }
793- return true ;
794- });
795-
796- // Mark unknown ops that are inside gpu.module, and one of its's operand is a
797- // memref type as dynamically legal.
573+ // Mark unknown ops that are inside gpu.module, and one of its's operand is
574+ // a memref type as dynamically legal.
798575 target.markUnknownOpDynamicallyLegal ([&typeConverter](Operation *op) -> bool {
799576 // Check if the operation is inside a gpu.module.
800577 if (op->getParentOfType <gpu::GPUModuleOp>()) {
@@ -899,21 +676,6 @@ struct GpuImitateUnsupportedTypesPass
899676 // Apply the conversion.
900677 if (failed (applyPartialConversion (op, target, std::move (patterns))))
901678 return signalPassFailure ();
902-
903- // Post-conversion validation: check for any remaining
904- // unrealized_conversion_cast.
905- op->walk ([&](UnrealizedConversionCastOp op) {
906- // Check if the cast is from a source type to a target type.
907- for (auto [sourceType, targetType] :
908- llvm::zip_equal (sourceTypes, targetTypes)) {
909- if (getElementTypeOrSelf (op.getOperand (0 ).getType ()) == sourceType &&
910- getElementTypeOrSelf (op.getResult (0 ).getType ()) == targetType) {
911- op->emitError (" unresolved unrealized_conversion_cast left in IR "
912- " after conversion" );
913- return signalPassFailure ();
914- }
915- }
916- });
917679 }
918680};
919681
0 commit comments