@@ -7087,9 +7087,16 @@ class DecomposeAtenEmptyLikeOp : public OpRewritePattern<AtenEmptyLikeOp> {
7087
7087
Torch::ListType::get (Torch::IntType::get (op.getContext ()));
7088
7088
Value sizeList =
7089
7089
rewriter.create <AtenSizeOp>(op.getLoc (), sizeListType, op.getSelf ());
7090
+
7091
+ FailureOr<Value> dtype = getDtypeFromOp (rewriter, op);
7092
+ if (failed (dtype)) {
7093
+ return rewriter.notifyMatchFailure (
7094
+ op, " could not determine dtype from the op." );
7095
+ }
7096
+
7090
7097
rewriter.replaceOpWithNewOp <AtenEmptyMemoryFormatOp>(
7091
- op, op.getType (), sizeList, op.getDtype (), op.getLayout (),
7092
- op.getDevice (), op. getPinMemory (), op.getMemoryFormat ());
7098
+ op, op.getType (), sizeList, *dtype, op.getLayout (), op.getDevice (),
7099
+ op.getPinMemory (), op.getMemoryFormat ());
7093
7100
return success ();
7094
7101
}
7095
7102
};
@@ -7838,18 +7845,13 @@ class DecomposeAtenNewEmptyOp : public OpRewritePattern<AtenNewEmptyOp> {
7838
7845
LogicalResult matchAndRewrite (AtenNewEmptyOp op,
7839
7846
PatternRewriter &rewriter) const override {
7840
7847
Value noneVal = rewriter.create <ConstantNoneOp>(op.getLoc ());
7841
- Value dtype = op.getDtype ();
7842
- if (isa<Torch::NoneType>(dtype.getType ())) {
7843
- BaseTensorType tensorType = cast<BaseTensorType>(op.getSelf ().getType ());
7844
- if (!tensorType.hasDtype ()) {
7845
- return rewriter.notifyMatchFailure (
7846
- op, " expected input tensor to have a dtype" );
7847
- }
7848
- dtype =
7849
- getDtypeIntValueForType (rewriter, op.getLoc (), tensorType.getDtype ());
7848
+ FailureOr<Value> dtype = getDtypeFromOp (rewriter, op);
7849
+ if (failed (dtype)) {
7850
+ return rewriter.notifyMatchFailure (
7851
+ op, " could not determine dtype from the op." );
7850
7852
}
7851
7853
rewriter.replaceOpWithNewOp <AtenEmptyMemoryFormatOp>(
7852
- op, op.getType (), op.getSize (), dtype, op.getLayout (), op.getDevice (),
7854
+ op, op.getType (), op.getSize (), * dtype, op.getLayout (), op.getDevice (),
7853
7855
op.getPinMemory (), /* memoryFormat=*/ noneVal);
7854
7856
return success ();
7855
7857
}
@@ -9257,12 +9259,12 @@ class DecomposeAtenRandnGeneratorOp
9257
9259
Location loc = op.getLoc ();
9258
9260
auto resultType = cast<BaseTensorType>(op.getType ());
9259
9261
9260
- if (!resultType.hasDtype ()) {
9262
+ FailureOr<Value> dtype = getDtypeFromOp (rewriter, op);
9263
+ if (failed (dtype)) {
9261
9264
return rewriter.notifyMatchFailure (
9262
- op, " expected result type to have a dtype " );
9265
+ op, " could not determine dtype from the op. " );
9263
9266
}
9264
9267
9265
- Value dtype = getDtypeIntValueForType (rewriter, loc, resultType.getDtype ());
9266
9268
Value none = rewriter.create <ConstantNoneOp>(loc);
9267
9269
Value low = rewriter.create <Torch::ConstantFloatOp>(
9268
9270
loc, rewriter.getF64FloatAttr ((double )0.0 ));
@@ -9274,12 +9276,12 @@ class DecomposeAtenRandnGeneratorOp
9274
9276
loc, rewriter.getF64FloatAttr ((double )(2.0 * 3.14159 )));
9275
9277
9276
9278
Value emptyTensorA = rewriter.create <AtenEmptyMemoryFormatOp>(
9277
- loc, resultType, op.getSize (), /* dtype=*/ dtype,
9279
+ loc, resultType, op.getSize (), /* dtype=*/ * dtype,
9278
9280
/* layout=*/ op.getLayout (),
9279
9281
/* device=*/ op.getDevice (), /* pin_memory=*/ op.getPinMemory (),
9280
9282
/* memory_format=*/ none);
9281
9283
Value emptyTensorB = rewriter.create <AtenEmptyMemoryFormatOp>(
9282
- loc, resultType, op.getSize (), /* dtype=*/ dtype,
9284
+ loc, resultType, op.getSize (), /* dtype=*/ * dtype,
9283
9285
/* layout=*/ op.getLayout (),
9284
9286
/* device=*/ op.getDevice (), /* pin_memory=*/ op.getPinMemory (),
9285
9287
/* memory_format=*/ none);
@@ -9377,8 +9379,13 @@ class DecomposeAtenRandOp : public OpRewritePattern<AtenRandOp> {
9377
9379
loc, rewriter.getF64FloatAttr ((double )0.0 ));
9378
9380
Value high = rewriter.create <Torch::ConstantFloatOp>(
9379
9381
loc, rewriter.getF64FloatAttr ((double )1.0 ));
9382
+ FailureOr<Value> dtype = getDtypeFromOp (rewriter, op);
9383
+ if (failed (dtype)) {
9384
+ return rewriter.notifyMatchFailure (
9385
+ op, " could not determine dtype from the op." );
9386
+ }
9380
9387
Value emptyTensor = rewriter.create <AtenEmptyMemoryFormatOp>(
9381
- loc, resultType, op.getSize (), /* dtype=*/ op. getDtype () ,
9388
+ loc, resultType, op.getSize (), /* dtype=*/ *dtype ,
9382
9389
/* layout=*/ op.getLayout (),
9383
9390
/* device=*/ op.getDevice (), /* pin_memory=*/ op.getPinMemory (),
9384
9391
/* memory_format=*/ noneVal);
@@ -9536,9 +9543,14 @@ class DecomposeAtenEmptyStridedOp
9536
9543
9537
9544
Value noneVal = rewriter.create <ConstantNoneOp>(op.getLoc ());
9538
9545
9546
+ FailureOr<Value> dtype = getDtypeFromOp (rewriter, op);
9547
+ if (failed (dtype)) {
9548
+ return rewriter.notifyMatchFailure (
9549
+ op, " could not determine dtype from the op." );
9550
+ }
9539
9551
rewriter.replaceOpWithNewOp <AtenEmptyMemoryFormatOp>(
9540
- op, op.getType (), op.getSize (), op.getDtype (), op.getLayout (),
9541
- op.getDevice (), op. getPinMemory (), /* memoryFormat=*/ noneVal);
9552
+ op, op.getType (), op.getSize (), *dtype, op.getLayout (), op.getDevice (),
9553
+ op.getPinMemory (), /* memoryFormat=*/ noneVal);
9542
9554
return success ();
9543
9555
}
9544
9556
};
0 commit comments