Skip to content

Commit fd00dac

Browse files
committed
Update on "[ET-VK] Migrate ops to use DynamicDispatchNode"
## Changes * Migrate operators that are used in the llama model to use `DynamicDispatchNode` instead of `DispatchNode` ## Motivation `DynamicDispatchNode` is a subclass of `DispatchNode` that allows dynamic selection of compute shaders, global and local work group sizing whenever the command buffer is encoded. This is critical for ensuring optimum performance when input shapes are dynamic, since it allows operators to select the best compute shader for the input conditions and also to adjust global work group sizing to launch the minimum number of work groups necessary. Without this change, performance of llama 3.2 1B with dynamic shapes enabled is terrible (< 1 tok/s) because global work group sizing is determined based on maximum tensor sizes, which is based on the maximum sequence length. In practice, the sequence length dimension of tensors (even during the prefill phase) will not approach the maximum. This results in a lot of inactive threads launched during compute shader dispatches. Differential Revision: [D75878398](https://our.internmc.facebook.com/intern/diff/D75878398/) [ghstack-poisoned]
2 parents 2a14b88 + c38171b commit fd00dac

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,5 @@ unary_op:
4646
OPERATOR: leaky_relu(X, A)
4747
- NAME: round
4848
OPERATOR: round(X)
49+
- NAME: tan
50+
OPERATOR: tan(X)

backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ DEFINE_ACTIVATION_FN(hardswish);
154154
DEFINE_ACTIVATION_FN(hardsigmoid);
155155
DEFINE_LEAKY_RELU_FN(leaky_relu);
156156
DEFINE_ACTIVATION_FN(round);
157+
DEFINE_ACTIVATION_FN(tan);
157158

158159
REGISTER_OPERATORS {
159160
VK_REGISTER_OP(aten.abs.default, abs);
@@ -174,6 +175,7 @@ REGISTER_OPERATORS {
174175
VK_REGISTER_OP(aten.hardsigmoid.default, hardsigmoid);
175176
VK_REGISTER_OP(aten.leaky_relu.default, leaky_relu);
176177
VK_REGISTER_OP(aten.round.default, round);
178+
VK_REGISTER_OP(aten.tan.default, tan);
177179
}
178180

179181
} // namespace vkcompute

0 commit comments

Comments
 (0)