Skip to content

Commit 80bd093

Browse files
afalkenberg1vivekkhandelwal1
authored andcommitted
Added tensorResultTypeAtIndex to Patterns.h
Need this for LayerNorm
1 parent 9adad9b commit 80bd093

File tree

1 file changed

+10
-0
lines changed
  • include/torch-mlir/Conversion/TorchOnnxToTorch

1 file changed

+10
-0
lines changed

include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,16 @@ struct OpBinder {
9595
return success();
9696
}
9797

98+
ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx, int64_t idx) {
99+
if (idx >= op->getNumResults())
100+
return failure();
101+
auto t = toValidTensorType(op->getResult(idx).getType());
102+
if (!t)
103+
return failure();
104+
typeIdx = t;
105+
return success();
106+
}
107+
98108
// Attribute accessors.
99109
ParseResult s64BoolAttr(bool &value, StringRef nameSuffix,
100110
bool defaultValue = false) {

0 commit comments

Comments
 (0)