Skip to content

Commit 46925eb

Browse files
authored
Address small downstream build issues (#4289)
Two issues seemed to block an integrate downstream with IREE: iree-org/iree#21628 1. Addresses an unused variable for release builds (and fixes a typo in the check 2->3). 2. Adds missing custom build methods for two `TMTensor` ops. (I'm not sure why this hasn't caused a failure in the past). Signed-off-by: zjgarvey <[email protected]>
1 parent 8809167 commit 46925eb

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def TMTensor_ScanOp : TMTensor_Op<"scan",
6262

6363
let builders = [
6464
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
65-
CArg<"int64_t", "0">:$dimension, CArg<"bool", "true">:$inclusive)>
65+
CArg<"int64_t", "0">:$dimension, CArg<"bool", "true">:$inclusive), [{
66+
build($_builder, $_state, TypeRange(outputs), inputs, outputs, dimension, inclusive);
67+
}]>
6668
];
6769

6870
let results = (outs Variadic<AnyRankedTensor>:$results);
@@ -267,7 +269,9 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
267269
);
268270

269271
let builders = [
270-
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs)>
272+
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs), [{
273+
build($_builder, $_state, TypeRange(outputs), inputs, outputs);
274+
}]>
271275
];
272276

273277
let results = (outs Variadic<AnyRankedTensor>:$result);

lib/Conversion/TorchToLinalg/TensorConstructors.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,8 @@ class ConvertAtenReplicationPad3dOp
284284
Location loc = op->getLoc();
285285
Value input = adaptor.getSelf();
286286
auto inputType = llvm::cast<RankedTensorType>(input.getType());
287-
int64_t inputRank = inputType.getRank();
288-
assert(inputRank >= 2 && "Not enough input dimensions");
287+
[[maybe_unused]] int64_t inputRank = inputType.getRank();
288+
assert(inputRank >= 3 && "Not enough input dimensions");
289289

290290
SmallVector<int64_t> padInts;
291291
if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts)))

lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -236,12 +236,9 @@ static Value createTMTensorScanOp(
236236
int64_t dim, bool inclusive,
237237
function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuild) {
238238
auto inputType = cast<RankedTensorType>(input.getType());
239-
auto accType = cast<RankedTensorType>(accumulator.getType());
240239
Type elementType = inputType.getElementType();
241240
auto scanOp = b.create<TMTensor::ScanOp>(
242-
loc, TypeRange{inputType, accType}, input,
243-
ValueRange{output, accumulator}, b.getI64IntegerAttr(dim),
244-
b.getBoolAttr(inclusive));
241+
loc, ValueRange{input}, ValueRange{output, accumulator}, dim, inclusive);
245242

246243
Region &scanOpRegion = scanOp.getRegion();
247244
auto &scanOpBlock = scanOpRegion.emplaceBlock();

0 commit comments

Comments
 (0)