1919
2020layout (std430) buffer ;
2121
22- ${layout_declare_tensor(0 , "w", "t_out", DTYPE, STORAGE)}
23- ${layout_declare_tensor(1 , "r", "t_in", DTYPE, STORAGE)}
24- ${layout_declare_tensor(2 , "r", "t_other", DTYPE, STORAGE)}
25- ${layout_declare_ubo(3 , "ivec4 ", "out_sizes")}
26- ${layout_declare_ubo(4 , "ivec4 ", "in_sizes")}
27- ${layout_declare_ubo(5 , "ivec4 ", "other_sizes")}
28- ${layout_declare_ubo(6 , "ivec2 ", "broadcast_params")}
29- ${layout_declare_ubo(7 , "float ", "alpha")}
22+ ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
23+ ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
24+ ${layout_declare_tensor(B, "r", "t_other", DTYPE, STORAGE)}
25+ ${layout_declare_ubo(B, "ivec4 ", "out_sizes")}
26+ ${layout_declare_ubo(B, "ivec4 ", "out_axis_map")}
27+ ${layout_declare_ubo(B, "ivec4 ", "in_sizes")}
28+ ${layout_declare_ubo(B, "ivec4 ", "in_axis_map")}
29+ ${layout_declare_ubo(B, "ivec4 ", "other_sizes")}
30+ ${layout_declare_ubo(B, "ivec4 ", "other_axis_map")}
31+ ${layout_declare_ubo(B, "ivec2 ", "broadcast_params")}
32+ ${layout_declare_ubo(B, "float ", "alpha")}
3033
3134layout (local_size_x_id =  0 , local_size_y_id =  1 , local_size_z_id =  2 ) in ;
3235
3336layout (constant_id =  3 ) const  int  packed_dim =  C_DIM;
3437
3538void  main() {
39+   //  pos is physical (x, y, z), as global workgroup uses image extents
3640  const  ivec3  pos =  ivec3 (gl_GlobalInvocationID);
37-   const  ivec4  idx =  to_tensor_idx(pos, out_sizes, packed_dim);
41+   //  physical pos (x, y, z) -> logical (w, c, h, n) output
42+   const  ivec4  idx =  to_tensor_idx(pos, out_sizes, out_axis_map, packed_dim);
3843
3944  if  (any (greaterThanEqual (idx, out_sizes))) {
4045    return ;
4146  }
4247
48+   //  broadcast on logical sizes
4349  ivec4  in_idx =  broadcast_indices(idx, in_sizes);
44-   VEC4_T in_texel =  VEC4_T(texelFetch (
50+   VEC4_T in_texel =  VEC4_T(load_texel (
4551    t_in,
46-     to_texture_pos(in_idx, in_sizes, packed_dim), 
47-     0 ));
52+     //  read axis mapped texel 
53+     to_texture_pos(in_idx, in_sizes, in_axis_map, packed_dim) ));
4854
55+   //  broadcast on logical sizes
4956  ivec4  other_idx =  broadcast_indices(idx, other_sizes);
50-   VEC4_T other_texel =  VEC4_T(texelFetch (
57+   VEC4_T other_texel =  VEC4_T(load_texel (
5158    t_other,
52-     to_texture_pos(other_idx, other_sizes, packed_dim), 
53-     0 ));
59+     //  read axis mapped texel 
60+     to_texture_pos(other_idx, other_sizes, other_axis_map, packed_dim) ));
5461
5562  //  Check boolean broadcast flags; we use ivec2 instead of bvec2 for alignment.
5663  if  (broadcast_params.x >  0 ) {
@@ -60,5 +67,7 @@ void main() {
6067    other_texel =  other_texel.xxxx;
6168  }
6269
63-   imageStore(t_out, pos, VEC4_T(op(in_texel, other_texel, alpha)));
70+   imageStore(t_out,
71+     to_texture_pos(idx, out_sizes, out_axis_map, packed_dim),
72+     VEC4_T(op(in_texel, other_texel, alpha)));
6473}
0 commit comments