Skip to content

Commit f13507e

Browse files
committed
Use arith.bitcast for memref element casting.
Remove the usage of memref.view op and the restrictions comes with it. Makes the pass straight forward.
1 parent 0cb8c96 commit f13507e

File tree

2 files changed

+133
-306
lines changed

2 files changed

+133
-306
lines changed

mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp

Lines changed: 13 additions & 251 deletions
Original file line numberDiff line numberDiff line change
@@ -113,94 +113,6 @@ IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
113113
return rewriter.getIntegerAttr(dstType, intVal);
114114
}
115115

116-
struct RawAllocator {
117-
RawAllocator(OpBuilder &builder, Location loc) : builder(builder), loc(loc) {}
118-
119-
std::variant<Value, int64_t> computeTotalBytes(MemRefType srcType,
120-
Value srcMemref) {
121-
// Element size in bytes.
122-
int64_t elemBitWidth = srcType.getElementTypeBitWidth();
123-
int64_t elemByteWidth = (elemBitWidth + 7) / 8;
124-
125-
if (srcType.hasStaticShape()) {
126-
// Static shape: compute total bytes statically.
127-
int64_t numElements = 1;
128-
for (int64_t dim : srcType.getShape()) {
129-
numElements *= dim;
130-
}
131-
return numElements * elemByteWidth;
132-
}
133-
134-
auto sizes = getSizes(srcType, srcMemref);
135-
// Compute number of elements dynamically.
136-
Value numElements = sizes.front();
137-
for (auto size : llvm::drop_begin(sizes))
138-
numElements = builder.create<arith::MulIOp>(loc, numElements, size);
139-
Value elemSize = builder.create<arith::ConstantIndexOp>(loc, elemByteWidth);
140-
141-
return builder.create<arith::MulIOp>(loc, numElements, elemSize);
142-
}
143-
144-
SmallVector<Value> getSizes(MemRefType type, Value memref) {
145-
SmallVector<Value> sizes;
146-
for (unsigned i = 0; i < type.getRank(); ++i) {
147-
if (type.isDynamicDim(i)) {
148-
sizes.push_back(builder.create<memref::DimOp>(loc, memref, i));
149-
} else {
150-
sizes.push_back(
151-
builder.create<arith::ConstantIndexOp>(loc, type.getShape()[i]));
152-
}
153-
}
154-
return sizes;
155-
}
156-
157-
SmallVector<Value> getDynamicSizes(MemRefType type, Value memref) {
158-
SmallVector<Value> sizes;
159-
for (unsigned i = 0; i < type.getRank(); ++i) {
160-
if (type.isDynamicDim(i)) {
161-
sizes.push_back(builder.create<memref::DimOp>(loc, memref, i));
162-
}
163-
}
164-
return sizes;
165-
}
166-
167-
SmallVector<Value> getIdentityStrides(MemRefType type) {
168-
SmallVector<Value> strides;
169-
int64_t runningStride = 1;
170-
for (int64_t dim : llvm::reverse(type.getShape())) {
171-
strides.push_back(
172-
builder.create<arith::ConstantIndexOp>(loc, runningStride));
173-
if (dim != ShapedType::kDynamic)
174-
runningStride *= dim;
175-
else
176-
runningStride = -1; // not handling dynamic strides.
177-
}
178-
std::reverse(strides.begin(), strides.end());
179-
return strides;
180-
}
181-
182-
private:
183-
OpBuilder &builder;
184-
Location loc;
185-
};
186-
187-
// Replace uses according to predicates automatically.
188-
template <typename OpTy>
189-
void replaceUsesWithPredicate(
190-
OpTy originalValue,
191-
ArrayRef<std::pair<std::function<bool(OpOperand &)>, Value>> replacements,
192-
ConversionPatternRewriter &rewriter) {
193-
194-
for (OpOperand &use : llvm::make_early_inc_range(originalValue->getUses())) {
195-
for (const auto &[predicate, newValue] : replacements) {
196-
if (predicate(use)) {
197-
use.set(newValue);
198-
break;
199-
}
200-
}
201-
}
202-
}
203-
204116
//===----------------------------------------------------------------------===//
205117
// Convertion patterns
206118
//===----------------------------------------------------------------------===//
@@ -355,127 +267,6 @@ struct ConvertGPULaunchFuncOp : OpConversionPattern<gpu::LaunchFuncOp> {
355267
}
356268
};
357269

