@@ -18,9 +18,8 @@ namespace vkcompute {
1818
1919vkapi::ShaderInfo get_noop_shader (ComputeGraph& graph, const ValueRef packed) {
2020 std::string noop_shader_name (" no_op" );
21- vTensorPtr t_packed = graph.get_tensor (packed);
22- add_dtype_suffix (noop_shader_name, *t_packed);
23- add_storage_type_suffix (noop_shader_name, *t_packed);
21+ add_dtype_suffix (noop_shader_name, graph.dtype_of (packed));
22+ add_storage_type_suffix (noop_shader_name, graph.storage_type_of (packed));
2423 return VK_KERNEL_FROM_STR (noop_shader_name);
2524}
2625
@@ -48,13 +47,13 @@ PrepackNode::PrepackNode(
4847}
4948
5049api::StagingBuffer PrepackNode::create_staging_buffer (ComputeGraph* graph) {
51- vTensorPtr packed = graph->get_tensor (packed_);
52-
53- // If no TensorRef is provided, create a staging buffer of zeros according to
54- // the vkapi::vTensor metadata.
50+ // If no TensorRef is provided, create a staging buffer of zeros based on the
51+ // Tensor metadata.
5552 if (graph->val_is_none (tref_)) {
56- size_t numel = utils::multiply_integers (packed->sizes ());
57- api::StagingBuffer staging (graph->context (), packed->dtype (), numel);
53+ const std::vector<int64_t > packed_sizes = graph->sizes_of (packed_);
54+ size_t numel = utils::multiply_integers (packed_sizes);
55+ api::StagingBuffer staging (
56+ graph->context (), graph->dtype_of (packed_), numel);
5857 staging.set_staging_zeros ();
5958 return staging;
6059 }
@@ -80,7 +79,6 @@ void PrepackNode::encode(ComputeGraph* graph) {
8079
8180 context->check_device_capabilities (shader_);
8281
83- vTensorPtr packed = graph->get_tensor (packed_);
8482 api::StagingBuffer staging = create_staging_buffer (graph);
8583
8684 std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock ();
@@ -101,8 +99,8 @@ void PrepackNode::encode(ComputeGraph* graph) {
10199 shader_, local_workgroup_size_, spec_vars_, push_constants_offset);
102100
103101 uint32_t idx = 0 ;
104- bind_tensor_to_descriptor_set (
105- *packed ,
102+ graph-> bind_tensor_to_descriptor_set (
103+ packed_ ,
106104 pipeline_barrier,
107105 vkapi::MemoryAccessType::WRITE,
108106 descriptor_set,
@@ -128,8 +126,8 @@ void PrepackNode::encode(ComputeGraph* graph) {
128126 vkapi::DescriptorSet descriptor_set = context->get_descriptor_set (
129127 noop_shader_, utils::WorkgroupSize (1 , 1 , 1 ));
130128
131- bind_tensor_to_descriptor_set (
132- *packed ,
129+ graph-> bind_tensor_to_descriptor_set (
130+ packed_ ,
133131 pipeline_barrier,
134132 vkapi::MemoryAccessType::READ,
135133 descriptor_set,
0 commit comments