Skip to content

Commit befc0b9

Browse files
committed
clean up code
1 parent 04bb39e commit befc0b9

File tree

1 file changed

+32
-120
lines changed

1 file changed

+32
-120
lines changed

tools/llm/torchtrt_ext/sdpa_converter.py

Lines changed: 32 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -164,130 +164,42 @@ def scaled_dot_product_attention(
164164
L = impl.shape.shape(ctx, target, source_ir, name + "_shape_0", query, 2)
165165
if S < 0:
166166
S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2)
167-
if is_causal:
168-
# generate the mask tensor
169-
tril_tensor = tril(
170-
ctx, target, source_ir, name + "_tril", L, S, sliding_window_size
171-
)
172167

173-
temp_mask = impl.unary.logical_not(
174-
ctx, target, source_ir, name + "_logical_not", tril_tensor
175-
)
176-
177-
# This need_mask determines if we want to use the causal mask or not
178-
# When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask.
179-
# So need_mask will be all False values in this case.
180-
# TODO: Implement more general case where L != 1 and S != L
181-
need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S)
182-
temp_mask = impl.elementwise.logical_and(
183-
ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask
184-
)
185-
temp_mask_casted = cast_trt_tensor(
186-
ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir
187-
)
188-
189-
one_minus_temp_mask = impl.elementwise.sub(
190-
ctx,
191-
target,
192-
source_ir,
193-
name + "_one_minus_temp_mask",
194-
1.0,
195-
temp_mask_casted,
196-
)
197-
attn_bias = impl.unary.log(
198-
ctx, target, source_ir, name + "_log", one_minus_temp_mask
199-
)
200-
scaled_add_attn_bias = impl.elementwise.add(
201-
ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias
202-
)
203-
else:
204-
use_if_conditional = False
205-
if not use_if_conditional:
206-
# works in non cache scenario, but in kv cache, got the following error:
207-
# ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 4: Internal Error (kOPT values for profile 0 violate shape constraints: [ELEMENTWISE]-[aten_ops.scaled_dot_product_attention]-[model.layers.0.self_attn/scaled_dot_product_attention_attn_mask_add]: dimensions not compatible for elementwise. Broadcast has incompatible dimensions: 5 != 71 && 5 != 1 && 71 != 1.)
208-
scaled_add_attn_bias = impl.elementwise.add(
209-
ctx, target, source_ir, name + "_attn_mask_add", mm, attn_mask
210-
)
211-
else:
212-
if_option = "if_conditional_subgraph" # if_conditional_subgraph or if_conditional or if_conditional_input
213-
if if_option == "if_conditional_subgraph":
214-
# reference: https://gitlab-master.nvidia.com/TensorRT/TensorRT/-/blob/main/documentation/operators/examples/example_if.py#L46
215-
# if_conditional_subgraph is not working, got the following error:
216-
# Internal Error: MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block
217-
# ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 1: Myelin ([myelin_graph.h:attachExceptionMsgToGraph:1139] MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block)
218-
219-
need_mask = impl.elementwise.eq(
220-
ctx, target, source_ir, name + "_eq", L, S
221-
)
222-
# if I do not squeeze, it will throw the error: condition must be a scalar tensor
223-
condition = impl.squeeze.squeeze(
224-
ctx, target, source_ir, name + "_unsqueeze", need_mask, 0
225-
)
226-
if_layer = ctx.net.add_if_conditional()
227-
if_layer.set_condition(condition)
228-
cond_input1 = if_layer.add_input(mm)
229-
cond_input2 = if_layer.add_input(attn_mask)
230-
231-
true_input = impl.elementwise.add(
232-
ctx,
233-
target,
234-
source_ir,
235-
name + "_attn_bias_add",
236-
cond_input1.get_output(0),
237-
cond_input2.get_output(0),
238-
)
239-
false_input = cond_input1.get_output(0)
240-
output_layer = if_layer.add_output(true_input, false_input)
241-
scaled_add_attn_bias = output_layer.get_output(0)
242-
elif if_option == "if_conditional_input":
243-
# reference: https://gitlab-master.nvidia.com/TensorRT/TensorRT/-/blob/main/documentation/operators/examples/example_if.py#L17
244-
# if_conditional_input is not working, got the following error:
245-
# Internal Error: MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block
246-
# ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 1: Myelin ([myelin_graph.h:attachExceptionMsgToGraph:1139] MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block)
168+
# generate the mask tensor
169+
tril_tensor = tril(
170+
ctx, target, source_ir, name + "_tril", L, S, sliding_window_size
171+
)
247172