358-
//===----------------------------------------------------------------------===//
359-
// AllocOp conversion pattern
360-
//===----------------------------------------------------------------------===//
361-
template <typename AllocOp>
362-
struct ConvertAllocOp : OpConversionPattern<AllocOp> {
363-
ConvertAllocOp(MLIRContext *ctx, TypeConverter &typeConverter)
364-
: OpConversionPattern<AllocOp>(ctx), typeConverter(typeConverter) {}
365-
366-
LogicalResult
367-
matchAndRewrite(AllocOp op, typename AllocOp::Adaptor adaptor,
368-
ConversionPatternRewriter &rewriter) const override {
369-
Location loc = op.getLoc();
370-
MemRefType srcType = llvm::cast<MemRefType>(op.getType());
371-
// Only supports memref types with identity layout. Since this mechanism
372-
// requires the usage of memref.ViewOp, which requires the layout to be
373-
// identity.
374-
if (!srcType.getLayout().isIdentity())
375-
op.emitError("only memrefs with identity layout is supported");
376-
377-
auto dstType =
378-
dyn_cast_or_null<MemRefType>(typeConverter.convertType(srcType));
379-
if (!dstType || dstType == srcType)
380-
return failure(); // No need to rewrite.
381-
382-
// Helper class to allocate raw memory.
383-
RawAllocator allocator(rewriter, loc);
384-
385-
// 1. Compute total allocation size.
386-
auto totalBytes = allocator.computeTotalBytes(srcType, op.getMemref());
387-
388-
// 2. Create raw i8 buffer.
389-
MemRefType rawType;
390-
if (std::holds_alternative<int64_t>(totalBytes)) {
391-
// Static size.
392-
SmallVector<int64_t> staticI8Shape;
393-
staticI8Shape.push_back(std::get<int64_t>(totalBytes));
394-
rawType = MemRefType::get(staticI8Shape, rewriter.getI8Type(), {},
395-
srcType.getMemorySpaceAsInt());
396-
} else {
397-
// Dynamic size.
398-
rawType = MemRefType::get({ShapedType::kDynamic}, rewriter.getI8Type(),
399-
{}, srcType.getMemorySpaceAsInt());
400-
}
401-
Value rawAlloc;
402-
403-
if constexpr (std::is_same_v<AllocOp, gpu::AllocOp>) {
404-
rawAlloc =
405-
rewriter
406-
.create<gpu::AllocOp>(
407-
loc, rawType,
408-
op.getAsyncToken() ? op.getAsyncToken().getType() : nullptr,
409-
adaptor.getAsyncDependencies(),
410-
std::holds_alternative<Value>(totalBytes)
411-
? ValueRange{std::get<Value>(totalBytes)}
412-
: ValueRange{},
413-
adaptor.getSymbolOperands(), op.getHostShared())
414-
.getResult(0);
415-
} else {
416-
rawAlloc = rewriter.create<memref::AllocOp>(
417-
loc, rawType,
418-
std::holds_alternative<Value>(totalBytes)
419-
? ValueRange{std::get<Value>(totalBytes)}
420-
: ValueRange{},
421-
op.getSymbolOperands());
422-
}
423-
424-
// 3. Create view for original type.
425-
SmallVector<Value> dynamicSizes =
426-
allocator.getDynamicSizes(srcType, op.getMemref());
427-
// Since we are using memref::ViewOp, only identity strides are supported.
428-
SmallVector<Value> dynamicStrides = allocator.getIdentityStrides(srcType);
429-
Value zeroOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
430-
Value originalView = rewriter.create<memref::ViewOp>(
431-
loc, srcType, rawAlloc, zeroOffset, dynamicSizes);
432-
433-
// 4. Create view for converted type.
434-
Value convertedView = rewriter.create<memref::ViewOp>(
435-
loc, dstType, rawAlloc, zeroOffset, dynamicSizes);
436-
437-
// 5. Replace uses:
438-
// gpu::LaunchFuncOp uses -> Replace the original AllocOp use in
439-
// gpu::LaunchFuncOp with the view of the
440-
// converted type.
441-
//
442-
// DeallocOp uses -> Replace the original AllocOp use in dealloc with
443-
// the new AllocOp.
444-
//
445-
// Other uses-> Replace the original AllocOp use with the view of the
446-
// original type.
447-
448-
SmallVector<OpOperand *> launchFuncUses;
449-
SmallVector<OpOperand *> deallocUses;
450-
SmallVector<OpOperand *> otherUses;
451-
452-
for (OpOperand &use : op->getUses()) {
453-
if (isa<gpu::LaunchFuncOp>(use.getOwner())) {
454-
launchFuncUses.push_back(&use);
455-
} else if (isa<memref::DeallocOp>(use.getOwner()) ||
456-
isa<gpu::DeallocOp>(use.getOwner())) {
457-
deallocUses.push_back(&use);
458-
} else {
459-
otherUses.push_back(&use);
460-
}
461-
}
462-
463-
for (OpOperand *use : launchFuncUses)
464-
use->set(convertedView);
465-
for (OpOperand *use : deallocUses)
466-
use->set(rawAlloc);
467-
for (OpOperand *use : otherUses)
468-
use->set(originalView);
469-
470-
// Erase the original AllocOp.
471-
rewriter.eraseOp(op);
472-
return success();
473-
}
474-
475-
private:
476-
TypeConverter &typeConverter;
477-
};
478-
479270
//===----------------------------------------------------------------------===//
480271
// ArithConstantOp conversion pattern
481272
//===----------------------------------------------------------------------===//
@@ -688,12 +479,10 @@ void mlir::populateImitateUnsupportedTypesTypeConverter(
688479
ValueRange inputs, Location loc) -> Value {
689480
assert(inputs.size() == 1 && "Expected single input");
690481
Type inputType = inputs[0].getType();
691-
if (isa<MemRefType>(resultType) && isa<MemRefType>(inputType)) {
692-
return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
693-
.getResult(0);
694-
}
695-
if ((resultType.isIntOrIndexOrFloat() || isa<VectorType>(resultType)) &&
696-
(inputType.isIntOrIndexOrFloat() || isa<VectorType>(inputType))) {
482+
if ((resultType.isIntOrIndexOrFloat() || isa<VectorType>(resultType) ||
483+
isa<MemRefType>(resultType)) &&
484+
(inputType.isIntOrIndexOrFloat() || isa<VectorType>(inputType) ||
485+
isa<MemRefType>(inputType))) {
697486
return builder.create<arith::BitcastOp>(loc, resultType, inputs[0])
698487
.getResult();
699488
}
@@ -724,8 +513,6 @@ void mlir::populateImitateUnsupportedTypesConversionPatterns(
724513
patterns.add<ConvertCallOp>(ctx, typeConverter, convertedFuncTypes);
725514
patterns.add<ConvertArithConstantOp>(ctx, typeConverter, srcTypes, tgtTypes);
726515
patterns.add<ConvertGPULaunchFuncOp>(ctx);
727-
patterns.add<ConvertAllocOp<gpu::AllocOp>>(ctx, typeConverter);
728-
patterns.add<ConvertAllocOp<memref::AllocOp>>(ctx, typeConverter);
729516
}
730517

731518
//===----------------------------------------------------------------------===//
@@ -744,8 +531,11 @@ void mlir::configureImitateUnsupportedTypesLegality(
744531
return true;
745532
});
746533

747-
target.addDynamicallyLegalDialect<gpu::GPUDialect>(
748-
[&](Operation *op) { return typeConverter.isLegal(op); });
534+
target.addDynamicallyLegalDialect<gpu::GPUDialect>([&](Operation *op) {
535+
if (op->getParentOfType<gpu::GPUModuleOp>())
536+
return typeConverter.isLegal(op);
537+
return true;
538+
});
749539

750540
target.addDynamicallyLegalDialect<func::FuncDialect>([&](Operation *op) {
751541
if (op->getParentOfType<gpu::GPUModuleOp>())
@@ -755,7 +545,6 @@ void mlir::configureImitateUnsupportedTypesLegality(
755545
});
756546

757547
target.addLegalOp<gpu::GPUModuleOp>();
758-
target.addLegalOp<UnrealizedConversionCastOp>();
759548
// Manually mark arithmetic-performing vector instructions.
760549
target.addLegalOp<vector::ContractionOp, vector::ReductionOp,
761550
vector::MultiDimReductionOp, vector::FMAOp,
@@ -767,6 +556,8 @@ void mlir::configureImitateUnsupportedTypesLegality(
767556
target.addDynamicallyLegalOp<gpu::GPUFuncOp>([&](gpu::GPUFuncOp op) {
768557
return typeConverter.isSignatureLegal(op.getFunctionType());
769558
});
559+
target.addDynamicallyLegalOp<gpu::LaunchFuncOp>(
560+
[&](gpu::LaunchFuncOp op) { return typeConverter.isLegal(op); });
770561
// Only convert functions and function calls in gpu.module
771562
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
772563
if (op->getParentOfType<gpu::GPUModuleOp>())
@@ -779,22 +570,8 @@ void mlir::configureImitateUnsupportedTypesLegality(
779570
return true;
780571
});
781572

782-
// Only convert alloc ops in gpu.module or in host functions and has a use
783-
// in LaunchFunc
784-
target.addDynamicallyLegalOp<memref::AllocOp>([&](memref::AllocOp op) {
785-
if (op->getParentOfType<gpu::GPUModuleOp>())
786-
return typeConverter.isLegal(op.getType());
787-
else {
788-
for (auto user : op->getUsers()) {
789-
if (isa<gpu::LaunchFuncOp>(user))
790-
return typeConverter.isLegal(op.getType());
791-
}
792-
}
793-
return true;
794-
});
795-
796-
// Mark unknown ops that are inside gpu.module, and one of its's operand is a
797-
// memref type as dynamically legal.
573+
// Mark unknown ops that are inside gpu.module, and one of its's operand is
574+
// a memref type as dynamically legal.
798575
target.markUnknownOpDynamicallyLegal([&typeConverter](Operation *op) -> bool {
799576
// Check if the operation is inside a gpu.module.
800577
if (op->getParentOfType<gpu::GPUModuleOp>()) {
@@ -899,21 +676,6 @@ struct GpuImitateUnsupportedTypesPass
899676
// Apply the conversion.
900677
if (failed(applyPartialConversion(op, target, std::move(patterns))))
901678
return signalPassFailure();
902-
903-
// Post-conversion validation: check for any remaining
904-
// unrealized_conversion_cast.
905-
op->walk([&](UnrealizedConversionCastOp op) {
906-
// Check if the cast is from a source type to a target type.
907-
for (auto [sourceType, targetType] :
908-
llvm::zip_equal(sourceTypes, targetTypes)) {
909-
if (getElementTypeOrSelf(op.getOperand(0).getType()) == sourceType &&
910-
getElementTypeOrSelf(op.getResult(0).getType()) == targetType) {
911-
op->emitError("unresolved unrealized_conversion_cast left in IR "
912-
"after conversion");
913-
return signalPassFailure();
914-
}
915-
}
916-
});
917679
}
918680
};
919681

0 commit comments

Comments
 (0)