1919
2020layout (std430) buffer ;
2121
22- ${layout_declare_tensor(0 , "w", "t_out", DTYPE, STORAGE)}
23- ${layout_declare_tensor(1 , "r", "t_in", DTYPE, STORAGE)}
24- ${layout_declare_tensor(2 , "r", "t_other", DTYPE, STORAGE)}
25- ${layout_declare_ubo(3 , "ivec4 ", "out_sizes")}
26- ${layout_declare_ubo(4 , "ivec4 ", "in_sizes")}
27- ${layout_declare_ubo(5 , "ivec4 ", "other_sizes")}
28- ${layout_declare_ubo(6 , "ivec2 ", "broadcast_params")}
29- ${layout_declare_ubo(7 , "float ", "alpha")}
22+ ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
23+ ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
24+ ${layout_declare_tensor(B, "r", "t_other", DTYPE, STORAGE)}
25+ ${layout_declare_ubo(B, "ivec4 ", "out_sizes")}
26+ ${layout_declare_ubo(B, "ivec4 ", "out_axis_map")}
27+ ${layout_declare_ubo(B, "ivec4 ", "in_sizes")}
28+ ${layout_declare_ubo(B, "ivec4 ", "in_axis_map")}
29+ ${layout_declare_ubo(B, "ivec4 ", "other_sizes")}
30+ ${layout_declare_ubo(B, "ivec4 ", "other_axis_map")}
31+ ${layout_declare_ubo(B, "ivec2 ", "broadcast_params")}
32+ ${layout_declare_ubo(B, "float ", "alpha")}
3033
3134layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
3235
3336layout (constant_id = 3 ) const int packed_dim = C_DIM;
3437
3538void main() {
39+ // pos is physical (x, y, z), as global workgroup uses image extents
3640 const ivec3 pos = ivec3 (gl_GlobalInvocationID);
37- const ivec4 idx = to_tensor_idx(pos, out_sizes, packed_dim);
41+ // physical pos (x, y, z) -> logical (w, c, h, n) output
42+ const ivec4 idx = to_tensor_idx(pos, out_sizes, out_axis_map, packed_dim);
3843
3944 if (any (greaterThanEqual (idx, out_sizes))) {
4045 return ;
4146 }
4247
48+ // broadcast on logical sizes
4349 ivec4 in_idx = broadcast_indices(idx, in_sizes);
44- VEC4_T in_texel = VEC4_T(texelFetch (
50+ VEC4_T in_texel = VEC4_T(load_texel (
4551 t_in,
46- to_texture_pos(in_idx, in_sizes, packed_dim),
47- 0 ));
52+ // read axis mapped texel
53+ to_texture_pos(in_idx, in_sizes, in_axis_map, packed_dim) ));
4854
55+ // broadcast on logical sizes
4956 ivec4 other_idx = broadcast_indices(idx, other_sizes);
50- VEC4_T other_texel = VEC4_T(texelFetch (
57+ VEC4_T other_texel = VEC4_T(load_texel (
5158 t_other,
52- to_texture_pos(other_idx, other_sizes, packed_dim),
53- 0 ));
59+ // read axis mapped texel
60+ to_texture_pos(other_idx, other_sizes, other_axis_map, packed_dim) ));
5461
5562 // Check boolean broadcast flags; we use ivec2 instead of bvec2 for alignment.
5663 if (broadcast_params.x > 0 ) {
@@ -60,5 +67,7 @@ void main() {
6067 other_texel = other_texel.xxxx;
6168 }
6269
63- imageStore(t_out, pos, VEC4_T(op(in_texel, other_texel, alpha)));
70+ imageStore(t_out,
71+ to_texture_pos(idx, out_sizes, out_axis_map, packed_dim),
72+ VEC4_T(op(in_texel, other_texel, alpha)));
6473}
0 commit comments