248-
need_mask = impl.elementwise.eq(
249-
ctx, target, source_ir, name + "_eq", L, S
250-
)
251-
# if I do not squeeze, it will throw the error: condition must be a scalar tensor
252-
condition = impl.squeeze.squeeze(
253-
ctx, target, source_ir, name + "_unsqueeze", need_mask, 0
254-
)
255-
if_layer = ctx.net.add_if_conditional()
256-
if_layer.set_condition(condition)
257-
true_input = impl.elementwise.add(
258-
ctx, target, source_ir, name + "_attn_bias_add", mm, attn_mask
259-
)
260-
false_input = mm
261-
true_cond_input = if_layer.add_input(true_input)
262-
false_cond_input = if_layer.add_input(false_input)
263-
output_layer = if_layer.add_output(
264-
true_cond_input.get_output(0), false_cond_input.get_output(0)
265-
)
266-
scaled_add_attn_bias = output_layer.get_output(0)
267-
elif if_option == "if_conditional":
268-
# reference: https://github.com/pytorch/TensorRT/blob/535c6a8341a3258a9c311406a9af50eb3c68c5a6/examples/dynamo/llm/cache_utils.py#L15-L44
269-
# if_conditional is not working, got the following error:
270-
# Internal Error: MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block
271-
# ERROR:torch_tensorrt [TensorRT Conversion Context]:IBuilder::buildSerializedNetwork: Error Code 1: Myelin ([myelin_graph.h:attachExceptionMsgToGraph:1139] MyelinCheckException: utils.cpp:694: CHECK(common_bb == cur_call->dds_parent()->parent()) failed. Expect the graph has single block)
173+
temp_mask = impl.unary.logical_not(
174+
ctx, target, source_ir, name + "_logical_not", tril_tensor
175+
)
272176

273-
need_mask = impl.elementwise.eq(
274-
ctx, target, source_ir, name + "_eq", L, S
275-
)
276-
# if I do not squeeze, it will throw the error: condition must be a scalar tensor
277-
condition = impl.squeeze.squeeze(
278-
ctx, target, source_ir, name + "_unsqueeze", need_mask, 0
279-
)
280-
if_layer = ctx.net.add_if_conditional()
281-
if_layer.set_condition(condition)
282-
true_input = impl.elementwise.add(
283-
ctx, target, source_ir, name + "_attn_bias_add", mm, attn_mask
284-
)
285-
false_input = mm
286-
output_layer = if_layer.add_output(
287-
true_input.get_output(0), false_input.get_output(0)
288-
)
289-
scaled_add_attn_bias = output_layer.get_output(0)
177+
# This need_mask determines if we want to use the causal mask or not
178+
# When KV caching is enabled, L = 1 and != S. In this case, we shouldn't use the causal mask.
179+
# So need_mask will be all False values in this case.
180+
# TODO: Implement more general case where L != 1 and S != L
181+
need_mask = impl.elementwise.eq(ctx, target, source_ir, name + "_eq", L, S)
182+
temp_mask = impl.elementwise.logical_and(
183+
ctx, target, source_ir, name + "_logical_and", need_mask, temp_mask
184+
)
185+
temp_mask_casted = cast_trt_tensor(
186+
ctx, temp_mask, query_dtype, name + "_casted_bool", target, source_ir
187+
)
290188

189+
one_minus_temp_mask = impl.elementwise.sub(
190+
ctx,
191+
target,
192+
source_ir,
193+
name + "_one_minus_temp_mask",
194+
1.0,
195+
temp_mask_casted,
196+
)
197+
attn_bias = impl.unary.log(
198+
ctx, target, source_ir, name + "_log", one_minus_temp_mask
199+
)
200+
scaled_add_attn_bias = impl.elementwise.add(
201+
ctx, target, source_ir, name + "_attn_bias_add", mm, attn_bias
202+
)
291203
softmax = impl.normalization.softmax(
292204
ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False
293205
)

0 commit comments

Comments
 (0)