@@ -239,13 +239,14 @@ ValueRef ComputeGraph::add_tensor(
239239 const vkapi::ScalarType dtype,
240240 const utils::StorageType storage_type,
241241 const utils::GPUMemoryLayout memory_layout,
242- const int64_t shared_object_idx) {
242+ const int64_t shared_object_idx,
243+ const utils::AxisMapLayout axis_map_layout) {
243244 bool allocate_memory = shared_object_idx < 0 ;
244245
245246 ValueRef idx (static_cast <int >(values_.size ()));
246247 check_no_active_value_ptrs ();
247248 values_.emplace_back (api::vTensor (
248- context (), sizes, dtype, storage_type, memory_layout, allocate_memory));
249+ context (), sizes, dtype, storage_type, memory_layout, allocate_memory, axis_map_layout ));
249250
250251 if (!allocate_memory) {
251252 get_shared_object (shared_object_idx).add_user (this , idx);
@@ -257,44 +258,50 @@ ValueRef ComputeGraph::add_tensor(
257258 const std::vector<int64_t >& sizes,
258259 const vkapi::ScalarType dtype,
259260 const utils::StorageType storage_type,
260- const int64_t shared_object_idx) {
261+ const int64_t shared_object_idx,
262+ const utils::AxisMapLayout axis_map_layout) {
261263 return add_tensor (
262264 sizes,
263265 dtype,
264266 storage_type,
265267 suggested_memory_layout (sizes),
266- shared_object_idx);
268+ shared_object_idx,
269+ axis_map_layout);
267270}
268271
269272ValueRef ComputeGraph::add_tensor (
270273 const std::vector<int64_t >& sizes,
271274 const vkapi::ScalarType dtype,
272275 const utils::GPUMemoryLayout memory_layout,
273- const int64_t shared_object_idx) {
276+ const int64_t shared_object_idx,
277+ const utils::AxisMapLayout axis_map_layout) {
274278 return add_tensor (
275- sizes, dtype, suggested_storage_type (), memory_layout, shared_object_idx);
279+ sizes, dtype, suggested_storage_type (), memory_layout, shared_object_idx, axis_map_layout );
276280}
277281
278282ValueRef ComputeGraph::add_tensor_like (
279283 const ValueRef idx,
280284 const utils::StorageType storage_type,
281- const utils::GPUMemoryLayout memory_layout) {
282- return add_tensor (sizes_of (idx), dtype_of (idx), storage_type, memory_layout);
285+ const utils::GPUMemoryLayout memory_layout,
286+ const utils::AxisMapLayout axis_map_layout) {
287+ return add_tensor (sizes_of (idx), dtype_of (idx), storage_type, memory_layout, -1 , axis_map_layout);
283288}
284289
285290ValueRef ComputeGraph::add_tensor_like (
286291 const ValueRef idx,
287- const utils::GPUMemoryLayout memory_layout) {
292+ const utils::GPUMemoryLayout memory_layout,
293+ const utils::AxisMapLayout axis_map_layout) {
288294 return add_tensor (
289- sizes_of (idx), dtype_of (idx), storage_type_of (idx), memory_layout);
295+ sizes_of (idx), dtype_of (idx), storage_type_of (idx), memory_layout, - 1 , axis_map_layout );
290296}
291297
292298ValueRef ComputeGraph::add_tensor (
293299 const std::vector<int64_t >& sizes,
294300 const vkapi::ScalarType dtype,
295- const int64_t shared_object_idx) {
301+ const int64_t shared_object_idx,
302+ const utils::AxisMapLayout axis_map_layout) {
296303 return add_tensor (
297- sizes, dtype, suggested_memory_layout (sizes), shared_object_idx);
304+ sizes, dtype, suggested_memory_layout (sizes), shared_object_idx, axis_map_layout );
298305}
299306
300307ValueRef ComputeGraph::add_tensor (const vkapi::VulkanImage& image) {
0 commit comments