@@ -3329,41 +3329,43 @@ struct ConvertAtenOnesZerosOp : ConversionPattern {
33293329 return failure ();
33303330 Location loc = op->getLoc ();
33313331
3332- SmallVector< Value, 3 > opArguments ;
3332+ Value size, layout, pin_memory ;
33333333 int64_t elementValue;
33343334
33353335 if (AtenOnesOp onesOp = dyn_cast<AtenOnesOp>(op)) {
3336- opArguments.insert (opArguments.end (),
3337- {onesOp.size (), onesOp.layout (), onesOp.pin_memory ()});
3336+ size = onesOp.size ();
3337+ layout = onesOp.layout ();
3338+ pin_memory = onesOp.pin_memory ();
33383339 elementValue = 1 ;
33393340 } else if (AtenZerosOp zerosOp = dyn_cast<AtenZerosOp>(op)) {
3340- opArguments.insert (opArguments.end (), {zerosOp.size (), zerosOp.layout (),
3341- zerosOp.pin_memory ()});
3341+ size = zerosOp.size ();
3342+ layout = zerosOp.layout ();
3343+ pin_memory = zerosOp.pin_memory ();
33423344 elementValue = 0 ;
33433345 }
33443346
33453347 // We ignore device, but add simple asserts for unimplemented kwargs
3346- if (!opArguments[ 1 ] .getType ().isa <Torch::NoneType>())
3348+ if (!layout .getType ().isa <Torch::NoneType>())
33473349 return rewriter.notifyMatchFailure (op,
33483350 " only default layout is supported" );
33493351
33503352 bool pinMemory = false ;
3351- if (!opArguments[ 2 ] .getType ().isa <Torch::NoneType>() &&
3352- !matchPattern (opArguments[ 2 ] , m_TorchConstantBool (&pinMemory))) {
3353+ if (!pin_memory .getType ().isa <Torch::NoneType>() &&
3354+ !matchPattern (pin_memory , m_TorchConstantBool (&pinMemory))) {
33533355 return rewriter.notifyMatchFailure (
33543356 op, " pin_memory must be constant bool or None" );
33553357 }
33563358 if (pinMemory)
33573359 return rewriter.notifyMatchFailure (op, " memory pinning not supported" );
33583360
3359- SmallVector<Value> size , sizeIndex;
3360- if (!getListConstructElements (opArguments[ 0 ], size )) {
3361+ SmallVector<Value> sizes , sizeIndex;
3362+ if (!getListConstructElements (size, sizes )) {
33613363 return rewriter.notifyMatchFailure (
33623364 op, " size must be created by ListConstruct" );
33633365 }
3364- size = getTypeConvertedValues (rewriter, loc, getTypeConverter (), size );
3365- for (size_t i = 0 ; i < size .size (); i++)
3366- sizeIndex.push_back (castIntToIndex (rewriter, loc, size [i]));
3366+ sizes = getTypeConvertedValues (rewriter, loc, getTypeConverter (), sizes );
3367+ for (size_t i = 0 ; i < sizes .size (); i++)
3368+ sizeIndex.push_back (castIntToIndex (rewriter, loc, sizes [i]));
33673369
33683370 RankedTensorType newResultType =
33693371 getTypeConverter ()
0 commit comments