Skip to content

Commit 2c6915e

Browse files
committed
Update base for Update on "[ET-VK] New implementation of cat operator"
## Changes * Introduce `concat_texture.glsl` and `concat_buffer.glsl` to implement the `torch.cat` operator * Introduce `Concat.cpp` to replace `Cat.cpp` * Fix a bug with channels-packed buffer tensors where input data would be copied incorrectly with multiple dims have a stride of 1 ## Motivation > * Introduce `concat_texture.glsl` and `concat_buffer.glsl` to implement the `torch.cat` operator > * Introduce `Concat.cpp` to replace `Cat.cpp` The existing implementation of `torch.cat` uses the copy_channel_offset` shaders. However, these shaders have a critical bug where the output tensor is passed in separately with difference access types, i.e. ``` graph.execute_nodes().emplace_back(new DispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), global_size, local_size, // Inputs and Outputs { {out, vkapi::kWrite}, {out, vkapi::kRead}, {in, vkapi::kRead}, }, ``` This creates many validation layer errors because the memory barriers for the resource cannot be formed properly. The shader essentially relies on undefined behaviour to work correctly. The result is that the `cat` operator produces incorrect result on many platforms. Rather than fix the `copy_offset` shaders, I decided to just introduce new shaders to perform the concat operation. The new implementation handles both buffer and texture inputs and is agnostic to memory layout. Differential Revision: [D76305343](https://our.internmc.facebook.com/intern/diff/D76305343/) [ghstack-poisoned]
1 parent d43de86 commit 2c6915e

File tree

4 files changed

+6
-8
lines changed

4 files changed

+6
-8
lines changed

backends/vulkan/runtime/api/containers/Tensor.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
#include <executorch/backends/vulkan/runtime/utils/StorageUtils.h>
1818

19-
#include <iostream>
20-
2119
namespace vkcompute {
2220
namespace api {
2321

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,13 @@ ivec3 lpos_to_pos(const ivec3 lpos, const ivec4 axis_map) {
253253
* e.g. 0x11021, 1 -> ivec4(1, 2, 0, 1)
254254
*/
255255
#define unhash_axis_map(hash) \
256-
ivec4(hash & 0xf, (hash >> 4) & 0xf, (hash >> 8 & 0xf), (hash >> 12 & 0xf))
256+
(ivec4(hash & 0xf, (hash >> 4) & 0xf, (hash >> 8 & 0xf), (hash >> 12 & 0xf)))
257257

258258
/*
259259
*
260260
*/
261261
#define unhash_dim_order(hash) \
262-
ivec4(hash & 0xf, (hash >> 4) & 0xf, (hash >> 8 & 0xf), (hash >> 12 & 0xf))
262+
(ivec4(hash & 0xf, (hash >> 4) & 0xf, (hash >> 8 & 0xf), (hash >> 12 & 0xf)))
263263

264264
#define unhash_packed_dim(hash) int(hash >> 16 & 0xf)
265265

backends/vulkan/runtime/vk_api/Descriptor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ BufferBindInfo::BufferBindInfo(
3232

3333
BufferBindInfo::BufferBindInfo(
3434
const VulkanBuffer& buffer_p,
35-
const uint32_t offset_p,
36-
const uint32_t range_p)
35+
const size_t offset_p,
36+
const size_t range_p)
3737
: handle(buffer_p.handle()),
3838
offset(buffer_p.mem_offset() + offset_p),
3939
range(range_p) {

backends/vulkan/runtime/vk_api/Descriptor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ struct BufferBindInfo final {
3636
BufferBindInfo(const VulkanBuffer& buffer_p, const uint32_t offset_p = 0u);
3737
BufferBindInfo(
3838
const VulkanBuffer& buffer_p,
39-
const uint32_t offset_p,
40-
const uint32_t range_p);
39+
const size_t offset_p,
40+
const size_t range_p);
4141
};
4242

4343
struct ParamsBindList final {

0 commit comments

Comments
 (0)