File tree Expand file tree Collapse file tree 2 files changed +23
-0
lines changed Expand file tree Collapse file tree 2 files changed +23
-0
lines changed Original file line number Diff line number Diff line change @@ -61,6 +61,10 @@ class RankedTensorType {
6161 return ElementType () == other.ElementType () && Layout () == other.Layout ();
6262 }
6363
64+ bool operator !=(const RankedTensorType& other) const {
65+ return !(*this == other);
66+ }
67+
6468 ElementType ElementType () const { return element_type_; }
6569
6670 const Layout& Layout () const { return layout_; }
@@ -105,6 +109,16 @@ class Tensor : public internal::NonOwnedHandle<LiteRtTensor> {
105109 }
106110 }
107111
112+ bool HasType (const RankedTensorType& type) const {
113+ auto t = RankedTensorType ();
114+ return t && *t == type;
115+ }
116+
117+ bool HasType (const LiteRtRankedTensorType& type) const {
118+ auto t = RankedTensorType ();
119+ return t && *t == ::litert::RankedTensorType (type);
120+ }
121+
108122 LiteRtTensorTypeId TypeId () const {
109123 LiteRtTensorTypeId type_id;
110124 internal::AssertOk (LiteRtGetTensorTypeId, Get (), &type_id);
Original file line number Diff line number Diff line change @@ -209,6 +209,15 @@ class TensorBuffer
209209 return RankedTensorType (tensor_type);
210210 }
211211
212+ bool HasType (const RankedTensorType& type) const {
213+ auto t = TensorType ();
214+ return t && *t == type;
215+ }
216+
217+ bool HasType (const LiteRtRankedTensorType& type) const {
218+ auto t = TensorType ();
219+ return t && *t == ::litert::RankedTensorType (type);
220+ }
212221 // Returns the size of the underlying H/W tensor buffer. This size can be
213222 // different to the PackedSize() if there is stride and padding exists.
214223 Expected<size_t > Size () const {
You can’t perform that action at this time.
0 commit comments