|
20 | 20 | #include "mlir/IR/PatternMatch.h"
|
21 | 21 | #include "mlir/IR/TypeRange.h"
|
22 | 22 | #include "mlir/IR/Value.h"
|
| 23 | +#include "mlir/IR/ValueRange.h" |
23 | 24 | #include "mlir/Transforms/DialectConversion.h"
|
24 | 25 | #include "llvm/Support/FormatVariadic.h"
|
25 | 26 | #include <cstdint>
|
@@ -316,36 +317,26 @@ struct ConvertReinterpretCastOp final
|
316 | 317 |
|
317 | 318 | return ptr;
|
318 | 319 | };
|
319 |
| - auto [strides, offset] = targetMemRefType.getStridesAndOffset(); |
320 |
| - // Value offsetValue = rewriter.create<emitc::ConstantOp>( |
321 |
| - // loc, rewriter.getIndexType(), rewriter.getIndexAttr(offset)); |
322 | 320 |
|
323 | 321 | auto srcPtr = createPointerFromEmitcArray(srcArrayValue);
|
324 |
| - // emitc::PointerType targetPointerType = |
325 |
| - // emitc::PointerType::get(srcArrayValue.getType().getElementType()); |
326 |
| - |
327 |
| - auto dimensions = targetMemRefType.getShape(); |
328 |
| - std::string reinterpretCastName = llvm::formatv( |
329 |
| - "reinterpret_cast<{0}(*)", srcArrayValue.getType().getElementType()); |
330 |
| - std::string dimensionsStr; |
331 |
| - for (auto dim : dimensions) { |
332 |
| - dimensionsStr += llvm::formatv("[{0}]", dim); |
333 |
| - } |
334 |
| - reinterpretCastName += llvm::formatv("{0}>", dimensionsStr); |
335 |
| - reinterpretCastName += ">"; |
336 |
| - |
337 |
| - reinterpretCastName += llvm::formatv("{0}", srcPtr->getResult(0)); |
| 322 | + // 1. Create a TypeAttr for the target type. |
| 323 | + TypeAttr targetTypeAttr = |
| 324 | + TypeAttr::get(emitc::PointerType::get(targetInEmitC)); |
| 325 | + IntegerAttr resty = rewriter.getIndexAttr(0); |
338 | 326 |
|
339 |
| - std::string outputStr = llvm::formatv( |
340 |
| - "{0}(*){1}", srcArrayValue.getType().getElementType(), dimensionsStr); |
341 |
| - auto outputType = emitc::PointerType::get( |
342 |
| - emitc::OpaqueType::get(rewriter.getContext(), outputStr)); |
| 327 | + // 2. Create an ArrayAttr with the TypeAttr. This will be the |
| 328 | + // templateArgsAttr. |
| 329 | + ArrayAttr templateArgsAttr = rewriter.getArrayAttr({targetTypeAttr}); |
343 | 330 |
|
344 |
| - emitc::ConstantOp reinterpretOp = rewriter.create<emitc::ConstantOp>( |
345 |
| - loc, outputType, |
346 |
| - emitc::OpaqueAttr::get(rewriter.getContext(), reinterpretCastName)); |
| 331 | + auto reinterpretCastCall = rewriter.create<emitc::CallOpaqueOp>( |
| 332 | + loc, |
| 333 | + /*result types=*/TypeRange{emitc::PointerType::get(targetInEmitC)}, |
| 334 | + /*callee=*/"reinterpret_cast", |
| 335 | + /*args*/ rewriter.getArrayAttr({resty}), |
| 336 | + /*template_args=*/templateArgsAttr, |
| 337 | + /*operands=*/ValueRange{srcPtr.getResult()}); |
347 | 338 |
|
348 |
| - rewriter.replaceOp(castOp, reinterpretOp.getResult()); |
| 339 | + rewriter.replaceOp(castOp, reinterpretCastCall.getResults()); |
349 | 340 | return success();
|
350 | 341 | }
|
351 | 342 | };
|
|
0 commit comments