Skip to content

Commit 0a0a1b4

Browse files
[MLIR][Torch] Resolve styling issues related to aten zeros/ones op
#464 (comment) Signed-Off By: Vivek Khandelwal <[email protected]>
1 parent f34eb66 commit 0a0a1b4

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3329,41 +3329,43 @@ struct ConvertAtenOnesZerosOp : ConversionPattern {
33293329
return failure();
33303330
Location loc = op->getLoc();
33313331

3332-
SmallVector<Value, 3> opArguments;
3332+
Value size, layout, pin_memory;
33333333
int64_t elementValue;
33343334

33353335
if (AtenOnesOp onesOp = dyn_cast<AtenOnesOp>(op)) {
3336-
opArguments.insert(opArguments.end(),
3337-
{onesOp.size(), onesOp.layout(), onesOp.pin_memory()});
3336+
size = onesOp.size();
3337+
layout = onesOp.layout();
3338+
pin_memory = onesOp.pin_memory();
33383339
elementValue = 1;
33393340
} else if (AtenZerosOp zerosOp = dyn_cast<AtenZerosOp>(op)) {
3340-
opArguments.insert(opArguments.end(), {zerosOp.size(), zerosOp.layout(),
3341-
zerosOp.pin_memory()});
3341+
size = zerosOp.size();
3342+
layout = zerosOp.layout();
3343+
pin_memory = zerosOp.pin_memory();
33423344
elementValue = 0;
33433345
}
33443346

33453347
// We ignore device, but add simple asserts for unimplemented kwargs
3346-
if (!opArguments[1].getType().isa<Torch::NoneType>())
3348+
if (!layout.getType().isa<Torch::NoneType>())
33473349
return rewriter.notifyMatchFailure(op,
33483350
"only default layout is supported");
33493351

33503352
bool pinMemory = false;
3351-
if (!opArguments[2].getType().isa<Torch::NoneType>() &&
3352-
!matchPattern(opArguments[2], m_TorchConstantBool(&pinMemory))) {
3353+
if (!pin_memory.getType().isa<Torch::NoneType>() &&
3354+
!matchPattern(pin_memory, m_TorchConstantBool(&pinMemory))) {
33533355
return rewriter.notifyMatchFailure(
33543356
op, "pin_memory must be constant bool or None");
33553357
}
33563358
if (pinMemory)
33573359
return rewriter.notifyMatchFailure(op, "memory pinning not supported");
33583360

3359-
SmallVector<Value> size, sizeIndex;
3360-
if (!getListConstructElements(opArguments[0], size)) {
3361+
SmallVector<Value> sizes, sizeIndex;
3362+
if (!getListConstructElements(size, sizes)) {
33613363
return rewriter.notifyMatchFailure(
33623364
op, "size must be created by ListConstruct");
33633365
}
3364-
size = getTypeConvertedValues(rewriter, loc, getTypeConverter(), size);
3365-
for (size_t i = 0; i < size.size(); i++)
3366-
sizeIndex.push_back(castIntToIndex(rewriter, loc, size[i]));
3366+
sizes = getTypeConvertedValues(rewriter, loc, getTypeConverter(), sizes);
3367+
for (size_t i = 0; i < sizes.size(); i++)
3368+
sizeIndex.push_back(castIntToIndex(rewriter, loc, sizes[i]));
33673369

33683370
RankedTensorType newResultType =
33693371
getTypeConverter()

0 commit comments

Comments
 (0)