From 8e3d21ed30c776986371266a521f4231e4623c29 Mon Sep 17 00:00:00 2001 From: Alexander Dean Date: Wed, 10 Sep 2025 13:43:53 -0500 Subject: [PATCH 1/8] Optimize conv2d s1p0 --- .../graph/ops/glsl/conv2d_pw_s1p0.glsl | 185 +++++++----------- .../graph/ops/glsl/conv2d_pw_s1p0.yaml | 2 - .../runtime/graph/ops/impl/Convolution.cpp | 4 + 3 files changed, 80 insertions(+), 111 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl index 9f84afeb1a1..217c727bcd6 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl @@ -14,9 +14,6 @@ #define VEC4_T ${texel_type(DTYPE)} -#define TILE_SIZE_X uint16_t(${TILE_SIZE_X}) -#define TILE_SIZE_Y uint16_t(${TILE_SIZE_Y}) - #define op(X, A, B) ${OPERATOR} #include "indexing_utils.h" @@ -50,119 +47,89 @@ ${layout_declare_spec_const(C, "int", "ngroups", "1")} * size is only 1x1, making it easier to re-use loaded texels from t_kernel. */ void main() { - const int out_limits_scaled[2] = - {(out_limits.x + (TILE_SIZE_X - 1)) / TILE_SIZE_X, - (out_limits.y + (TILE_SIZE_Y - 1)) / TILE_SIZE_Y}; - - const uint16_t div_by_x = uint16_t(gl_GlobalInvocationID.x / out_limits_scaled[0]); - const uint16_t out_pos_xy[2] = {uint16_t(gl_GlobalInvocationID.x % out_limits_scaled[0]), div_by_x}; - const int out_pos_z = int(gl_GlobalInvocationID.y); - - // If the top left position is out of bounds, then this invocation will have - // no work to do. - if (out_pos_xy[1] >= out_limits_scaled[1] || out_pos_z >= out_limits.z) { - return; - } - // Output position for TILE_SIZE = 2 - // +--------+--------+ - // | pos[0] | pos[1] | - // +--------+--------+ - // | pos[2] | pos[3] | - // +--------+--------+ - uint16_t pos[TILE_SIZE_X * TILE_SIZE_Y * 2]; - for (uint16_t y = uint16_t(0), i = uint16_t(0); y < TILE_SIZE_Y; ++y) { - for (uint16_t x = uint16_t(0); x < TILE_SIZE_X; ++x) { - pos[i * 2] = out_pos_xy[0] * TILE_SIZE_X + x; - pos[i * 2 + 1] = out_pos_xy[1] * TILE_SIZE_Y + y; - i++; - } - } + int inputAndOutputWidth = out_limits.x; + int inputAndOutputHeight = out_limits.y; + int outputChannel = out_limits.z*4; - // Final output array where each element is a tensor value. - // Tuple of consecutive 4 elements represents a single output texel. - float sum[TILE_SIZE_X * TILE_SIZE_Y * 4]; + // Divided by 4 because the input channels are packed + int inputChannel = in_group_size/4; - // Initialize the output array with the bias value - for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y * 4; i++) { - sum[i] = 0; - } + int threadHW = int(gl_GlobalInvocationID.x); + int gid1 = int(gl_GlobalInvocationID.y); - int z4 = 0; - // Since the kernel is 1x1, we only have to loop over the depth dimension. - for (int z = 0; z < in_group_size; z += 4, ++z4) { - // During prepacking, the weight tensor has been permuted so that the - // channel (IC) dim is along the x-axis, and the batch (OC) dim is along - // the z-axis. - float kernel_values[4 * 4]; // 4 channels, 4 elements per channel - - // Load kernel values from texels to array - [[unroll]] for (int i = 0; i < 4; ++i) { - const vec4 k_tex = texelFetch(t_kernel, ivec2(z + i, out_pos_z), 0); - kernel_values[i * 4 + 0] = k_tex.x; - kernel_values[i * 4 + 1] = k_tex.y; - kernel_values[i * 4 + 2] = k_tex.z; - kernel_values[i * 4 + 3] = k_tex.w; - } + int xIdx = threadHW % inputAndOutputWidth; + int yIdx = threadHW / inputAndOutputWidth; - for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) { - const vec4 in_tex = texelFetch(t_in, ivec3(pos[i * 2], pos[i * 2 + 1], z4), 0); - // Load the input texel into an array - float tex_values[4]; - tex_values[0] = in_tex.x; - tex_values[1] = in_tex.y; - tex_values[2] = in_tex.z; - tex_values[3] = in_tex.w; - - // For 2x2 tile size algorithm works as follows. - // To explain the calculations below, the contents of one in_tex and the - // group of 4 texels loaded from t_kernel are shown: - // - // in_tex t_kernel - // -x-> ---x---> - // +---+ +----+----+----+----+ - // ^ | w | ^ | D0 | D1 | D2 | D3 | - // | +---+ | +----+----+----+----+ - // | | z | | | C0 | C1 | C2 | C3 | - // z +---+ z +----+----+----+----+ - // | | y | | | B0 | B2 | B2 | B3 | - // | +---+ | +----+----+----+----+ - // | x | | A0 | A1 | A2 | A3 | - // +---+ +----+----+----+----+ - // - // In the t_kernel graphic, cells sharing the same letter are from - // the same batch/output channel index, and the number denotes a unique - // channel index. To calculate the output texel, the following - // calculation is performed: - // - // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ - // | x | | D0 | | y | | D1 | | z | | D2 | | w | | D3 | - // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ - // | x | | C0 | | y | | C1 | | z | | C2 | | w | | C3 | - // +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+ - // | x | | B0 | | y | | B1 | | z | | B2 | | w | | B3 | - // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ - // | x | | A0 | | y | | A1 | | z | | A2 | | w | | A3 | - // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ - // - // which is what is expressed in the following calculations. This is done - // for each output position. - for (int j = 0; j < 4; ++j) { - sum[i * 4 + j] = tex_values[0] * kernel_values[0 + j] + sum[i * 4 + j]; - sum[i * 4 + j] = tex_values[1] * kernel_values[4 + j] + sum[i * 4 + j]; - sum[i * 4 + j] = tex_values[2] * kernel_values[8 + j] + sum[i * 4 + j]; - sum[i * 4 + j] = tex_values[3] * kernel_values[12 + j] + sum[i * 4 + j]; - } - } - } + if (threadHW < inputAndOutputWidth * inputAndOutputHeight && gid1 < outputChannel) { + + vec4 outputTexel = texelFetch(t_bias, ivec2(gid1, 0), 0); + + vec4 inputVec; + vec4 weight1OutputChannelPacked; + vec4 weight2OutputChannelPacked; + vec4 weight3OutputChannelPacked; + vec4 weight4OutputChannelPacked; + + // By unrolling the loop in sets of 4, this significantly reduces the number of branching instructions + // and enables the compiler to rearrange instructions for more efficient memory retrieval and compute + for (int inputC = 0; inputC < inputChannel; inputC += 1) { + + inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + + weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, gid1), 0); + weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, gid1), 0); + weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, gid1), 0); + weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, gid1), 0); - const vec4 bias = texelFetch(t_bias, ivec2(out_pos_z, 0), 0); + outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); - for (int i = 0; i < TILE_SIZE_X * TILE_SIZE_Y; ++i) { - const ivec3 pos_l = ivec3(pos[i * 2], pos[i * 2 + 1], out_pos_z); - if (all(lessThan(pos_l.xy, out_limits.xy))) { - const vec4 out_sum = vec4(sum[i * 4], sum[i * 4 + 1], sum[i * 4 + 2], sum[i * 4 + 3]); - imageStore(t_out, pos_l, op(out_sum + bias, out_min, out_max)); + inputC += 1; + + inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + + weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, gid1), 0); + weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, gid1), 0); + weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, gid1), 0); + weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, gid1), 0); + + outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + + inputC += 1; + + inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + + weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, gid1), 0); + weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, gid1), 0); + weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, gid1), 0); + weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, gid1), 0); + + outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + + inputC += 1; + + inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + + weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, gid1), 0); + weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, gid1), 0); + weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, gid1), 0); + weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, gid1), 0); + + outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); } + + imageStore(t_out, ivec3(xIdx, yIdx, gid1), op(outputTexel, out_min, out_max)); } } diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml index ebfee11c405..bab3c715540 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml @@ -9,8 +9,6 @@ conv2d_pw_s1p0: OPERATOR: X NDIM: 3 DTYPE: float - TILE_SIZE_X: 1 - TILE_SIZE_Y: 4 generate_variant_forall: DTYPE: - VALUE: half diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index ded1defe973..ef4a4d514b0 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -364,6 +364,10 @@ utils::uvec3 conv2d_global_wg_size( if (method == Conv2dMethod::Depthwise || method == Conv2dMethod::Pointwise) { wg_size = {wg_size[0] * wg_size[1], wg_size[2], 1}; + + if (shader.kernel_name.find("s1p0") != std::string::npos) { + wg_size[0] *= 4; + } } return wg_size; From 41ec8b64829f6a74a4153c3a700c290c6e03838c Mon Sep 17 00:00:00 2001 From: Alex Dean Date: Thu, 11 Sep 2025 16:13:36 -0700 Subject: [PATCH 2/8] Stylistic changes to pw conv2d s1p0 --- .../graph/ops/glsl/conv2d_pw_s1p0.glsl | 107 +++++++++--------- 1 file changed, 54 insertions(+), 53 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl index 217c727bcd6..06443d6a028 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl @@ -56,80 +56,81 @@ void main() { int inputChannel = in_group_size/4; int threadHW = int(gl_GlobalInvocationID.x); - int gid1 = int(gl_GlobalInvocationID.y); + int threadOutChannel = int(gl_GlobalInvocationID.y); int xIdx = threadHW % inputAndOutputWidth; int yIdx = threadHW / inputAndOutputWidth; - if (threadHW < inputAndOutputWidth * inputAndOutputHeight && gid1 < outputChannel) { - - vec4 outputTexel = texelFetch(t_bias, ivec2(gid1, 0), 0); + if (threadHW >= inputAndOutputWidth * inputAndOutputHeight && threadOutChannel >= outputChannel) { + return; + } - vec4 inputVec; - vec4 weight1OutputChannelPacked; - vec4 weight2OutputChannelPacked; - vec4 weight3OutputChannelPacked; - vec4 weight4OutputChannelPacked; + vec4 outputTexel = texelFetch(t_bias, ivec2(threadOutChannel, 0), 0); - // By unrolling the loop in sets of 4, this significantly reduces the number of branching instructions - // and enables the compiler to rearrange instructions for more efficient memory retrieval and compute - for (int inputC = 0; inputC < inputChannel; inputC += 1) { + vec4 inputVec; + vec4 weight1OutputChannelPacked; + vec4 weight2OutputChannelPacked; + vec4 weight3OutputChannelPacked; + vec4 weight4OutputChannelPacked; - inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + // By unrolling the loop in sets of 4, this significantly reduces the number of branching instructions + // and enables the compiler to rearrange instructions for more efficient memory retrieval and compute + for (int inputC = 0; inputC < inputChannel; inputC += 1) { - weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, gid1), 0); - weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, gid1), 0); - weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, gid1), 0); - weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, gid1), 0); + inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); - outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); - outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); - outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); - outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0); + weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0); + weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0); + weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0); - inputC += 1; + outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); - inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + inputC += 1; - weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, gid1), 0); - weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, gid1), 0); - weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, gid1), 0); - weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, gid1), 0); + inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); - outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); - outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); - outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); - outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0); + weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0); + weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0); + weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0); - inputC += 1; + outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); - inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + inputC += 1; - weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, gid1), 0); - weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, gid1), 0); - weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, gid1), 0); - weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, gid1), 0); + inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); - outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); - outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); - outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); - outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0); + weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0); + weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0); + weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0); - inputC += 1; + outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); - inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + inputC += 1; - weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, gid1), 0); - weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, gid1), 0); - weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, gid1), 0); - weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, gid1), 0); + inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); - outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); - outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); - outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); - outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); - } + weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0); + weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0); + weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0); + weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0); - imageStore(t_out, ivec3(xIdx, yIdx, gid1), op(outputTexel, out_min, out_max)); + outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); } + + imageStore(t_out, ivec3(xIdx, yIdx, threadOutChannel), op(outputTexel, out_min, out_max)); } From c1910fe25ec3ff47f3b6faaf29bff54d6bbe1ce5 Mon Sep 17 00:00:00 2001 From: Alexander Dean Date: Wed, 17 Sep 2025 16:49:55 -0500 Subject: [PATCH 3/8] Add fp16 to conv2d pw s1p0 --- .../graph/ops/glsl/conv2d_pw_s1p0.glsl | 93 ++++++++++--------- 1 file changed, 49 insertions(+), 44 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl index 06443d6a028..ef50a1aca9f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl @@ -12,7 +12,12 @@ #define PRECISION ${PRECISION} -#define VEC4_T ${texel_type(DTYPE)} +$if DTYPE == "half": + #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require + #define VEC4_T f16vec4 +$else: + #define VEC4_T ${texel_type(DTYPE)} + #define op(X, A, B) ${OPERATOR} @@ -65,72 +70,72 @@ void main() { return; } - vec4 outputTexel = texelFetch(t_bias, ivec2(threadOutChannel, 0), 0); + VEC4_T outputTexel = VEC4_T(texelFetch(t_bias, ivec2(threadOutChannel, 0), 0)); - vec4 inputVec; - vec4 weight1OutputChannelPacked; - vec4 weight2OutputChannelPacked; - vec4 weight3OutputChannelPacked; - vec4 weight4OutputChannelPacked; + VEC4_T inputVec; + VEC4_T weight1OutputChannelPacked; + VEC4_T weight2OutputChannelPacked; + VEC4_T weight3OutputChannelPacked; + VEC4_T weight4OutputChannelPacked; // By unrolling the loop in sets of 4, this significantly reduces the number of branching instructions // and enables the compiler to rearrange instructions for more efficient memory retrieval and compute for (int inputC = 0; inputC < inputChannel; inputC += 1) { - inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); - weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0); - weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0); - weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0); - weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0); + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); - outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); - outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); - outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); - outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); inputC += 1; - inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); - weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0); - weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0); - weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0); - weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0); + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); - outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); - outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); - outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); - outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); inputC += 1; - inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); - weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0); - weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0); - weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0); - weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0); + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); - outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); - outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); - outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); - outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); inputC += 1; - inputVec = texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0); + inputVec = VEC4_T(texelFetch(t_in, ivec3(xIdx, yIdx, inputC), 0)); - weight1OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0); - weight2OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0); - weight3OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0); - weight4OutputChannelPacked = texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0); + weight1OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 0, threadOutChannel), 0)); + weight2OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 1, threadOutChannel), 0)); + weight3OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 2, threadOutChannel), 0)); + weight4OutputChannelPacked = VEC4_T(texelFetch(t_kernel, ivec2(inputC * 4 + 3, threadOutChannel), 0)); - outputTexel[0] += dot(inputVec, vec4(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); - outputTexel[1] += dot(inputVec, vec4(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); - outputTexel[2] += dot(inputVec, vec4(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); - outputTexel[3] += dot(inputVec, vec4(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); + outputTexel[0] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[0], weight2OutputChannelPacked[0], weight3OutputChannelPacked[0], weight4OutputChannelPacked[0])); + outputTexel[1] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[1], weight2OutputChannelPacked[1], weight3OutputChannelPacked[1], weight4OutputChannelPacked[1])); + outputTexel[2] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[2], weight2OutputChannelPacked[2], weight3OutputChannelPacked[2], weight4OutputChannelPacked[2])); + outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); } - imageStore(t_out, ivec3(xIdx, yIdx, threadOutChannel), op(outputTexel, out_min, out_max)); + imageStore(t_out, ivec3(xIdx, yIdx, threadOutChannel), op(vec4(outputTexel), out_min, out_max)); } From a72e2e43405b84d3a45b8b0ddf33d3595a5cfeeb Mon Sep 17 00:00:00 2001 From: Alexander Dean Date: Thu, 18 Sep 2025 13:42:31 -0500 Subject: [PATCH 4/8] Add Fusing for Conv/Binary Ops, Clamp/Binary Ops, and Clamp/Clamp --- .../transforms/fuse_clamp_with_binary_op.py | 132 +++++ backends/transforms/fuse_clamps.py | 96 ++++ .../transforms/fuse_conv_with_binary_op.py | 108 ++++ backends/transforms/fuse_conv_with_clamp.py | 9 +- backends/transforms/targets.bzl | 48 ++ backends/vulkan/custom_ops_lib.py | 536 ++++++++++++++++++ backends/vulkan/op_registry.py | 9 + .../graph/ops/glsl/conv2d_pw_s1p0.glsl | 10 +- .../graph/ops/glsl/conv2d_pw_s1p0.yaml | 14 + .../runtime/graph/ops/glsl/unary_op.glsl | 9 + .../runtime/graph/ops/glsl/unary_op.yaml | 13 + .../runtime/graph/ops/impl/Convolution.cpp | 92 ++- .../vulkan/runtime/graph/ops/impl/UnaryOp.cpp | 31 +- backends/vulkan/targets.bzl | 3 + backends/vulkan/vulkan_preprocess.py | 10 +- 15 files changed, 1104 insertions(+), 16 deletions(-) create mode 100644 backends/transforms/fuse_clamp_with_binary_op.py create mode 100644 backends/transforms/fuse_clamps.py create mode 100644 backends/transforms/fuse_conv_with_binary_op.py diff --git a/backends/transforms/fuse_clamp_with_binary_op.py b/backends/transforms/fuse_clamp_with_binary_op.py new file mode 100644 index 00000000000..8e24f482695 --- /dev/null +++ b/backends/transforms/fuse_clamp_with_binary_op.py @@ -0,0 +1,132 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +import executorch.backends.vulkan.custom_ops_lib # noqa + +import torch + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +class FuseClampBinaryOpPass(ExportPass): + + FUSEABLE_OPS = [ + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.clamp.default, + ] + FUSEABLE_BINARY_OPS = [ + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.div.Tensor, + ] + + def exists_before(self, graph_module, node_a, node_b): + seen_a = False + for n in graph_module.graph.nodes: + if n is node_a: + seen_a = True + if n is node_b: + return seen_a + return False + + def get_output_min_max_from_activation(self, activation_node): + if activation_node.target == exir_ops.edge.aten.relu.default: + output_min = 0.0 + output_max = sys.float_info.max + elif activation_node.target == exir_ops.edge.aten.hardtanh.default: + output_min = -1.0 + output_max = 1.0 + if len(activation_node.args) > 1: + output_min = activation_node.args[1] + output_max = activation_node.args[2] + elif activation_node.target == exir_ops.edge.aten.clamp.default: + output_min = None + output_max = None + if len(activation_node.args) >= 2: + output_min = activation_node.args[1] + if len(activation_node.args) >= 3: + output_max = activation_node.args[2] + + return output_min, output_max + + + def call(self, graph_module: torch.fx.GraphModule): + fuseAdded = True + while fuseAdded: + fuseAdded = False + for arg_idx in range(0, 2): + for binary_op_node in graph_module.graph.nodes: + if binary_op_node.op == "call_function": + if binary_op_node.target in self.FUSEABLE_BINARY_OPS: + preceding_op = binary_op_node.args[arg_idx] + + if ( + preceding_op.op == "call_function" + and preceding_op.target in self.FUSEABLE_OPS + ): + # Ensure the shapes match + if "val" not in binary_op_node.args[0].meta or "val" not in binary_op_node.args[1].meta: + continue + if len(binary_op_node.args[1].meta["val"].shape) != len(binary_op_node.args[0].meta["val"].shape): + continue + + # Get the texture to do the binary op + texture = binary_op_node.args[(arg_idx + 1)%2] + + # Fuse only if the texture exists before the preceding op + if not self.exists_before(graph_module, texture, preceding_op): + continue + + new_args = list(preceding_op.args) + + # insert the min/max at indices 1 and 2 + output_min_max = self.get_output_min_max_from_activation( + preceding_op + ) + new_args.insert(1, output_min_max[0]) + new_args.insert(2, output_min_max[1]) + + # put the other texture at idx 3 + new_args.insert(3, texture) + new_args = new_args[0:4] + + new_args = tuple(new_args) + binary_op_node.replace_all_uses_with(preceding_op) + graph_module.graph.erase_node(binary_op_node) + + new_op = None + if binary_op_node.target == exir_ops.edge.aten.add.Tensor: + new_op = exir_ops.edge.et_vk.clamp_with_binary_add.default + if binary_op_node.target == exir_ops.edge.aten.sub.Tensor: + new_op = exir_ops.edge.et_vk.clamp_with_binary_sub.default + if binary_op_node.target == exir_ops.edge.aten.mul.Tensor: + new_op = exir_ops.edge.et_vk.clamp_with_binary_mul.default + if binary_op_node.target == exir_ops.edge.aten.div.Tensor: + new_op = exir_ops.edge.et_vk.clamp_with_binary_div.default + + assert(new_op != None) + + # Create and insert node of custom op `clamp_with_binary_op` + with graph_module.graph.inserting_before(preceding_op): + clamp_binary_op_node = graph_module.graph.create_node( + "call_function", + new_op, + new_args, + ) + + preceding_op.replace_all_uses_with(clamp_binary_op_node) + graph_module.graph.erase_node(preceding_op) + + fuseAdded = True + + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/transforms/fuse_clamps.py b/backends/transforms/fuse_clamps.py new file mode 100644 index 00000000000..d07a7646f0c --- /dev/null +++ b/backends/transforms/fuse_clamps.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +import executorch.backends.vulkan.custom_ops_lib # noqa + +import torch + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +class FuseClampsPass(ExportPass): + + FUSEABLE_CLAMPS = [ + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.clamp.default, + ] + + def get_output_min_max_from_activation(self, activation_node): + if activation_node.target == exir_ops.edge.aten.relu.default: + output_min = 0.0 + output_max = sys.float_info.max + elif activation_node.target == exir_ops.edge.aten.hardtanh.default: + output_min = -1.0 + output_max = 1.0 + if len(activation_node.args) > 1: + output_min = activation_node.args[1] + output_max = activation_node.args[2] + elif activation_node.target == exir_ops.edge.aten.clamp.default: + output_min = None + output_max = None + if len(activation_node.args) >= 2: + output_min = activation_node.args[1] + if len(activation_node.args) >= 3: + output_max = activation_node.args[2] + + return output_min, output_max + + + def call(self, graph_module: torch.fx.GraphModule): + fuseAdded = True + while fuseAdded: + fuseAdded = False + for clamp_2_node in graph_module.graph.nodes: + if clamp_2_node.op == "call_function": + if clamp_2_node.target in self.FUSEABLE_CLAMPS: + preceding_op = clamp_2_node.args[0] + if ( + preceding_op.op == "call_function" + and preceding_op.target in self.FUSEABLE_CLAMPS + ): + # Ensure the shapes match + if "val" not in clamp_2_node.args[0].meta or "val" not in preceding_op.args[0].meta: + continue + if len(clamp_2_node.args[0].meta["val"].shape) != len(preceding_op.args[0].meta["val"].shape): + continue + + min_max1 = self.get_output_min_max_from_activation(preceding_op) + min_max2 = self.get_output_min_max_from_activation(clamp_2_node) + + min_max = [None, None] + + if min_max1[0] == None and min_max2[0] != None: + min_max[0] = min_max2[0] + elif min_max1[0] != None and min_max2[0] == None: + min_max[0] = min_max1[0] + else: + min_max[0] = min(min_max1[0], min_max2[0]) + + if min_max1[1] == None and min_max2[1] != None: + min_max[1] = min_max2[1] + elif min_max1[1] != None and min_max2[1] == None: + min_max[1] = min_max1[1] + else: + min_max[1] = max(min_max1[1], min_max2[1]) + + new_args = list(preceding_op.args) + + # Insert the new min/max at indices 1 and 2 + new_args.insert(1, min_max[0]) + new_args.insert(2, min_max[1]) + new_args = new_args[0:3] + preceding_op.args = tuple(new_args) + clamp_2_node.replace_all_uses_with(preceding_op) + graph_module.graph.erase_node(clamp_2_node) + fuseAdded = True + + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/transforms/fuse_conv_with_binary_op.py b/backends/transforms/fuse_conv_with_binary_op.py new file mode 100644 index 00000000000..461d66531bc --- /dev/null +++ b/backends/transforms/fuse_conv_with_binary_op.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys + +import executorch.backends.vulkan.custom_ops_lib # noqa + +import torch + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + +class FuseConvBinaryOpPass(ExportPass): + """ + Some activations like ReLU and hardtanh can be fused with certain operators (e.g. convolution) preceding it. + """ + + FUSEABLE_OPS = [ + exir_ops.edge.aten.convolution.default, + ] + FUSEABLE_BINARY_OPS = [ + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.mul.Tensor, + exir_ops.edge.aten.div.Tensor, + ] + + def exists_before(self, graph_module, node_a, node_b): + seen_a = False + for n in graph_module.graph.nodes: + if n is node_a: + seen_a = True + if n is node_b: + return seen_a + return False + + + def call(self, graph_module: torch.fx.GraphModule): + + fuseAdded = True + while fuseAdded: + fuseAdded = False + for arg_idx in range(0, 2): + for binary_op_node in graph_module.graph.nodes: + if binary_op_node.op == "call_function": + if binary_op_node.target in self.FUSEABLE_BINARY_OPS: + preceding_op = binary_op_node.args[arg_idx] + if ( + preceding_op.op == "call_function" + and preceding_op.target in self.FUSEABLE_OPS + ): + + # For now only pw conv2d s1p0 is supported + if not (len(preceding_op.args[3]) == 2 and preceding_op.args[3][0] == 1 and preceding_op.args[3][1] == 1 and preceding_op.args[4][0] == 0 and preceding_op.args[4][1] == 0): + continue + + # Ensure the shapes match + if "val" not in binary_op_node.args[0].meta or "val" not in binary_op_node.args[1].meta: + continue + if len(binary_op_node.args[0].meta["val"].shape) != len(binary_op_node.args[1].meta["val"].shape): + continue + + + # Get the texture to do the binary op + texture = binary_op_node.args[(arg_idx + 1)%2] + + # Fuse only if the texture exists before the preceding op + if not self.exists_before(graph_module, texture, preceding_op): + continue + + new_args = list(preceding_op.args) + new_args.append(texture) + new_args = tuple(new_args) + binary_op_node.replace_all_uses_with(preceding_op) + graph_module.graph.erase_node(binary_op_node) + + new_op = None + if binary_op_node.target == exir_ops.edge.aten.add.Tensor: + new_op = exir_ops.edge.et_vk.conv_with_binary_add.default + if binary_op_node.target == exir_ops.edge.aten.sub.Tensor: + new_op = exir_ops.edge.et_vk.conv_with_binary_sub.default + if binary_op_node.target == exir_ops.edge.aten.mul.Tensor: + new_op = exir_ops.edge.et_vk.conv_with_binary_mul.default + if binary_op_node.target == exir_ops.edge.aten.div.Tensor: + new_op = exir_ops.edge.et_vk.conv_with_binary_div.default + + assert(new_op != None) + + # Create and insert node of custom op `conv_with_binary_op` + with graph_module.graph.inserting_before(preceding_op): + conv_binary_op_node = graph_module.graph.create_node( + "call_function", + new_op, + new_args, + ) + + preceding_op.replace_all_uses_with(conv_binary_op_node) + graph_module.graph.erase_node(preceding_op) + + fuseAdded = True + + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, True) diff --git a/backends/transforms/fuse_conv_with_clamp.py b/backends/transforms/fuse_conv_with_clamp.py index 3f45296b26c..824d67a5c88 100644 --- a/backends/transforms/fuse_conv_with_clamp.py +++ b/backends/transforms/fuse_conv_with_clamp.py @@ -14,7 +14,7 @@ from executorch.exir.pass_base import ExportPass, PassResult -class FuseClampPass(ExportPass): +class FuseConvClampPass(ExportPass): """ Some activations like ReLU and hardtanh can be fused with certain operators (e.g. convolution) preceding it. """ @@ -37,6 +37,13 @@ def get_output_min_max_from_activation(self, activation_node): if len(activation_node.args) > 1: output_min = activation_node.args[1] output_max = activation_node.args[2] + elif activation_node.target == exir_ops.edge.aten.clamp.default: + output_min = None + output_max = None + if len(activation_node.args) >= 2: + output_min = activation_node.args[1] + if len(activation_node.args) >= 3: + output_max = activation_node.args[2] return output_min, output_max diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index ca09d34c2fe..f996fec0efc 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -77,6 +77,54 @@ def define_common_targets(): ], ) + runtime.python_library( + name = "fuse_conv_with_binary_op", + srcs = ["fuse_conv_with_binary_op.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + ":utils", + "//caffe2:torch", + "//executorch/backends/vulkan:custom_ops_lib", + "//executorch/exir:pass_base", + "//executorch/exir:sym_util", + "//executorch/exir/dialects:lib", + ], + ) + + runtime.python_library( + name = "fuse_clamps", + srcs = ["fuse_clamps.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + ":utils", + "//caffe2:torch", + "//executorch/backends/vulkan:custom_ops_lib", + "//executorch/exir:pass_base", + "//executorch/exir:sym_util", + "//executorch/exir/dialects:lib", + ], + ) + + runtime.python_library( + name = "fuse_clamp_with_binary_op", + srcs = ["fuse_clamp_with_binary_op.py"], + visibility = [ + "//executorch/backends/...", + ], + deps = [ + ":utils", + "//caffe2:torch", + "//executorch/backends/vulkan:custom_ops_lib", + "//executorch/exir:pass_base", + "//executorch/exir:sym_util", + "//executorch/exir/dialects:lib", + ], + ) + runtime.python_library( name = "view_copy_to_squeeze_unsqueeze", srcs = ["view_copy_to_squeeze_unsqueeze.py"], diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index 56e803b9127..c4e49ba9428 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -109,6 +109,542 @@ def conv_with_clamp_out_impl( ) lib.impl(name, conv_with_clamp_out_impl, "CompositeExplicitAutograd") +########################## +## conv_with_binary_add ## +########################## + + +def conv_with_binary_add_impl( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + transposed=False, + output_padding=0, + groups=1, + other=None, +): + return torch.add( + torch.convolution( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ), + other, + ) + + +name = "conv_with_binary_add" +lib.define( + f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other) -> Tensor" +) +lib.impl(name, conv_with_binary_add_impl, "CompositeExplicitAutograd") +conv_with_binary_add_op = getattr(getattr(torch.ops, namespace), name) + +############################# +## conv_with_binary_add.out ## +############################# + + +def conv_with_binary_add_out_impl( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + transposed=False, + output_padding=0, + groups=1, + other=None, + out=None, +): + out = conv_with_binary_add_impl( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + other, + ) + return out + + +name = "conv_with_binary_add.out" +lib.define( + f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, conv_with_binary_add_out_impl, "CompositeExplicitAutograd") + +########################## +## conv_with_binary_sub ## +########################## + + +def conv_with_binary_sub_impl( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + transposed=False, + output_padding=0, + groups=1, + other=None, +): + return torch.sub( + torch.convolution( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ), + other, + ) + + +name = "conv_with_binary_sub" +lib.define( + f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other) -> Tensor" +) +lib.impl(name, conv_with_binary_sub_impl, "CompositeExplicitAutograd") +conv_with_binary_sub_op = getattr(getattr(torch.ops, namespace), name) + +############################## +## conv_with_binary_sub.out ## +############################## + + +def conv_with_binary_sub_out_impl( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + transposed=False, + output_padding=0, + groups=1, + other=None, + out=None, +): + out = conv_with_binary_sub_impl( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + other, + ) + return out + + +name = "conv_with_binary_sub.out" +lib.define( + f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, conv_with_binary_sub_out_impl, "CompositeExplicitAutograd") + +########################## +## conv_with_binary_mul ## +########################## + + +def conv_with_binary_mul_impl( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + transposed=False, + output_padding=0, + groups=1, + other=None, +): + return torch.mul( + torch.convolution( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ), + other, + ) + + +name = "conv_with_binary_mul" +lib.define( + f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other) -> Tensor" +) +lib.impl(name, conv_with_binary_mul_impl, "CompositeExplicitAutograd") +conv_with_binary_mul_op = getattr(getattr(torch.ops, namespace), name) + +############################## +## conv_with_binary_mul.out ## +############################## + + +def conv_with_binary_mul_out_impl( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + transposed=False, + output_padding=0, + groups=1, + other=None, + out=None, +): + out = conv_with_binary_mul_impl( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + other, + ) + return out + + +name = "conv_with_binary_mul.out" +lib.define( + f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, conv_with_binary_mul_out_impl, "CompositeExplicitAutograd") + +########################## +## conv_with_binary_div ## +########################## + + +def conv_with_binary_div_impl( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + transposed=False, + output_padding=0, + groups=1, + other=None, +): + return torch.div( + torch.convolution( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + ), + other, + ) + + +name = "conv_with_binary_div" +lib.define( + f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other) -> Tensor" +) +lib.impl(name, conv_with_binary_div_impl, "CompositeExplicitAutograd") +conv_with_binary_div_op = getattr(getattr(torch.ops, namespace), name) + +############################## +## conv_with_binary_div.out ## +############################## + + +def conv_with_binary_div_out_impl( + input, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + transposed=False, + output_padding=0, + groups=1, + other=None, + out=None, +): + out = conv_with_binary_div_impl( + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + other, + ) + return out + + +name = "conv_with_binary_div.out" +lib.define( + f"{name}(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, Tensor other, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, conv_with_binary_div_out_impl, "CompositeExplicitAutograd") + +########################### +## clamp_with_binary_add ## +########################### + + +def clamp_with_binary_add_impl( + input, + output_min=-float("inf"), + output_max=float("inf"), + other=None, +): + return torch.add( + torch.clamp( + input, + output_min, + output_max, + ), + other, + ) + + +name = "clamp_with_binary_add" +lib.define( + f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other) -> Tensor" +) +lib.impl(name, clamp_with_binary_add_impl, "CompositeExplicitAutograd") +clamp_with_binary_add_op = getattr(getattr(torch.ops, namespace), name) + +############################### +## clamp_with_binary_add.out ## +############################### + + +def clamp_with_binary_add_out_impl( + input, + output_min=-float("inf"), + output_max=float("inf"), + other=None, + out=None, +): + out = clamp_with_binary_add_impl( + input, + output_min, + output_max, + other, + ) + return out + + +name = "clamp_with_binary_add.out" +lib.define( + f"{name}(Tensor input, Tensor weight, Scalar? output_min, Scalar? output_max, Tensor? other, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, clamp_with_binary_add_out_impl, "CompositeExplicitAutograd") + +########################### +## clamp_with_binary_sub ## +########################### + + +def clamp_with_binary_sub_impl( + input, + output_min=-float("inf"), + output_max=float("inf"), + other=None, +): + return torch.sub( + torch.clamp( + input, + output_min, + output_max, + ), + other, + ) + + +name = "clamp_with_binary_sub" +lib.define( + f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other) -> Tensor" +) +lib.impl(name, clamp_with_binary_sub_impl, "CompositeExplicitAutograd") +clamp_with_binary_sub_op = getattr(getattr(torch.ops, namespace), name) + +############################### +## clamp_with_binary_sub.out ## +############################### + + +def clamp_with_binary_sub_out_impl( + input, + output_min=-float("inf"), + output_max=float("inf"), + other=None, + out=None, +): + out = clamp_with_binary_sub_impl( + input, + output_min, + output_max, + other, + ) + return out + + +name = "clamp_with_binary_sub.out" +lib.define( + f"{name}(Tensor input, Tensor weight, Scalar? output_min, Scalar? output_max, Tensor? other, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, clamp_with_binary_sub_out_impl, "CompositeExplicitAutograd") + +########################### +## clamp_with_binary_mul ## +########################### + + +def clamp_with_binary_mul_impl( + input, + output_min=-float("inf"), + output_max=float("inf"), + other=None, +): + return torch.mul( + torch.clamp( + input, + output_min, + output_max, + ), + other, + ) + + +name = "clamp_with_binary_mul" +lib.define( + f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other) -> Tensor" +) +lib.impl(name, clamp_with_binary_mul_impl, "CompositeExplicitAutograd") +clamp_with_binary_mul_op = getattr(getattr(torch.ops, namespace), name) + +############################### +## clamp_with_binary_mul.out ## +############################### + + +def clamp_with_binary_mul_out_impl( + input, + output_min=-float("inf"), + output_max=float("inf"), + other=None, + out=None, +): + out = clamp_with_binary_mul_impl( + input, + output_min, + output_max, + other, + ) + return out + + +name = "clamp_with_binary_mul.out" +lib.define( + f"{name}(Tensor input, Tensor weight, Scalar? output_min, Scalar? output_max, Tensor? other, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, clamp_with_binary_mul_out_impl, "CompositeExplicitAutograd") + +########################### +## clamp_with_binary_div ## +########################### + + +def clamp_with_binary_div_impl( + input, + output_min=-float("inf"), + output_max=float("inf"), + other=None, +): + return torch.div( + torch.clamp( + input, + output_min, + output_max, + ), + other, + ) + + +name = "clamp_with_binary_div" +lib.define( + f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other) -> Tensor" +) +lib.impl(name, clamp_with_binary_div_impl, "CompositeExplicitAutograd") +clamp_with_binary_div_op = getattr(getattr(torch.ops, namespace), name) + +############################### +## clamp_with_binary_div.out ## +############################### + + +def clamp_with_binary_div_out_impl( + input, + output_min=-float("inf"), + output_max=float("inf"), + other=None, + out=None, +): + out = clamp_with_binary_div_impl( + input, + output_min, + output_max, + other, + ) + return out + + +name = "clamp_with_binary_div.out" +lib.define( + f"{name}(Tensor input, Tensor weight, Scalar? output_min, Scalar? output_max, Tensor? other, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, clamp_with_binary_div_out_impl, "CompositeExplicitAutograd") + ################# ## grid_priors ## ################# diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 4c686e0cfc5..62c9f6c18b9 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -238,6 +238,10 @@ def register_binary_op(): exir_ops.edge.aten.tanh.default, exir_ops.edge.aten.round.default, exir_ops.edge.aten.leaky_relu.default, + exir_ops.edge.et_vk.clamp_with_binary_add.default, + exir_ops.edge.et_vk.clamp_with_binary_sub.default + exir_ops.edge.et_vk.clamp_with_binary_mul.default, + exir_ops.edge.et_vk.clamp_with_binary_div.default ] ) def register_unary_op(): @@ -471,6 +475,10 @@ def register_2d_pool_op(): [ exir_ops.edge.aten.convolution.default, exir_ops.edge.et_vk.conv_with_clamp.default, + exir_ops.edge.et_vk.conv_with_binary_add.default, + exir_ops.edge.et_vk.conv_with_binary_sub.default, + exir_ops.edge.et_vk.conv_with_binary_mul.default, + exir_ops.edge.et_vk.conv_with_binary_div.default, ] ) def register_convolution_op(): @@ -487,6 +495,7 @@ def register_convolution_op(): utils.NO_STORAGE, # groups (non tensor) utils.NO_STORAGE, # output_min (non tensor) utils.NO_STORAGE, # output_max (non tensor) + utils.NO_STORAGE, # other (prepacked) ], supports_resize=True, supports_prepacking=True, diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl index ef50a1aca9f..79577a0e2a2 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl @@ -30,6 +30,10 @@ ${layout_declare_tensor(1, "r", "t_in", DTYPE, "texture3d")} ${layout_declare_tensor(2, "r", "t_kernel", DTYPE, "texture2d")} ${layout_declare_tensor(3, "r", "t_bias", DTYPE, "texture2d")} +$if FUSED_OP != "A": + #define fused_op(A, B) ${FUSED_OP} + ${layout_declare_tensor(4, "r", "t_other", DTYPE, "texture3d")} + layout(push_constant) uniform restrict Block { ivec4 out_limits; ivec2 stride; @@ -137,5 +141,9 @@ void main() { outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); } - imageStore(t_out, ivec3(xIdx, yIdx, threadOutChannel), op(vec4(outputTexel), out_min, out_max)); +$if FUSED_OP != "A": + VEC4_T otherTexel = VEC4_T(texelFetch(t_other, ivec3(xIdx, yIdx, threadOutChannel), 0)); + fused_op(outputTexel, otherTexel); + +imageStore(t_out, ivec3(xIdx, yIdx, threadOutChannel), op(vec4(outputTexel), out_min, out_max)); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml index bab3c715540..0cc872e0392 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml @@ -9,6 +9,8 @@ conv2d_pw_s1p0: OPERATOR: X NDIM: 3 DTYPE: float + FUSED_OP: A + IS_FUSED: 0 generate_variant_forall: DTYPE: - VALUE: half @@ -17,3 +19,15 @@ conv2d_pw_s1p0: - NAME: conv2d_pw_s1p0 - NAME: conv2d_pw_s1p0_clamp OPERATOR: clamp(X, A, B) + - NAME: conv2d_pw_s1p0_add + FUSED_OP: A += B + IS_FUSED: 1 + - NAME: conv2d_pw_s1p0_sub + FUSED_OP: A -= B + IS_FUSED: 1 + - NAME: conv2d_pw_s1p0_mul + FUSED_OP: A *= B + IS_FUSED: 1 + - NAME: conv2d_pw_s1p0_div + FUSED_OP: A /= B + IS_FUSED: 1 diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl b/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl index bb7ce482a7a..d5df5bf00cf 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl @@ -17,6 +17,9 @@ ${define_active_storage_type(STORAGE)} +$if BINARY_OP != "A": + #define binary_op(A, B) ${BINARY_OP} + #include "indexing_utils.h" ${define_required_extensions(DTYPE)} @@ -25,6 +28,8 @@ layout(std430) buffer; ${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} ${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)} +$if BINARY_OP != "A": + ${layout_declare_tensor(2, "r", "t_other", DTYPE, STORAGE)} layout(push_constant) uniform restrict Block { $if STORAGE == "buffer": @@ -61,7 +66,11 @@ void main() { } VEC4_T in_texel = texelFetch(t_in, pos, 0); + +$if BINARY_OP == "A": imageStore(t_out, pos, VEC4_T(op(in_texel, minimum, maximum))); +$else: + imageStore(t_out, pos, VEC4_T(binary_op(op(in_texel, minimum, maximum), texelFetch(t_other, pos, 0)))); } #endif diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index 47f538aee6c..e22affce11e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -3,6 +3,7 @@ unary_op: OPERATOR: clamp(X, A, B) DTYPE: float STORAGE: texture3d + BINARY_OP: A generate_variant_forall: DTYPE: - VALUE: half @@ -15,6 +16,18 @@ unary_op: OPERATOR: abs(X) - NAME: clamp OPERATOR: clamp(X, A, B) + - NAME: clamp_add + OPERATOR: clamp(X, A, B) + BINARY_OP: A + B + - NAME: clamp_sub + OPERATOR: clamp(X, A, B) + BINARY_OP: A - B + - NAME: clamp_mul + OPERATOR: clamp(X, A, B) + BINARY_OP: A * B + - NAME: clamp_div + OPERATOR: clamp(X, A, B) + BINARY_OP: A / B - NAME: clamp_int32 OPERATOR: clamp(X, A, B) DTYPE: int32 diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 479bb44ae6f..0fbcdcedfc2 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -131,7 +131,8 @@ vkapi::ShaderInfo get_conv2d_shader( const ValueRef weight, const bool clamp_out = false, const bool stride_equals_dilation = false, - const bool stride_1_padding_0 = false) { + const bool stride_1_padding_0 = false, + const std::string& op_name = "") { std::string kernel_name; kernel_name.reserve(kShaderNameReserve); switch (method) { @@ -156,6 +157,19 @@ vkapi::ShaderInfo get_conv2d_shader( } else { kernel_name = stride_1_padding_0 ? "conv2d_pw_s1p0" : "conv2d_pw"; } + + // For now, binary op fusing is only supported for Conv2D PW s1p0 + if (stride_1_padding_0) { + if (op_name == "conv_add") { + kernel_name += "_add"; + } else if (op_name == "conv_sub") { + kernel_name += "_sub"; + } else if (op_name == "conv_mul") { + kernel_name += "_mul"; + } else if (op_name == "conv_div") { + kernel_name += "_div"; + } + } break; case Conv2dMethod::SlidingWindow: kernel_name = "conv2d"; @@ -445,8 +459,10 @@ void add_conv2d_node( const ValueRef groups, const ValueRef out_min, const ValueRef out_max, + const ValueRef binary_op_other, const ValueRef out, - const bool clamp_out) { + const bool clamp_out, + const std::string& op_name) { const bool transposed_val = graph.get_bool(transposed); float out_min_val = 0.0f; @@ -509,7 +525,8 @@ void add_conv2d_node( weight_data, clamp_out, stride_equals_dilation, - stride_1_padding_0); + stride_1_padding_0, + op_name); utils::uvec3 wg_size = create_conv2d_global_wg_size( graph, method, out, weight_data, stride_equals_dilation); @@ -591,13 +608,27 @@ void add_conv2d_node( }; } + ValueRef arg_binary_op_other = binary_op_other; + if (binary_op_other != kDummyValueRef) { + arg_binary_op_other = + prepack_standard_like(graph, binary_op_other, out, true); + + check_conv_args(graph, arg_binary_op_other, out); + } + + std::vector args = { + {out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}}; + std::vector args_with_binary_op = { + {out, vkapi::kWrite}, + {{in, arg_weight, arg_bias, arg_binary_op_other}, vkapi::kRead}}; + graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, shader, conv2d_global_wg_size, conv2d_local_wg_size, // Inputs and Outputs - {{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}}, + binary_op_other == kDummyValueRef ? args : args_with_binary_op, // Shader params buffers param_buffers, // Push Constants @@ -710,7 +741,10 @@ void add_conv1d_node( resize_conv1d_node)); } -void conv(ComputeGraph& graph, const std::vector& args) { +void conv( + ComputeGraph& graph, + const std::vector& args, + const std::string& op_name) { int64_t in_ndim = graph.dim_of(args[0]); if (in_ndim == 4) { if (args.size() == 10) { @@ -728,8 +762,29 @@ void conv(ComputeGraph& graph, const std::vector& args) { args[8], /*out_min = */ kDummyValueRef, /*out_max = */ kDummyValueRef, + /*binary_op_other = */ kDummyValueRef, args[9], - false); + false, + op_name); + } else if (args.size() == 11) { + // conv2d with binary op + return add_conv2d_node( + graph, + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[6], + args[7], + args[8], + /*out_min = */ kDummyValueRef, + /*out_max = */ kDummyValueRef, + args[9], + args[10], + false, + op_name); } else { // conv2d with clamp return add_conv2d_node( @@ -745,8 +800,10 @@ void conv(ComputeGraph& graph, const std::vector& args) { args[8], args[9], args[10], + /*binary_op_other = */ kDummyValueRef, args[11], - true); + true, + op_name); } } else { if (args.size() == 10) { @@ -783,10 +840,25 @@ void conv(ComputeGraph& graph, const std::vector& args) { } } +#define DEFINE_CONV_BINARY_OP_FN(op_name) \ + void op_name(ComputeGraph& graph, const std::vector& args) { \ + return conv(graph, args, #op_name); \ + } + +DEFINE_CONV_BINARY_OP_FN(conv_no_op); +DEFINE_CONV_BINARY_OP_FN(conv_add); +DEFINE_CONV_BINARY_OP_FN(conv_sub); +DEFINE_CONV_BINARY_OP_FN(conv_mul); +DEFINE_CONV_BINARY_OP_FN(conv_div); + REGISTER_OPERATORS { - VK_REGISTER_OP(aten.convolution.default, conv); - VK_REGISTER_OP(conv_with_clamp.default, conv); - VK_REGISTER_OP(et_vk.conv_with_clamp.default, conv); + VK_REGISTER_OP(aten.convolution.default, conv_no_op); + VK_REGISTER_OP(conv_with_clamp.default, conv_no_op); + VK_REGISTER_OP(et_vk.conv_with_clamp.default, conv_no_op); + VK_REGISTER_OP(et_vk.conv_with_binary_add.default, conv_add); + VK_REGISTER_OP(et_vk.conv_with_binary_sub.default, conv_sub); + VK_REGISTER_OP(et_vk.conv_with_binary_mul.default, conv_mul); + VK_REGISTER_OP(et_vk.conv_with_binary_div.default, conv_div); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index 9830a8e8784..f64133fec01 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -39,11 +39,16 @@ void add_unary_op_node( const float min, const float max, const ValueRef out, - const std::string& op_name) { + const std::string& op_name, + const ValueRef other = kDummyValueRef) { std::string kernel_name(op_name); add_dtype_suffix(kernel_name, graph.dtype_of(out)); add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); + std::vector args = {{out, vkapi::kWrite}, {in, vkapi::kRead}}; + std::vector args_with_binary_op = { + {out, vkapi::kWrite}, {{in, other}, vkapi::kRead}}; + const utils::vec2 min_max = {min, max}; graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, @@ -51,7 +56,7 @@ void add_unary_op_node( default_pick_global_wg_size, default_pick_local_wg_size, // Inputs and Outputs - {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + other == kDummyValueRef ? args : args_with_binary_op, // Shader params buffers {}, // Push Constants @@ -94,6 +99,18 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) { kClampShaderName); \ } +#define DEFINE_CLAMP_BINARY_OP_FN(op_name) \ + void op_name(ComputeGraph& graph, const std::vector& args) { \ + return add_unary_op_node( \ + graph, \ + args[0], \ + get_val_or_inf(graph, args[1], /*max = */ false), \ + get_val_or_inf(graph, args[2], /*max = */ true), \ + args[args.size() - 1], \ + #op_name, \ + args[3]); \ + } + #define DEFINE_RELU_FN(op_name) \ void op_name(ComputeGraph& graph, const std::vector& args) { \ return add_unary_op_node( \ @@ -159,6 +176,11 @@ DEFINE_ACTIVATION_FN(hardsigmoid); DEFINE_LEAKY_RELU_FN(leaky_relu); DEFINE_ACTIVATION_FN(round); +DEFINE_CLAMP_BINARY_OP_FN(clamp_add); +DEFINE_CLAMP_BINARY_OP_FN(clamp_sub); +DEFINE_CLAMP_BINARY_OP_FN(clamp_mul); +DEFINE_CLAMP_BINARY_OP_FN(clamp_div); + REGISTER_OPERATORS { VK_REGISTER_OP(aten.abs.default, abs); VK_REGISTER_OP(aten.clamp.default, clamp); @@ -179,6 +201,11 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.hardsigmoid.default, hardsigmoid); VK_REGISTER_OP(aten.leaky_relu.default, leaky_relu); VK_REGISTER_OP(aten.round.default, round); + + VK_REGISTER_OP(et_vk.clamp_with_binary_add.default, clamp_add); + VK_REGISTER_OP(et_vk.clamp_with_binary_sub.default, clamp_sub); + VK_REGISTER_OP(et_vk.clamp_with_binary_mul.default, clamp_mul); + VK_REGISTER_OP(et_vk.clamp_with_binary_div.default, clamp_div); } } // namespace vkcompute diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index a9ba62b6f9f..8a95134578d 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -381,6 +381,9 @@ def define_common_targets(is_fbcode = False): deps = [ "//executorch/backends/transforms:addmm_mm_to_linear", "//executorch/backends/transforms:fuse_batch_norm_with_conv", + "//executorch/backends/transforms:fuse_clamp_with_binary_op", + "//executorch/backends/transforms:fuse_clamps", + "//executorch/backends/transforms:fuse_conv_with_binary_op", "//executorch/backends/transforms:fuse_conv_with_clamp", "//executorch/backends/transforms:fuse_view_copy", "//executorch/backends/transforms:remove_clone_ops", diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 95da66494e0..494351d0fb0 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -13,7 +13,10 @@ import executorch.backends.vulkan.utils as utils from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform -from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass +from executorch.backends.transforms.fuse_conv_with_clamp import FuseConvClampPass +from executorch.backends.transforms.fuse_conv_with_binary_op import FuseConvBinaryOpPass +from executorch.backends.transforms.fuse_clamp_with_binary_op import FuseClampBinaryOpPass +from executorch.backends.transforms.fuse_clamps import FuseClampsPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform from executorch.backends.transforms.view_copy_to_squeeze_unsqueeze import ( ViewCopyToSqueezeUnsqueezePass, @@ -167,7 +170,10 @@ def preprocess( # noqa: C901 FuseViewCopyTransform(), ViewCopyToSqueezeUnsqueezePass(), FuseBatchNormPass(program), - FuseClampPass(), + FuseClampsPass(), + FuseConvClampPass(), + FuseConvBinaryOpPass(), + FuseClampBinaryOpPass(), ], ) From 7e8d9b12f25ac7339a8ab3b0130c190ecfedd4d1 Mon Sep 17 00:00:00 2001 From: Alexander Dean Date: Thu, 18 Sep 2025 14:39:59 -0500 Subject: [PATCH 5/8] Lint fixes --- .../transforms/fuse_clamp_with_binary_op.py | 4 +- backends/transforms/fuse_clamps.py | 10 +- .../transforms/fuse_conv_with_binary_op.py | 108 +++++++++--------- backends/vulkan/op_registry.py | 2 +- 4 files changed, 58 insertions(+), 66 deletions(-) diff --git a/backends/transforms/fuse_clamp_with_binary_op.py b/backends/transforms/fuse_clamp_with_binary_op.py index 8e24f482695..62f0a56fbe4 100644 --- a/backends/transforms/fuse_clamp_with_binary_op.py +++ b/backends/transforms/fuse_clamp_with_binary_op.py @@ -78,7 +78,7 @@ def call(self, graph_module: torch.fx.GraphModule): continue # Get the texture to do the binary op - texture = binary_op_node.args[(arg_idx + 1)%2] + texture = binary_op_node.args[(arg_idx + 1) % 2] # Fuse only if the texture exists before the preceding op if not self.exists_before(graph_module, texture, preceding_op): @@ -111,8 +111,6 @@ def call(self, graph_module: torch.fx.GraphModule): if binary_op_node.target == exir_ops.edge.aten.div.Tensor: new_op = exir_ops.edge.et_vk.clamp_with_binary_div.default - assert(new_op != None) - # Create and insert node of custom op `clamp_with_binary_op` with graph_module.graph.inserting_before(preceding_op): clamp_binary_op_node = graph_module.graph.create_node( diff --git a/backends/transforms/fuse_clamps.py b/backends/transforms/fuse_clamps.py index d07a7646f0c..0191c994747 100644 --- a/backends/transforms/fuse_clamps.py +++ b/backends/transforms/fuse_clamps.py @@ -54,7 +54,7 @@ def call(self, graph_module: torch.fx.GraphModule): preceding_op.op == "call_function" and preceding_op.target in self.FUSEABLE_CLAMPS ): - # Ensure the shapes match + # Ensure the shapes match if "val" not in clamp_2_node.args[0].meta or "val" not in preceding_op.args[0].meta: continue if len(clamp_2_node.args[0].meta["val"].shape) != len(preceding_op.args[0].meta["val"].shape): @@ -65,16 +65,16 @@ def call(self, graph_module: torch.fx.GraphModule): min_max = [None, None] - if min_max1[0] == None and min_max2[0] != None: + if min_max1[0] is None and min_max2[0] is not None: min_max[0] = min_max2[0] - elif min_max1[0] != None and min_max2[0] == None: + elif min_max1[0] is not None and min_max2[0] is None: min_max[0] = min_max1[0] else: min_max[0] = min(min_max1[0], min_max2[0]) - if min_max1[1] == None and min_max2[1] != None: + if min_max1[1] is None and min_max2[1] is not None: min_max[1] = min_max2[1] - elif min_max1[1] != None and min_max2[1] == None: + elif min_max1[1] is not None and min_max2[1] is None: min_max[1] = min_max1[1] else: min_max[1] = max(min_max1[1], min_max2[1]) diff --git a/backends/transforms/fuse_conv_with_binary_op.py b/backends/transforms/fuse_conv_with_binary_op.py index 461d66531bc..6cf5d6054a2 100644 --- a/backends/transforms/fuse_conv_with_binary_op.py +++ b/backends/transforms/fuse_conv_with_binary_op.py @@ -4,8 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import sys - import executorch.backends.vulkan.custom_ops_lib # noqa import torch @@ -45,62 +43,58 @@ def call(self, graph_module: torch.fx.GraphModule): fuseAdded = False for arg_idx in range(0, 2): for binary_op_node in graph_module.graph.nodes: - if binary_op_node.op == "call_function": - if binary_op_node.target in self.FUSEABLE_BINARY_OPS: - preceding_op = binary_op_node.args[arg_idx] - if ( - preceding_op.op == "call_function" - and preceding_op.target in self.FUSEABLE_OPS - ): - - # For now only pw conv2d s1p0 is supported - if not (len(preceding_op.args[3]) == 2 and preceding_op.args[3][0] == 1 and preceding_op.args[3][1] == 1 and preceding_op.args[4][0] == 0 and preceding_op.args[4][1] == 0): - continue - - # Ensure the shapes match - if "val" not in binary_op_node.args[0].meta or "val" not in binary_op_node.args[1].meta: - continue - if len(binary_op_node.args[0].meta["val"].shape) != len(binary_op_node.args[1].meta["val"].shape): - continue - - - # Get the texture to do the binary op - texture = binary_op_node.args[(arg_idx + 1)%2] - - # Fuse only if the texture exists before the preceding op - if not self.exists_before(graph_module, texture, preceding_op): - continue - - new_args = list(preceding_op.args) - new_args.append(texture) - new_args = tuple(new_args) - binary_op_node.replace_all_uses_with(preceding_op) - graph_module.graph.erase_node(binary_op_node) - - new_op = None - if binary_op_node.target == exir_ops.edge.aten.add.Tensor: - new_op = exir_ops.edge.et_vk.conv_with_binary_add.default - if binary_op_node.target == exir_ops.edge.aten.sub.Tensor: - new_op = exir_ops.edge.et_vk.conv_with_binary_sub.default - if binary_op_node.target == exir_ops.edge.aten.mul.Tensor: - new_op = exir_ops.edge.et_vk.conv_with_binary_mul.default - if binary_op_node.target == exir_ops.edge.aten.div.Tensor: - new_op = exir_ops.edge.et_vk.conv_with_binary_div.default - - assert(new_op != None) - - # Create and insert node of custom op `conv_with_binary_op` - with graph_module.graph.inserting_before(preceding_op): - conv_binary_op_node = graph_module.graph.create_node( - "call_function", - new_op, - new_args, - ) - - preceding_op.replace_all_uses_with(conv_binary_op_node) - graph_module.graph.erase_node(preceding_op) + if binary_op_node.op == "call_function" and binary_op_node.target in self.FUSEABLE_BINARY_OPS: + preceding_op = binary_op_node.args[arg_idx] + if ( + preceding_op.op == "call_function" + and preceding_op.target in self.FUSEABLE_OPS + ): + + # For now only pw conv2d s1p0 is supported + if not (len(preceding_op.args[3]) == 2 and preceding_op.args[3][0] == 1 and preceding_op.args[3][1] == 1 and preceding_op.args[4][0] == 0 and preceding_op.args[4][1] == 0): + continue + + # Ensure the shapes match + if "val" not in binary_op_node.args[0].meta or "val" not in binary_op_node.args[1].meta: + continue + if len(binary_op_node.args[0].meta["val"].shape) != len(binary_op_node.args[1].meta["val"].shape): + continue + + # Get the texture to do the binary op + texture = binary_op_node.args[(arg_idx + 1) % 2] + + # Fuse only if the texture exists before the preceding op + if not self.exists_before(graph_module, texture, preceding_op): + continue + + new_args = list(preceding_op.args) + new_args.append(texture) + new_args = tuple(new_args) + binary_op_node.replace_all_uses_with(preceding_op) + graph_module.graph.erase_node(binary_op_node) + + new_op = None + if binary_op_node.target == exir_ops.edge.aten.add.Tensor: + new_op = exir_ops.edge.et_vk.conv_with_binary_add.default + if binary_op_node.target == exir_ops.edge.aten.sub.Tensor: + new_op = exir_ops.edge.et_vk.conv_with_binary_sub.default + if binary_op_node.target == exir_ops.edge.aten.mul.Tensor: + new_op = exir_ops.edge.et_vk.conv_with_binary_mul.default + if binary_op_node.target == exir_ops.edge.aten.div.Tensor: + new_op = exir_ops.edge.et_vk.conv_with_binary_div.default + + # Create and insert node of custom op `conv_with_binary_op` + with graph_module.graph.inserting_before(preceding_op): + conv_binary_op_node = graph_module.graph.create_node( + "call_function", + new_op, + new_args, + ) + + preceding_op.replace_all_uses_with(conv_binary_op_node) + graph_module.graph.erase_node(preceding_op) - fuseAdded = True + fuseAdded = True graph_module.recompile() graph_module = super().call(graph_module).graph_module diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 62c9f6c18b9..ea1f82173d6 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -239,7 +239,7 @@ def register_binary_op(): exir_ops.edge.aten.round.default, exir_ops.edge.aten.leaky_relu.default, exir_ops.edge.et_vk.clamp_with_binary_add.default, - exir_ops.edge.et_vk.clamp_with_binary_sub.default + exir_ops.edge.et_vk.clamp_with_binary_sub.default, exir_ops.edge.et_vk.clamp_with_binary_mul.default, exir_ops.edge.et_vk.clamp_with_binary_div.default ] From ffb0cdf5704d4da2e0cf70e410530ae5a771fc38 Mon Sep 17 00:00:00 2001 From: Alexander Dean Date: Mon, 22 Sep 2025 15:00:36 -0500 Subject: [PATCH 6/8] Remove conv + binary ops fusing and change binary op + clamp fusing --- .../transforms/fuse_clamp_with_binary_op.py | 129 +++++----- backends/transforms/fuse_clamps.py | 21 +- .../transforms/fuse_conv_with_binary_op.py | 102 -------- backends/transforms/fuse_conv_with_clamp.py | 1 + backends/transforms/targets.bzl | 16 -- backends/vulkan/custom_ops_lib.py | 229 +++++++++++++++++- backends/vulkan/op_registry.py | 11 +- .../runtime/graph/ops/glsl/binary_op.glsl | 59 ++++- .../graph/ops/glsl/conv2d_pw_s1p0.glsl | 10 +- .../graph/ops/glsl/conv2d_pw_s1p0.yaml | 12 - .../runtime/graph/ops/glsl/unary_op.glsl | 8 - .../runtime/graph/ops/glsl/unary_op.yaml | 13 - .../runtime/graph/ops/impl/BinaryOp.cpp | 102 +++++++- .../runtime/graph/ops/impl/Convolution.cpp | 92 +------ .../vulkan/runtime/graph/ops/impl/UnaryOp.cpp | 31 +-- backends/vulkan/targets.bzl | 1 - backends/vulkan/vulkan_preprocess.py | 8 +- 17 files changed, 469 insertions(+), 376 deletions(-) delete mode 100644 backends/transforms/fuse_conv_with_binary_op.py diff --git a/backends/transforms/fuse_clamp_with_binary_op.py b/backends/transforms/fuse_clamp_with_binary_op.py index 62f0a56fbe4..4155b2b7458 100644 --- a/backends/transforms/fuse_clamp_with_binary_op.py +++ b/backends/transforms/fuse_clamp_with_binary_op.py @@ -13,9 +13,10 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult + class FuseClampBinaryOpPass(ExportPass): - FUSEABLE_OPS = [ + FUSEABLE_CLAMP_OPS = [ exir_ops.edge.aten.relu.default, exir_ops.edge.aten.hardtanh.default, exir_ops.edge.aten.clamp.default, @@ -55,76 +56,68 @@ def get_output_min_max_from_activation(self, activation_node): output_max = activation_node.args[2] return output_min, output_max - + + def fuse_binary_op_with_clamp(self, graph_module: torch.fx.GraphModule): + fuseAdded = False + for clamp_node in graph_module.graph.nodes: + if clamp_node.op == "call_function": + if clamp_node.target in self.FUSEABLE_CLAMP_OPS: + preceding_op = clamp_node.args[0] + + if ( + preceding_op.op == "call_function" + and preceding_op.target in self.FUSEABLE_BINARY_OPS + ): + # Delete activation + output_min_max = self.get_output_min_max_from_activation( + clamp_node + ) + new_args = list(preceding_op.args) + new_args.append(output_min_max[0]) + new_args.append(output_min_max[1]) + new_args = tuple(new_args) + clamp_node.replace_all_uses_with(preceding_op) + graph_module.graph.erase_node(clamp_node) + + new_op = None + match preceding_op.target: + case exir_ops.edge.aten.add.Tensor: + new_op = ( + exir_ops.edge.et_vk.binary_add_with_clamp.default + ) + case exir_ops.edge.aten.sub.Tensor: + new_op = ( + exir_ops.edge.et_vk.binary_sub_with_clamp.default + ) + case exir_ops.edge.aten.mul.Tensor: + new_op = ( + exir_ops.edge.et_vk.binary_mul_with_clamp.default + ) + case exir_ops.edge.aten.div.Tensor: + new_op = ( + exir_ops.edge.et_vk.binary_div_with_clamp.default + ) + + # Create and insert node of custom op `binary__with_clamp` + with graph_module.graph.inserting_before(preceding_op): + binary_op_clamp_node = graph_module.graph.create_node( + "call_function", + new_op, + new_args, + ) + + preceding_op.replace_all_uses_with(binary_op_clamp_node) + graph_module.graph.erase_node(preceding_op) + + fuseAdded = True + + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + return [fuseAdded, graph_module] def call(self, graph_module: torch.fx.GraphModule): fuseAdded = True while fuseAdded: - fuseAdded = False - for arg_idx in range(0, 2): - for binary_op_node in graph_module.graph.nodes: - if binary_op_node.op == "call_function": - if binary_op_node.target in self.FUSEABLE_BINARY_OPS: - preceding_op = binary_op_node.args[arg_idx] - - if ( - preceding_op.op == "call_function" - and preceding_op.target in self.FUSEABLE_OPS - ): - # Ensure the shapes match - if "val" not in binary_op_node.args[0].meta or "val" not in binary_op_node.args[1].meta: - continue - if len(binary_op_node.args[1].meta["val"].shape) != len(binary_op_node.args[0].meta["val"].shape): - continue - - # Get the texture to do the binary op - texture = binary_op_node.args[(arg_idx + 1) % 2] - - # Fuse only if the texture exists before the preceding op - if not self.exists_before(graph_module, texture, preceding_op): - continue - - new_args = list(preceding_op.args) - - # insert the min/max at indices 1 and 2 - output_min_max = self.get_output_min_max_from_activation( - preceding_op - ) - new_args.insert(1, output_min_max[0]) - new_args.insert(2, output_min_max[1]) - - # put the other texture at idx 3 - new_args.insert(3, texture) - new_args = new_args[0:4] - - new_args = tuple(new_args) - binary_op_node.replace_all_uses_with(preceding_op) - graph_module.graph.erase_node(binary_op_node) - - new_op = None - if binary_op_node.target == exir_ops.edge.aten.add.Tensor: - new_op = exir_ops.edge.et_vk.clamp_with_binary_add.default - if binary_op_node.target == exir_ops.edge.aten.sub.Tensor: - new_op = exir_ops.edge.et_vk.clamp_with_binary_sub.default - if binary_op_node.target == exir_ops.edge.aten.mul.Tensor: - new_op = exir_ops.edge.et_vk.clamp_with_binary_mul.default - if binary_op_node.target == exir_ops.edge.aten.div.Tensor: - new_op = exir_ops.edge.et_vk.clamp_with_binary_div.default - - # Create and insert node of custom op `clamp_with_binary_op` - with graph_module.graph.inserting_before(preceding_op): - clamp_binary_op_node = graph_module.graph.create_node( - "call_function", - new_op, - new_args, - ) - - preceding_op.replace_all_uses_with(clamp_binary_op_node) - graph_module.graph.erase_node(preceding_op) - - fuseAdded = True - - graph_module.recompile() - graph_module = super().call(graph_module).graph_module + fuseAdded, graph_module = self.fuse_binary_op_with_clamp(graph_module) return PassResult(graph_module, True) diff --git a/backends/transforms/fuse_clamps.py b/backends/transforms/fuse_clamps.py index 0191c994747..6e5be508d54 100644 --- a/backends/transforms/fuse_clamps.py +++ b/backends/transforms/fuse_clamps.py @@ -13,6 +13,7 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult + class FuseClampsPass(ExportPass): FUSEABLE_CLAMPS = [ @@ -40,7 +41,6 @@ def get_output_min_max_from_activation(self, activation_node): output_max = activation_node.args[2] return output_min, output_max - def call(self, graph_module: torch.fx.GraphModule): fuseAdded = True @@ -55,13 +55,22 @@ def call(self, graph_module: torch.fx.GraphModule): and preceding_op.target in self.FUSEABLE_CLAMPS ): # Ensure the shapes match - if "val" not in clamp_2_node.args[0].meta or "val" not in preceding_op.args[0].meta: + if ( + "val" not in clamp_2_node.args[0].meta + or "val" not in preceding_op.args[0].meta + ): continue - if len(clamp_2_node.args[0].meta["val"].shape) != len(preceding_op.args[0].meta["val"].shape): + if len(clamp_2_node.args[0].meta["val"].shape) != len( + preceding_op.args[0].meta["val"].shape + ): continue - min_max1 = self.get_output_min_max_from_activation(preceding_op) - min_max2 = self.get_output_min_max_from_activation(clamp_2_node) + min_max1 = self.get_output_min_max_from_activation( + preceding_op + ) + min_max2 = self.get_output_min_max_from_activation( + clamp_2_node + ) min_max = [None, None] @@ -71,7 +80,7 @@ def call(self, graph_module: torch.fx.GraphModule): min_max[0] = min_max1[0] else: min_max[0] = min(min_max1[0], min_max2[0]) - + if min_max1[1] is None and min_max2[1] is not None: min_max[1] = min_max2[1] elif min_max1[1] is not None and min_max2[1] is None: diff --git a/backends/transforms/fuse_conv_with_binary_op.py b/backends/transforms/fuse_conv_with_binary_op.py deleted file mode 100644 index 6cf5d6054a2..00000000000 --- a/backends/transforms/fuse_conv_with_binary_op.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import executorch.backends.vulkan.custom_ops_lib # noqa - -import torch - -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult - -class FuseConvBinaryOpPass(ExportPass): - """ - Some activations like ReLU and hardtanh can be fused with certain operators (e.g. convolution) preceding it. - """ - - FUSEABLE_OPS = [ - exir_ops.edge.aten.convolution.default, - ] - FUSEABLE_BINARY_OPS = [ - exir_ops.edge.aten.add.Tensor, - exir_ops.edge.aten.sub.Tensor, - exir_ops.edge.aten.mul.Tensor, - exir_ops.edge.aten.div.Tensor, - ] - - def exists_before(self, graph_module, node_a, node_b): - seen_a = False - for n in graph_module.graph.nodes: - if n is node_a: - seen_a = True - if n is node_b: - return seen_a - return False - - - def call(self, graph_module: torch.fx.GraphModule): - - fuseAdded = True - while fuseAdded: - fuseAdded = False - for arg_idx in range(0, 2): - for binary_op_node in graph_module.graph.nodes: - if binary_op_node.op == "call_function" and binary_op_node.target in self.FUSEABLE_BINARY_OPS: - preceding_op = binary_op_node.args[arg_idx] - if ( - preceding_op.op == "call_function" - and preceding_op.target in self.FUSEABLE_OPS - ): - - # For now only pw conv2d s1p0 is supported - if not (len(preceding_op.args[3]) == 2 and preceding_op.args[3][0] == 1 and preceding_op.args[3][1] == 1 and preceding_op.args[4][0] == 0 and preceding_op.args[4][1] == 0): - continue - - # Ensure the shapes match - if "val" not in binary_op_node.args[0].meta or "val" not in binary_op_node.args[1].meta: - continue - if len(binary_op_node.args[0].meta["val"].shape) != len(binary_op_node.args[1].meta["val"].shape): - continue - - # Get the texture to do the binary op - texture = binary_op_node.args[(arg_idx + 1) % 2] - - # Fuse only if the texture exists before the preceding op - if not self.exists_before(graph_module, texture, preceding_op): - continue - - new_args = list(preceding_op.args) - new_args.append(texture) - new_args = tuple(new_args) - binary_op_node.replace_all_uses_with(preceding_op) - graph_module.graph.erase_node(binary_op_node) - - new_op = None - if binary_op_node.target == exir_ops.edge.aten.add.Tensor: - new_op = exir_ops.edge.et_vk.conv_with_binary_add.default - if binary_op_node.target == exir_ops.edge.aten.sub.Tensor: - new_op = exir_ops.edge.et_vk.conv_with_binary_sub.default - if binary_op_node.target == exir_ops.edge.aten.mul.Tensor: - new_op = exir_ops.edge.et_vk.conv_with_binary_mul.default - if binary_op_node.target == exir_ops.edge.aten.div.Tensor: - new_op = exir_ops.edge.et_vk.conv_with_binary_div.default - - # Create and insert node of custom op `conv_with_binary_op` - with graph_module.graph.inserting_before(preceding_op): - conv_binary_op_node = graph_module.graph.create_node( - "call_function", - new_op, - new_args, - ) - - preceding_op.replace_all_uses_with(conv_binary_op_node) - graph_module.graph.erase_node(preceding_op) - - fuseAdded = True - - graph_module.recompile() - graph_module = super().call(graph_module).graph_module - - return PassResult(graph_module, True) diff --git a/backends/transforms/fuse_conv_with_clamp.py b/backends/transforms/fuse_conv_with_clamp.py index 824d67a5c88..52fc1f4a413 100644 --- a/backends/transforms/fuse_conv_with_clamp.py +++ b/backends/transforms/fuse_conv_with_clamp.py @@ -25,6 +25,7 @@ class FuseConvClampPass(ExportPass): FUSEABLE_ACTIVATIONS = [ exir_ops.edge.aten.relu.default, exir_ops.edge.aten.hardtanh.default, + exir_ops.edge.aten.clamp.default, ] def get_output_min_max_from_activation(self, activation_node): diff --git a/backends/transforms/targets.bzl b/backends/transforms/targets.bzl index f996fec0efc..f354f2234bd 100644 --- a/backends/transforms/targets.bzl +++ b/backends/transforms/targets.bzl @@ -77,22 +77,6 @@ def define_common_targets(): ], ) - runtime.python_library( - name = "fuse_conv_with_binary_op", - srcs = ["fuse_conv_with_binary_op.py"], - visibility = [ - "//executorch/backends/...", - ], - deps = [ - ":utils", - "//caffe2:torch", - "//executorch/backends/vulkan:custom_ops_lib", - "//executorch/exir:pass_base", - "//executorch/exir:sym_util", - "//executorch/exir/dialects:lib", - ], - ) - runtime.python_library( name = "fuse_clamps", srcs = ["fuse_clamps.py"], diff --git a/backends/vulkan/custom_ops_lib.py b/backends/vulkan/custom_ops_lib.py index c4e49ba9428..e99883cadd9 100644 --- a/backends/vulkan/custom_ops_lib.py +++ b/backends/vulkan/custom_ops_lib.py @@ -476,7 +476,7 @@ def clamp_with_binary_add_out_impl( name = "clamp_with_binary_add.out" lib.define( - f"{name}(Tensor input, Tensor weight, Scalar? output_min, Scalar? output_max, Tensor? other, *, Tensor(a!) out) -> Tensor(a!)" + f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other, *, Tensor(a!) out) -> Tensor(a!)" ) lib.impl(name, clamp_with_binary_add_out_impl, "CompositeExplicitAutograd") @@ -531,7 +531,7 @@ def clamp_with_binary_sub_out_impl( name = "clamp_with_binary_sub.out" lib.define( - f"{name}(Tensor input, Tensor weight, Scalar? output_min, Scalar? output_max, Tensor? other, *, Tensor(a!) out) -> Tensor(a!)" + f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other, *, Tensor(a!) out) -> Tensor(a!)" ) lib.impl(name, clamp_with_binary_sub_out_impl, "CompositeExplicitAutograd") @@ -586,7 +586,7 @@ def clamp_with_binary_mul_out_impl( name = "clamp_with_binary_mul.out" lib.define( - f"{name}(Tensor input, Tensor weight, Scalar? output_min, Scalar? output_max, Tensor? other, *, Tensor(a!) out) -> Tensor(a!)" + f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other, *, Tensor(a!) out) -> Tensor(a!)" ) lib.impl(name, clamp_with_binary_mul_out_impl, "CompositeExplicitAutograd") @@ -641,10 +641,231 @@ def clamp_with_binary_div_out_impl( name = "clamp_with_binary_div.out" lib.define( - f"{name}(Tensor input, Tensor weight, Scalar? output_min, Scalar? output_max, Tensor? other, *, Tensor(a!) out) -> Tensor(a!)" + f"{name}(Tensor input, Scalar? output_min, Scalar? output_max, Tensor? other, *, Tensor(a!) out) -> Tensor(a!)" ) lib.impl(name, clamp_with_binary_div_out_impl, "CompositeExplicitAutograd") +########################### +## binary_add_with_clamp ## +########################### + + +def binary_add_with_clamp_impl( + input, + other=None, + output_min=-float("inf"), + output_max=float("inf"), +): + return torch.clamp( + torch.add( + input, + other, + ), + output_min, + output_max, + ) + + +name = "binary_add_with_clamp" +lib.define( + f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max) -> Tensor" +) +lib.impl(name, binary_add_with_clamp_impl, "CompositeExplicitAutograd") +binary_add_with_clamp_op = getattr(getattr(torch.ops, namespace), name) + +############################### +## binary_add_with_clamp.out ## +############################### + + +def binary_add_with_clamp_out_impl( + input, + other=None, + output_min=-float("inf"), + output_max=float("inf"), + out=None, +): + out = binary_add_with_clamp_impl( + input, + output_min, + output_max, + other, + ) + return out + + +name = "binary_add_with_clamp.out" +lib.define( + f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, binary_add_with_clamp_impl, "CompositeExplicitAutograd") + +########################### +## binary_sub_with_clamp ## +########################### + + +def binary_sub_with_clamp_impl( + input, + other=None, + output_min=-float("inf"), + output_max=float("inf"), +): + return torch.clamp( + torch.sub( + input, + other, + ), + output_min, + output_max, + ) + + +name = "binary_sub_with_clamp" +lib.define( + f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max) -> Tensor" +) +lib.impl(name, binary_sub_with_clamp_impl, "CompositeExplicitAutograd") +binary_sub_with_clamp_op = getattr(getattr(torch.ops, namespace), name) + +############################### +## binary_sub_with_clamp.out ## +############################### + + +def binary_sub_with_clamp_out_impl( + input, + other=None, + output_min=-float("inf"), + output_max=float("inf"), + out=None, +): + out = binary_sub_with_clamp_impl( + input, + output_min, + output_max, + other, + ) + return out + + +name = "binary_sub_with_clamp.out" +lib.define( + f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, binary_sub_with_clamp_impl, "CompositeExplicitAutograd") + +########################### +## binary_mul_with_clamp ## +########################### + + +def binary_mul_with_clamp_impl( + input, + other=None, + output_min=-float("inf"), + output_max=float("inf"), +): + return torch.clamp( + torch.mul( + input, + other, + ), + output_min, + output_max, + ) + + +name = "binary_mul_with_clamp" +lib.define( + f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max) -> Tensor" +) +lib.impl(name, binary_mul_with_clamp_impl, "CompositeExplicitAutograd") +binary_mul_with_clamp_op = getattr(getattr(torch.ops, namespace), name) + +############################### +## binary_mul_with_clamp.out ## +############################### + + +def binary_mul_with_clamp_out_impl( + input, + other=None, + output_min=-float("inf"), + output_max=float("inf"), + out=None, +): + out = binary_mul_with_clamp_impl( + input, + output_min, + output_max, + other, + ) + return out + + +name = "binary_mul_with_clamp.out" +lib.define( + f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, binary_mul_with_clamp_impl, "CompositeExplicitAutograd") + +########################### +## binary_div_with_clamp ## +########################### + + +def binary_div_with_clamp_impl( + input, + other=None, + output_min=-float("inf"), + output_max=float("inf"), +): + return torch.clamp( + torch.div( + input, + other, + ), + output_min, + output_max, + ) + + +name = "binary_div_with_clamp" +lib.define( + f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max) -> Tensor" +) +lib.impl(name, binary_div_with_clamp_impl, "CompositeExplicitAutograd") +binary_div_with_clamp_op = getattr(getattr(torch.ops, namespace), name) + +############################### +## binary_div_with_clamp.out ## +############################### + + +def binary_div_with_clamp_out_impl( + input, + other=None, + output_min=-float("inf"), + output_max=float("inf"), + out=None, +): + out = binary_div_with_clamp_impl( + input, + output_min, + output_max, + other, + ) + return out + + +name = "binary_div_with_clamp.out" +lib.define( + f"{name}(Tensor input, Tensor? other, Scalar? output_min, Scalar? output_max, *, Tensor(a!) out) -> Tensor(a!)" +) +lib.impl(name, binary_div_with_clamp_impl, "CompositeExplicitAutograd") + + ################# ## grid_priors ## ################# diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index ea1f82173d6..768e38f400f 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -211,6 +211,10 @@ def register_torchao_choose_qparams_affine(): exir_ops.edge.aten.le.Tensor, exir_ops.edge.aten.gt.Tensor, exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.et_vk.binary_add_with_clamp.default, + exir_ops.edge.et_vk.binary_sub_with_clamp.default, + exir_ops.edge.et_vk.binary_mul_with_clamp.default, + exir_ops.edge.et_vk.binary_div_with_clamp.default, ] ) def register_binary_op(): @@ -241,7 +245,7 @@ def register_binary_op(): exir_ops.edge.et_vk.clamp_with_binary_add.default, exir_ops.edge.et_vk.clamp_with_binary_sub.default, exir_ops.edge.et_vk.clamp_with_binary_mul.default, - exir_ops.edge.et_vk.clamp_with_binary_div.default + exir_ops.edge.et_vk.clamp_with_binary_div.default, ] ) def register_unary_op(): @@ -475,10 +479,6 @@ def register_2d_pool_op(): [ exir_ops.edge.aten.convolution.default, exir_ops.edge.et_vk.conv_with_clamp.default, - exir_ops.edge.et_vk.conv_with_binary_add.default, - exir_ops.edge.et_vk.conv_with_binary_sub.default, - exir_ops.edge.et_vk.conv_with_binary_mul.default, - exir_ops.edge.et_vk.conv_with_binary_div.default, ] ) def register_convolution_op(): @@ -495,7 +495,6 @@ def register_convolution_op(): utils.NO_STORAGE, # groups (non tensor) utils.NO_STORAGE, # output_min (non tensor) utils.NO_STORAGE, # output_max (non tensor) - utils.NO_STORAGE, # other (prepacked) ], supports_resize=True, supports_prepacking=True, diff --git a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl index 6f2a93667ea..ed420fcc72f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/binary_op.glsl @@ -69,6 +69,9 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; ${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} ${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} ${layout_declare_spec_const(C, "int", "other_layout", "DEFAULT_LAYOUT")} +${layout_declare_spec_const(C, "int", "clamp_type", "0")} +${layout_declare_spec_const(C, "float", "min_val", "0")} +${layout_declare_spec_const(C, "float", "max_val", "0")} $if STORAGE == "buffer": const lowp ivec4 out_dim_order = unhash_dim_order(out_layout); @@ -90,7 +93,20 @@ void main() { // Simple case; no broadcasting if (are_equal(inp, other)) { - t_out[out_bufi] = T(op(t_in[out_bufi], t_other[out_bufi], T(alpha))); + T in_val = T(t_in[out_bufi]); + T other_val = T(t_other[out_bufi]); + if (clamp_type == 1) { + in_val = T(clamp(in_val, T(min_val), T(max_val))); + } + else if (clamp_type == 2) { + other_val = T(clamp(other_val, T(min_val), T(max_val))); + } + T out_val = T(op(in_val, other_val, T(alpha))); + if (clamp_type == 3) { + out_val = T(clamp(out_val, T(min_val), T(max_val))); + } + t_out[out_bufi] = out_val; + return; } @@ -106,7 +122,19 @@ void main() { uint inp_bufi = tensor_idx_to_linear_idx(inp, inp_tidx); uint other_bufi = tensor_idx_to_linear_idx(other, other_tidx); - t_out[out_bufi] = T(op(t_in[inp_bufi], t_other[other_bufi], T(alpha))); + T in_val = T(t_in[inp_bufi]); + T other_val = T(t_other[other_bufi]); + if (clamp_type == 1) { + in_val = T(clamp(in_val, T(min_val), T(max_val))); + } + else if (clamp_type == 2) { + other_val = T(clamp(other_val, T(min_val), T(max_val))); + } + T out_val = T(op(in_val, other_val, T(alpha))); + if (clamp_type == 3) { + out_val = T(clamp(out_val, T(min_val), T(max_val))); + } + t_out[out_bufi] = out_val; } #else // USING_TEXTURE @@ -126,6 +154,10 @@ void main() { // read axis mapped texel tidx_to_pos(in_idx, in_sizes, in_axis_map, packed_dim))); + if (clamp_type == 1) { + in_texel = clamp(in_texel, VEC4_T(min_val), VEC4_T(max_val)); + } + // broadcast on logical sizes ivec4 other_idx = broadcast_indices(tidx, other_sizes); VEC4_T other_texel = VEC4_T(load_texel( @@ -133,6 +165,10 @@ void main() { // read axis mapped texel tidx_to_pos(other_idx, other_sizes, other_axis_map, packed_dim))); + if (clamp_type == 2) { + in_texel = clamp(other_texel, VEC4_T(min_val), VEC4_T(max_val)); + } + // Check boolean broadcast flags; we use ivec2 instead of bvec2 for alignment. if (broadcast_params.x > 0) { in_texel = in_texel.xxxx; @@ -141,11 +177,20 @@ void main() { other_texel = other_texel.xxxx; } - write_texel_lpos( - t_out, - lpos, - VEC4_OUT_T(op(in_texel, other_texel, alpha)), - out_axis_map); + if (clamp_type != 3) { + write_texel_lpos( + t_out, + lpos, + VEC4_OUT_T(op(in_texel, other_texel, alpha)), + out_axis_map); + } + else { + write_texel_lpos( + t_out, + lpos, + VEC4_OUT_T(clamp(VEC4_OUT_T(op(in_texel, other_texel, alpha)), min_val, max_val)), + out_axis_map); + } } #endif diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl index 79577a0e2a2..ef50a1aca9f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl @@ -30,10 +30,6 @@ ${layout_declare_tensor(1, "r", "t_in", DTYPE, "texture3d")} ${layout_declare_tensor(2, "r", "t_kernel", DTYPE, "texture2d")} ${layout_declare_tensor(3, "r", "t_bias", DTYPE, "texture2d")} -$if FUSED_OP != "A": - #define fused_op(A, B) ${FUSED_OP} - ${layout_declare_tensor(4, "r", "t_other", DTYPE, "texture3d")} - layout(push_constant) uniform restrict Block { ivec4 out_limits; ivec2 stride; @@ -141,9 +137,5 @@ void main() { outputTexel[3] += dot(inputVec, VEC4_T(weight1OutputChannelPacked[3], weight2OutputChannelPacked[3], weight3OutputChannelPacked[3], weight4OutputChannelPacked[3])); } -$if FUSED_OP != "A": - VEC4_T otherTexel = VEC4_T(texelFetch(t_other, ivec3(xIdx, yIdx, threadOutChannel), 0)); - fused_op(outputTexel, otherTexel); - -imageStore(t_out, ivec3(xIdx, yIdx, threadOutChannel), op(vec4(outputTexel), out_min, out_max)); + imageStore(t_out, ivec3(xIdx, yIdx, threadOutChannel), op(vec4(outputTexel), out_min, out_max)); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml index 0cc872e0392..832b5bd09dc 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.yaml @@ -19,15 +19,3 @@ conv2d_pw_s1p0: - NAME: conv2d_pw_s1p0 - NAME: conv2d_pw_s1p0_clamp OPERATOR: clamp(X, A, B) - - NAME: conv2d_pw_s1p0_add - FUSED_OP: A += B - IS_FUSED: 1 - - NAME: conv2d_pw_s1p0_sub - FUSED_OP: A -= B - IS_FUSED: 1 - - NAME: conv2d_pw_s1p0_mul - FUSED_OP: A *= B - IS_FUSED: 1 - - NAME: conv2d_pw_s1p0_div - FUSED_OP: A /= B - IS_FUSED: 1 diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl b/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl index d5df5bf00cf..5bc01fa7f57 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.glsl @@ -17,9 +17,6 @@ ${define_active_storage_type(STORAGE)} -$if BINARY_OP != "A": - #define binary_op(A, B) ${BINARY_OP} - #include "indexing_utils.h" ${define_required_extensions(DTYPE)} @@ -28,8 +25,6 @@ layout(std430) buffer; ${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} ${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)} -$if BINARY_OP != "A": - ${layout_declare_tensor(2, "r", "t_other", DTYPE, STORAGE)} layout(push_constant) uniform restrict Block { $if STORAGE == "buffer": @@ -67,10 +62,7 @@ void main() { VEC4_T in_texel = texelFetch(t_in, pos, 0); -$if BINARY_OP == "A": imageStore(t_out, pos, VEC4_T(op(in_texel, minimum, maximum))); -$else: - imageStore(t_out, pos, VEC4_T(binary_op(op(in_texel, minimum, maximum), texelFetch(t_other, pos, 0)))); } #endif diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index e22affce11e..47f538aee6c 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -3,7 +3,6 @@ unary_op: OPERATOR: clamp(X, A, B) DTYPE: float STORAGE: texture3d - BINARY_OP: A generate_variant_forall: DTYPE: - VALUE: half @@ -16,18 +15,6 @@ unary_op: OPERATOR: abs(X) - NAME: clamp OPERATOR: clamp(X, A, B) - - NAME: clamp_add - OPERATOR: clamp(X, A, B) - BINARY_OP: A + B - - NAME: clamp_sub - OPERATOR: clamp(X, A, B) - BINARY_OP: A - B - - NAME: clamp_mul - OPERATOR: clamp(X, A, B) - BINARY_OP: A * B - - NAME: clamp_div - OPERATOR: clamp(X, A, B) - BINARY_OP: A / B - NAME: clamp_int32 OPERATOR: clamp(X, A, B) DTYPE: int32 diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp index 025b483eab7..9575ca0dcdd 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp @@ -54,13 +54,39 @@ void resize_binary_op_node( graph->virtual_resize(out, new_out_sizes); } +int remove_clamp_from_name(std::string& op) { + if (op.find("clamp_0_with_") != std::string::npos) { + op.erase(op.find("clamp_0_with_"), 13); + + // Clamp input 0 + return 1; + } + if (op.find("clamp_1_with_") != std::string::npos) { + op.erase(op.find("clamp_1_with_"), 13); + + // Clamp input 1 + return 2; + } + if (op.find("_with_clamp") != std::string::npos) { + op.erase(op.find("_with_clamp"), 11); + + // Clamp output + return 3; + } + + // No clamp + return 0; +} + void add_binary_op_texture_node( ComputeGraph& graph, const ValueRef in1, const ValueRef in2, const ValueRef alpha, const ValueRef out, - const std::string& op_name) { + const std::string& op_name, + const float min, + const float max) { ValueRef arg1 = prepack_standard_like(graph, in1, out, true); ValueRef arg2 = prepack_standard_like(graph, in2, out, true); @@ -80,7 +106,10 @@ void add_binary_op_texture_node( std::string kernel_name("binary_"); kernel_name.reserve(kShaderNameReserve); - kernel_name += op_name; + + std::string op = op_name; + int clamp_type = remove_clamp_from_name(op); + kernel_name += op; add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_dtype_suffix(kernel_name, graph.dtype_of(in1)); @@ -101,7 +130,10 @@ void add_binary_op_texture_node( // Specialization Constants {graph.hashed_layout_of(out), graph.hashed_layout_of(arg1), - graph.hashed_layout_of(arg2)}, + graph.hashed_layout_of(arg2), + clamp_type, + min, + max}, // Resize Args {}, // Resizing Logic @@ -114,7 +146,9 @@ void add_binary_op_buffer_node( const ValueRef in2, const ValueRef alpha, const ValueRef out, - const std::string& op_name) { + const std::string& op_name, + const float min, + const float max) { // check_binary_op_args(*t_in1, *t_in2, *t_out); float alpha_val = 1.0f; @@ -126,7 +160,9 @@ void add_binary_op_buffer_node( std::string kernel_name("binary_"); kernel_name.reserve(kShaderNameReserve); - kernel_name += op_name; + std::string op = op_name; + int clamp_type = remove_clamp_from_name(op); + kernel_name += op; add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); add_dtype_suffix(kernel_name, graph.dtype_of(in1)); @@ -149,7 +185,9 @@ void add_binary_op_buffer_node( // Specialization Constants {graph.hashed_layout_of(out), graph.hashed_layout_of(in1), - graph.hashed_layout_of(in2)}, + graph.hashed_layout_of(in2), + min, + max}, // Resize Args {}, // Resizing Logic @@ -162,11 +200,13 @@ void add_binary_op_node( const ValueRef in2, const ValueRef alpha, const ValueRef out, - const std::string& op_name) { + const std::string& op_name, + const float min = std::numeric_limits::infinity(), + const float max = -std::numeric_limits::infinity()) { if (graph.is_buffer_storage(out)) { - add_binary_op_buffer_node(graph, in1, in2, alpha, out, op_name); + add_binary_op_buffer_node(graph, in1, in2, alpha, out, op_name, min, max); } else { - add_binary_op_texture_node(graph, in1, in2, alpha, out, op_name); + add_binary_op_texture_node(graph, in1, in2, alpha, out, op_name, min, max); } } @@ -182,6 +222,40 @@ void add_binary_op_node( graph, args[0], args[1], kDummyValueRef, args[2], #op_name); \ } +float get_val_or_inf_(ComputeGraph& graph, const ValueRef& val, bool max) { + if (!graph.val_is_none(val)) { + return graph.extract_scalar(val); + } + return max ? std::numeric_limits::infinity() + : -std::numeric_limits::infinity(); +} + +#define DEFINE_BINARY_OP_WITH_ALPHA_FN_CLAMPED(op_name) \ + void op_name(ComputeGraph& graph, const std::vector& args) { \ + return add_binary_op_node( \ + graph, \ + args[0], \ + args[1], \ + args[2], \ + args[5], \ + #op_name, \ + get_val_or_inf_(graph, args[3], false), \ + get_val_or_inf_(graph, args[4], true)); \ + } + +#define DEFINE_BINARY_OP_FN_CLAMPED(op_name) \ + void op_name(ComputeGraph& graph, const std::vector& args) { \ + return add_binary_op_node( \ + graph, \ + args[0], \ + args[1], \ + kDummyValueRef, \ + args[4], \ + #op_name, \ + get_val_or_inf_(graph, args[2], false), \ + get_val_or_inf_(graph, args[3], true)); \ + } + DEFINE_BINARY_OP_WITH_ALPHA_FN(add); DEFINE_BINARY_OP_WITH_ALPHA_FN(sub); @@ -199,6 +273,11 @@ DEFINE_BINARY_OP_FN(le); DEFINE_BINARY_OP_FN(gt); DEFINE_BINARY_OP_FN(ge); +DEFINE_BINARY_OP_FN_CLAMPED(add_with_clamp); +DEFINE_BINARY_OP_FN_CLAMPED(sub_with_clamp); +DEFINE_BINARY_OP_FN_CLAMPED(mul_with_clamp); +DEFINE_BINARY_OP_FN_CLAMPED(div_with_clamp); + REGISTER_OPERATORS { VK_REGISTER_OP(aten.add.Tensor, add); VK_REGISTER_OP(aten.sub.Tensor, sub); @@ -212,6 +291,11 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.le.Tensor, le); VK_REGISTER_OP(aten.gt.Tensor, gt); VK_REGISTER_OP(aten.ge.Tensor, ge); + + VK_REGISTER_OP(et_vk.binary_add_with_clamp.default, add_with_clamp); + VK_REGISTER_OP(et_vk.binary_sub_with_clamp.default, sub_with_clamp); + VK_REGISTER_OP(et_vk.binary_mul_with_clamp.default, mul_with_clamp); + VK_REGISTER_OP(et_vk.binary_div_with_clamp.default, div_with_clamp); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 0fbcdcedfc2..479bb44ae6f 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -131,8 +131,7 @@ vkapi::ShaderInfo get_conv2d_shader( const ValueRef weight, const bool clamp_out = false, const bool stride_equals_dilation = false, - const bool stride_1_padding_0 = false, - const std::string& op_name = "") { + const bool stride_1_padding_0 = false) { std::string kernel_name; kernel_name.reserve(kShaderNameReserve); switch (method) { @@ -157,19 +156,6 @@ vkapi::ShaderInfo get_conv2d_shader( } else { kernel_name = stride_1_padding_0 ? "conv2d_pw_s1p0" : "conv2d_pw"; } - - // For now, binary op fusing is only supported for Conv2D PW s1p0 - if (stride_1_padding_0) { - if (op_name == "conv_add") { - kernel_name += "_add"; - } else if (op_name == "conv_sub") { - kernel_name += "_sub"; - } else if (op_name == "conv_mul") { - kernel_name += "_mul"; - } else if (op_name == "conv_div") { - kernel_name += "_div"; - } - } break; case Conv2dMethod::SlidingWindow: kernel_name = "conv2d"; @@ -459,10 +445,8 @@ void add_conv2d_node( const ValueRef groups, const ValueRef out_min, const ValueRef out_max, - const ValueRef binary_op_other, const ValueRef out, - const bool clamp_out, - const std::string& op_name) { + const bool clamp_out) { const bool transposed_val = graph.get_bool(transposed); float out_min_val = 0.0f; @@ -525,8 +509,7 @@ void add_conv2d_node( weight_data, clamp_out, stride_equals_dilation, - stride_1_padding_0, - op_name); + stride_1_padding_0); utils::uvec3 wg_size = create_conv2d_global_wg_size( graph, method, out, weight_data, stride_equals_dilation); @@ -608,27 +591,13 @@ void add_conv2d_node( }; } - ValueRef arg_binary_op_other = binary_op_other; - if (binary_op_other != kDummyValueRef) { - arg_binary_op_other = - prepack_standard_like(graph, binary_op_other, out, true); - - check_conv_args(graph, arg_binary_op_other, out); - } - - std::vector args = { - {out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}}; - std::vector args_with_binary_op = { - {out, vkapi::kWrite}, - {{in, arg_weight, arg_bias, arg_binary_op_other}, vkapi::kRead}}; - graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, shader, conv2d_global_wg_size, conv2d_local_wg_size, // Inputs and Outputs - binary_op_other == kDummyValueRef ? args : args_with_binary_op, + {{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}}, // Shader params buffers param_buffers, // Push Constants @@ -741,33 +710,11 @@ void add_conv1d_node( resize_conv1d_node)); } -void conv( - ComputeGraph& graph, - const std::vector& args, - const std::string& op_name) { +void conv(ComputeGraph& graph, const std::vector& args) { int64_t in_ndim = graph.dim_of(args[0]); if (in_ndim == 4) { if (args.size() == 10) { // ordinary conv2d - return add_conv2d_node( - graph, - args[0], - args[1], - args[2], - args[3], - args[4], - args[5], - args[6], - args[7], - args[8], - /*out_min = */ kDummyValueRef, - /*out_max = */ kDummyValueRef, - /*binary_op_other = */ kDummyValueRef, - args[9], - false, - op_name); - } else if (args.size() == 11) { - // conv2d with binary op return add_conv2d_node( graph, args[0], @@ -782,9 +729,7 @@ void conv( /*out_min = */ kDummyValueRef, /*out_max = */ kDummyValueRef, args[9], - args[10], - false, - op_name); + false); } else { // conv2d with clamp return add_conv2d_node( @@ -800,10 +745,8 @@ void conv( args[8], args[9], args[10], - /*binary_op_other = */ kDummyValueRef, args[11], - true, - op_name); + true); } } else { if (args.size() == 10) { @@ -840,25 +783,10 @@ void conv( } } -#define DEFINE_CONV_BINARY_OP_FN(op_name) \ - void op_name(ComputeGraph& graph, const std::vector& args) { \ - return conv(graph, args, #op_name); \ - } - -DEFINE_CONV_BINARY_OP_FN(conv_no_op); -DEFINE_CONV_BINARY_OP_FN(conv_add); -DEFINE_CONV_BINARY_OP_FN(conv_sub); -DEFINE_CONV_BINARY_OP_FN(conv_mul); -DEFINE_CONV_BINARY_OP_FN(conv_div); - REGISTER_OPERATORS { - VK_REGISTER_OP(aten.convolution.default, conv_no_op); - VK_REGISTER_OP(conv_with_clamp.default, conv_no_op); - VK_REGISTER_OP(et_vk.conv_with_clamp.default, conv_no_op); - VK_REGISTER_OP(et_vk.conv_with_binary_add.default, conv_add); - VK_REGISTER_OP(et_vk.conv_with_binary_sub.default, conv_sub); - VK_REGISTER_OP(et_vk.conv_with_binary_mul.default, conv_mul); - VK_REGISTER_OP(et_vk.conv_with_binary_div.default, conv_div); + VK_REGISTER_OP(aten.convolution.default, conv); + VK_REGISTER_OP(conv_with_clamp.default, conv); + VK_REGISTER_OP(et_vk.conv_with_clamp.default, conv); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index f64133fec01..9830a8e8784 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -39,16 +39,11 @@ void add_unary_op_node( const float min, const float max, const ValueRef out, - const std::string& op_name, - const ValueRef other = kDummyValueRef) { + const std::string& op_name) { std::string kernel_name(op_name); add_dtype_suffix(kernel_name, graph.dtype_of(out)); add_storage_type_suffix(kernel_name, graph.storage_type_of(out)); - std::vector args = {{out, vkapi::kWrite}, {in, vkapi::kRead}}; - std::vector args_with_binary_op = { - {out, vkapi::kWrite}, {{in, other}, vkapi::kRead}}; - const utils::vec2 min_max = {min, max}; graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, @@ -56,7 +51,7 @@ void add_unary_op_node( default_pick_global_wg_size, default_pick_local_wg_size, // Inputs and Outputs - other == kDummyValueRef ? args : args_with_binary_op, + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers {}, // Push Constants @@ -99,18 +94,6 @@ float get_val_or_inf(ComputeGraph& graph, const ValueRef& val, bool max) { kClampShaderName); \ } -#define DEFINE_CLAMP_BINARY_OP_FN(op_name) \ - void op_name(ComputeGraph& graph, const std::vector& args) { \ - return add_unary_op_node( \ - graph, \ - args[0], \ - get_val_or_inf(graph, args[1], /*max = */ false), \ - get_val_or_inf(graph, args[2], /*max = */ true), \ - args[args.size() - 1], \ - #op_name, \ - args[3]); \ - } - #define DEFINE_RELU_FN(op_name) \ void op_name(ComputeGraph& graph, const std::vector& args) { \ return add_unary_op_node( \ @@ -176,11 +159,6 @@ DEFINE_ACTIVATION_FN(hardsigmoid); DEFINE_LEAKY_RELU_FN(leaky_relu); DEFINE_ACTIVATION_FN(round); -DEFINE_CLAMP_BINARY_OP_FN(clamp_add); -DEFINE_CLAMP_BINARY_OP_FN(clamp_sub); -DEFINE_CLAMP_BINARY_OP_FN(clamp_mul); -DEFINE_CLAMP_BINARY_OP_FN(clamp_div); - REGISTER_OPERATORS { VK_REGISTER_OP(aten.abs.default, abs); VK_REGISTER_OP(aten.clamp.default, clamp); @@ -201,11 +179,6 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.hardsigmoid.default, hardsigmoid); VK_REGISTER_OP(aten.leaky_relu.default, leaky_relu); VK_REGISTER_OP(aten.round.default, round); - - VK_REGISTER_OP(et_vk.clamp_with_binary_add.default, clamp_add); - VK_REGISTER_OP(et_vk.clamp_with_binary_sub.default, clamp_sub); - VK_REGISTER_OP(et_vk.clamp_with_binary_mul.default, clamp_mul); - VK_REGISTER_OP(et_vk.clamp_with_binary_div.default, clamp_div); } } // namespace vkcompute diff --git a/backends/vulkan/targets.bzl b/backends/vulkan/targets.bzl index 8a95134578d..170afe4dc44 100644 --- a/backends/vulkan/targets.bzl +++ b/backends/vulkan/targets.bzl @@ -383,7 +383,6 @@ def define_common_targets(is_fbcode = False): "//executorch/backends/transforms:fuse_batch_norm_with_conv", "//executorch/backends/transforms:fuse_clamp_with_binary_op", "//executorch/backends/transforms:fuse_clamps", - "//executorch/backends/transforms:fuse_conv_with_binary_op", "//executorch/backends/transforms:fuse_conv_with_clamp", "//executorch/backends/transforms:fuse_view_copy", "//executorch/backends/transforms:remove_clone_ops", diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 494351d0fb0..f8e5aab8141 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -13,10 +13,11 @@ import executorch.backends.vulkan.utils as utils from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform -from executorch.backends.transforms.fuse_conv_with_clamp import FuseConvClampPass -from executorch.backends.transforms.fuse_conv_with_binary_op import FuseConvBinaryOpPass -from executorch.backends.transforms.fuse_clamp_with_binary_op import FuseClampBinaryOpPass +from executorch.backends.transforms.fuse_clamp_with_binary_op import ( + FuseClampBinaryOpPass, +) from executorch.backends.transforms.fuse_clamps import FuseClampsPass +from executorch.backends.transforms.fuse_conv_with_clamp import FuseConvClampPass from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform from executorch.backends.transforms.view_copy_to_squeeze_unsqueeze import ( ViewCopyToSqueezeUnsqueezePass, @@ -172,7 +173,6 @@ def preprocess( # noqa: C901 FuseBatchNormPass(program), FuseClampsPass(), FuseConvClampPass(), - FuseConvBinaryOpPass(), FuseClampBinaryOpPass(), ], ) From 4e6469187dc00da0244ab8607172d971ceb49940 Mon Sep 17 00:00:00 2001 From: Alex Dean Date: Thu, 9 Oct 2025 09:17:43 -0700 Subject: [PATCH 7/8] Lint fix --- backends/vulkan/vulkan_preprocess.py | 1 - 1 file changed, 1 deletion(-) diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 40eb4395895..001662557f8 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -173,7 +173,6 @@ def preprocess( # noqa: C901 [ FuseBatchNormPass(program), FusePatternsPass(), - FuseClampPass(), AddmmToLinearTransform(), RemoveRedundantOpsTransform(), FuseQuantizedOpsTransform(), From ecc521f110a5aa4f774ac3b32754e2a40f3b8a41 Mon Sep 17 00:00:00 2001 From: Alex Dean Date: Thu, 9 Oct 2025 11:28:53 -0700 Subject: [PATCH 8/8] Change order of fusing passes --- backends/vulkan/vulkan_preprocess.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backends/vulkan/vulkan_preprocess.py b/backends/vulkan/vulkan_preprocess.py index 001662557f8..d23f0a29126 100644 --- a/backends/vulkan/vulkan_preprocess.py +++ b/backends/vulkan/vulkan_preprocess.py @@ -173,6 +173,9 @@ def preprocess( # noqa: C901 [ FuseBatchNormPass(program), FusePatternsPass(), + FuseClampsPass(), + FuseConvClampPass(), + FuseClampBinaryOpPass(), AddmmToLinearTransform(), RemoveRedundantOpsTransform(), FuseQuantizedOpsTransform(), @@ -181,9 +184,6 @@ def preprocess( # noqa: C901 SqueezeUnsqueezeInputs(), FuseViewCopyTransform(), ViewCopyToSqueezeUnsqueezePass(), - FuseClampsPass(), - FuseConvClampPass(), - FuseClampBinaryOpPass(), ], )