Skip to content

Commit 366baa6

Browse files
authored
[ET-VK][ez] Make get_tensor() API protected (#13168)
## Changes As title; make the `get_tensor()` API protected. ## Motivation See the below diff/PR in the stack. The goal is to encourage operator authors to go through the `ComputeGraph` to access/modify tensors so that the activity can be tracked. Differential Revision: [D79564596](https://our.internmc.facebook.com/intern/diff/D79564596/)
1 parent 5488056 commit 366baa6

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,16 @@ class ComputeGraph final {
248248
return values_.at(idx).is##type_name(); \
249249
}
250250

251-
GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(vTensorPtr, tensor, Tensor)
251+
protected:
252+
inline vTensorPtr get_tensor(const ValueRef idx) {
253+
return vTensorPtr(this, idx);
254+
}
255+
256+
public:
257+
inline bool val_is_tensor(const ValueRef idx) const {
258+
return values_.at(idx).isTensor();
259+
}
260+
252261
GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(TensorRefPtr, tref, TensorRef)
253262
GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(StagingPtr, staging, Staging)
254263
GET_AND_CHECK_VAL_AS_PTR_TYPE_FNS(IntListPtr, int_list, IntList)
@@ -970,6 +979,8 @@ class ComputeGraph final {
970979
friend class SymIntPtr;
971980

972981
friend struct TmpTensor;
982+
friend struct SharedObject;
983+
friend class BlitNode;
973984
};
974985

975986
template <typename T>

0 commit comments

Comments
 (0)