Skip to content

Commit 76e10d1

Browse files
LukeBoyercopybara-github
authored andcommitted
Add convenience method to c++ api for tensor and tensor buffer to check type directly. This alleviates the minor annoyance of checking the expected before comparison and converting a c type.
LiteRT-PiperOrigin-RevId: 775400125
1 parent c8e50b0 commit 76e10d1

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

litert/cc/litert_model.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ class RankedTensorType {
6161
return ElementType() == other.ElementType() && Layout() == other.Layout();
6262
}
6363

64+
bool operator!=(const RankedTensorType& other) const {
65+
return !(*this == other);
66+
}
67+
6468
ElementType ElementType() const { return element_type_; }
6569

6670
const Layout& Layout() const { return layout_; }
@@ -105,6 +109,16 @@ class Tensor : public internal::NonOwnedHandle<LiteRtTensor> {
105109
}
106110
}
107111

112+
bool HasType(const RankedTensorType& type) const {
113+
auto t = RankedTensorType();
114+
return t && *t == type;
115+
}
116+
117+
bool HasType(const LiteRtRankedTensorType& type) const {
118+
auto t = RankedTensorType();
119+
return t && *t == ::litert::RankedTensorType(type);
120+
}
121+
108122
LiteRtTensorTypeId TypeId() const {
109123
LiteRtTensorTypeId type_id;
110124
internal::AssertOk(LiteRtGetTensorTypeId, Get(), &type_id);

litert/cc/litert_tensor_buffer.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,15 @@ class TensorBuffer
209209
return RankedTensorType(tensor_type);
210210
}
211211

212+
bool HasType(const RankedTensorType& type) const {
213+
auto t = TensorType();
214+
return t && *t == type;
215+
}
216+
217+
bool HasType(const LiteRtRankedTensorType& type) const {
218+
auto t = TensorType();
219+
return t && *t == ::litert::RankedTensorType(type);
220+
}
212221
// Returns the size of the underlying H/W tensor buffer. This size can be
213222
// different to the PackedSize() if there is stride and padding exists.
214223
Expected<size_t> Size() const {

0 commit comments

Comments
 (0)