From 6605ac9d012a272f0b0fed05f0f04da6073fcdc2 Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Tue, 2 Jul 2024 17:25:46 +0300 Subject: [PATCH 001/171] Backend optimized, naive search setup --- examples/dev/LLaMAMLP.py | 28 + examples/dev/backward-log.out | 1305 ++ examples/dev/backward_trc | 96 + examples/dev/backward_trc.pdf | Bin 0 -> 17024 bytes examples/dev/backward_trc_final | 63 + examples/dev/backward_trc_final.pdf | Bin 0 -> 14194 bytes examples/dev/backward_trc_fusion | 63 + examples/dev/backward_trc_fusion.pdf | Bin 0 -> 14244 bytes examples/dev/forward_trc | 45 + examples/dev/forward_trc.dot | 4518 ++++ examples/dev/forward_trc.pdf | Bin 0 -> 13368 bytes examples/dev/litGPT.out | 25381 +++++++++++++++++++++++ examples/dev/litGPT.py | 16 + examples/dev/log.out | 795 + examples/dev/my_graph.png | Bin 0 -> 144 bytes examples/dev/simple.py | 26 + examples/dev/simple_log.out | 132 + thunder/backend_optimizer/optimizer.py | 80 + thunder/common.py | 3 +- thunder/core/jit_ext.py | 1 + thunder/core/prims.py | 2 +- thunder/core/transforms.py | 9 + thunder/executors/passes.py | 87 +- thunder/executors/torch_autograd.py | 19 + thunder/visualizer/__init__.py | 0 thunder/visualizer/graphviz.py | 37 + 26 files changed, 32702 insertions(+), 4 deletions(-) create mode 100644 examples/dev/LLaMAMLP.py create mode 100644 examples/dev/backward-log.out create mode 100644 examples/dev/backward_trc create mode 100644 examples/dev/backward_trc.pdf create mode 100644 examples/dev/backward_trc_final create mode 100644 examples/dev/backward_trc_final.pdf create mode 100644 examples/dev/backward_trc_fusion create mode 100644 examples/dev/backward_trc_fusion.pdf create mode 100644 examples/dev/forward_trc create mode 100644 examples/dev/forward_trc.dot create mode 100644 examples/dev/forward_trc.pdf create mode 100644 examples/dev/litGPT.out create mode 100644 examples/dev/litGPT.py create mode 100644 examples/dev/log.out create mode 100644 examples/dev/my_graph.png create mode 100644 examples/dev/simple.py create mode 100644 examples/dev/simple_log.out create mode 100644 thunder/backend_optimizer/optimizer.py create mode 100644 thunder/visualizer/__init__.py create mode 100644 thunder/visualizer/graphviz.py diff --git a/examples/dev/LLaMAMLP.py b/examples/dev/LLaMAMLP.py new file mode 100644 index 0000000000..9a3418a447 --- /dev/null +++ b/examples/dev/LLaMAMLP.py @@ -0,0 +1,28 @@ +import torch +import thunder + +class LLaMAMLP(torch.nn.Module): + def __init__(self, n_embd, intermediate_size) -> None: + super().__init__() + self.fc_1 = torch.nn.Linear(n_embd, intermediate_size, bias=False) + self.fc_2 = torch.nn.Linear(n_embd, intermediate_size, bias=False) + self.proj = torch.nn.Linear(intermediate_size, n_embd, bias=False) + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_fc_1 = self.fc_1(x) + x_fc_2 = self.fc_2(x) + x = torch.nn.functional.silu(x_fc_1) * x_fc_2 + return self.proj(x) + +with torch.device('cuda'): + model = LLaMAMLP(4096, 11008) + x = torch.randn(2, 2048, 4096, requires_grad=True) + + jmodel = thunder.jit(model) + + ans = jmodel(x) + print('---------------------------------------------- all traces') + for t in thunder.last_traces(jmodel): + print(t) + print('##############################################') + print('---------------------------------------------- ans') + print(ans) diff --git a/examples/dev/backward-log.out b/examples/dev/backward-log.out new file mode 100644 index 0000000000..f74f642d6f --- /dev/null +++ b/examples/dev/backward-log.out @@ -0,0 +1,1305 @@ +============================================ START: LABEL default +============================================ START: computation_trc split_forward_backward +# Constructed by Dead Code Elimination (took 0 milliseconds) +import thunder +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): + # x: "cuda:0 f32[2, 2048, 4096]" + # t_fc_1_weight: "cuda:0 f32[11008, 4096]" + # t_fc_2_weight: "cuda:0 f32[11008, 4096]" + # t_proj_weight: "cuda:0 f32[4096, 11008]" + + # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) + x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" + # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" + x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" + # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" + + # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 + a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" + # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" + # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" + # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" + # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" + # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" + result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" + # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" + + # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) + t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" + # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" + return t18 +============================================ END: computation_trc split_forward_backward +============================================ START: primal_trace sort_data_parallel_syncs +# Constructed by Dead Code Elimination (took 0 milliseconds) +import thunder +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): + # x: "cuda:0 f32[2, 2048, 4096]" + # t_fc_1_weight: "cuda:0 f32[11008, 4096]" + # t_fc_2_weight: "cuda:0 f32[11008, 4096]" + # t_proj_weight: "cuda:0 f32[4096, 11008]" + + # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) + x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" + # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" + x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" + # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" + + # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 + a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" + # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" + # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" + # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" + # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" + # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" + result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" + # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" + + # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) + t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" + # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" + return t18 +============================================ END: primal_trace sort_data_parallel_syncs +============================================ START: augmented_forward_pass +result +t8 +env +{'a': VJPDual(primal=t6, + residuals=((t0, False), None, ([], [t3], [], [t5], [t5, t0]))), + 'result': VJPDual(primal=t7, residuals=((t6, t1), None, ([t1, t6],))), + 't18': VJPDual(primal=t8, + residuals=((t7, t_proj_weight, None), + None, + ([t_proj_weight, t7],))), + 't_fc_1_weight': VJPDual(primal=t_fc_1_weight, residuals=()), + 't_fc_2_weight': VJPDual(primal=t_fc_2_weight, residuals=()), + 't_proj_weight': VJPDual(primal=t_proj_weight, residuals=()), + 'x': VJPDual(primal=x, residuals=()), + 'x_fc_1': VJPDual(primal=t0, + residuals=((x, t_fc_1_weight, None), + None, + ([t_fc_1_weight, x],))), + 'x_fc_2': VJPDual(primal=t1, + residuals=((x, t_fc_2_weight, None), + None, + ([t_fc_2_weight, x],)))} +============================================ END: augmented_forward_pass +============================================ START: primal_trace forward_and_backward_from_trace +# Constructed by Augmented forward pass +import thunder +import thunder.core.prims as prims +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): + # x: "cuda:0 f32[2, 2048, 4096]" + # t_fc_1_weight: "cuda:0 f32[11008, 4096]" + # t_fc_2_weight: "cuda:0 f32[11008, 4096]" + # t_proj_weight: "cuda:0 f32[4096, 11008]" + t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" + t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" + t4 = ltorch.add(1.0, t3, alpha=None) # t4: "cuda:0 f32[2, 2048, 11008]" + # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" + t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" + t6 = ltorch.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" + # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" + t7 = ltorch.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" + # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" + t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) +============================================ END: primal_trace forward_and_backward_from_trace +============================================ START: before forward_trc transform_for_execution +# Constructed by Augmented forward pass +import thunder +import thunder.core.prims as prims +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): + # x: "cuda:0 f32[2, 2048, 4096]" + # t_fc_1_weight: "cuda:0 f32[11008, 4096]" + # t_fc_2_weight: "cuda:0 f32[11008, 4096]" + # t_proj_weight: "cuda:0 f32[4096, 11008]" + t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" + t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" + t4 = ltorch.add(1.0, t3, alpha=None) # t4: "cuda:0 f32[2, 2048, 11008]" + # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" + t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" + t6 = ltorch.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" + # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" + t7 = ltorch.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" + # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" + t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) +============================================ END: after forward_trc transform_for_execution +============================================ START: LABEL forward_trc +============================================ START: LABEL backward_trc +============================================ START: before _transform_for_operator_executor_execution +# Constructed by Dead Code Elimination (took 0 milliseconds) +import thunder +import thunder.core.prims as prims +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def backward_fn(saved_for_backward, cotangents): + # saved_for_backward: "Collection" + # cotangents: "Collection" + C0, _, = saved_for_backward + t9, = cotangents + x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight, = C0 + t25 = ltorch.reshape(t9, -1, 4096) # t25: "cuda:0 f32[4096, 4096]" + # t25 = prims.reshape(t9, (4096, 4096)) # t25: "cuda:0 f32[4096, 4096]" + t26 = ltorch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" + # t26 = prims.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" + t27 = ltorch.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" + # t27 = prims.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" + t28 = ltorch.reshape(t9, -1, 4096) # t28: "cuda:0 f32[4096, 4096]" + # t28 = prims.reshape(t9, (4096, 4096)) # t28: "cuda:0 f32[4096, 4096]" + t29 = prims.transpose(t28, (1, 0)) # t29: "cuda:0 f32[4096, 4096]" + t30 = ltorch.reshape(t7, -1, 11008) # t30: "cuda:0 f32[4096, 11008]" + # t30 = prims.reshape(t7, (4096, 11008)) # t30: "cuda:0 f32[4096, 11008]" + t31 = ltorch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" + # t31 = prims.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" + t32 = ltorch.mul(t1, t27) # t32: "cuda:0 f32[2, 2048, 11008]" + # t32 = prims.mul(t1, t27) # t32: "cuda:0 f32[2, 2048, 11008]" + t33 = ltorch.mul(t6, t27) # t33: "cuda:0 f32[2, 2048, 11008]" + # t33 = prims.mul(t6, t27) # t33: "cuda:0 f32[2, 2048, 11008]" + t34 = ltorch.mul(t5, t32) # t34: "cuda:0 f32[2, 2048, 11008]" + # t34 = prims.mul(t5, t32) # t34: "cuda:0 f32[2, 2048, 11008]" + t35 = ltorch.mul(t0, t32) # t35: "cuda:0 f32[2, 2048, 11008]" + # t35 = prims.mul(t0, t32) # t35: "cuda:0 f32[2, 2048, 11008]" + t36 = ltorch.neg(t35) # t36: "cuda:0 f32[2, 2048, 11008]" + # t36 = prims.neg(t35) # t36: "cuda:0 f32[2, 2048, 11008]" + t37 = ltorch.mul(t36, t5) # t37: "cuda:0 f32[2, 2048, 11008]" + # t37 = prims.mul(t36, t5) # t37: "cuda:0 f32[2, 2048, 11008]" + t38 = ltorch.mul(t37, t5) # t38: "cuda:0 f32[2, 2048, 11008]" + # t38 = prims.mul(t37, t5) # t38: "cuda:0 f32[2, 2048, 11008]" + t39 = ltorch.mul(t38, t3) # t39: "cuda:0 f32[2, 2048, 11008]" + # t39 = prims.mul(t38, t3) # t39: "cuda:0 f32[2, 2048, 11008]" + t40 = ltorch.neg(t39) # t40: "cuda:0 f32[2, 2048, 11008]" + # t40 = prims.neg(t39) # t40: "cuda:0 f32[2, 2048, 11008]" + t41 = prims.add(t34, t40) # t41: "cuda:0 f32[2, 2048, 11008]" + t42 = ltorch.reshape(t33, -1, 11008) # t42: "cuda:0 f32[4096, 11008]" + # t42 = prims.reshape(t33, (4096, 11008)) # t42: "cuda:0 f32[4096, 11008]" + t43 = ltorch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" + # t43 = prims.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" + t44 = ltorch.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" + # t44 = prims.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" + t45 = ltorch.reshape(t33, -1, 11008) # t45: "cuda:0 f32[4096, 11008]" + # t45 = prims.reshape(t33, (4096, 11008)) # t45: "cuda:0 f32[4096, 11008]" + t46 = prims.transpose(t45, (1, 0)) # t46: "cuda:0 f32[11008, 4096]" + t47 = ltorch.reshape(x, -1, 4096) # t47: "cuda:0 f32[4096, 4096]" + # t47 = prims.reshape(x, (4096, 4096)) # t47: "cuda:0 f32[4096, 4096]" + t48 = ltorch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" + # t48 = prims.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" + t49 = ltorch.reshape(t41, -1, 11008) # t49: "cuda:0 f32[4096, 11008]" + # t49 = prims.reshape(t41, (4096, 11008)) # t49: "cuda:0 f32[4096, 11008]" + t50 = ltorch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" + # t50 = prims.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" + t51 = ltorch.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" + # t51 = prims.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" + t52 = ltorch.reshape(t41, -1, 11008) # t52: "cuda:0 f32[4096, 11008]" + # t52 = prims.reshape(t41, (4096, 11008)) # t52: "cuda:0 f32[4096, 11008]" + t53 = prims.transpose(t52, (1, 0)) # t53: "cuda:0 f32[11008, 4096]" + t54 = ltorch.reshape(x, -1, 4096) # t54: "cuda:0 f32[4096, 4096]" + # t54 = prims.reshape(x, (4096, 4096)) # t54: "cuda:0 f32[4096, 4096]" + t55 = ltorch.matmul(t53, t54) # t55: "cuda:0 f32[11008, 4096]" + # t55 = prims.matmul(t53, t54) # t55: "cuda:0 f32[11008, 4096]" + t56 = prims.add(t44, t51) # t56: "cuda:0 f32[2, 2048, 4096]" + return (t56, t55, t48, t31) +============================================ END: before _transform_for_operator_executor_execution +============================================ START: after _transform_for_operator_executor_execution +# Constructed by Transform for operator executor execution (took 1 milliseconds) +import thunder +import thunder.core.prims as prims +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def backward_fn(saved_for_backward, cotangents): + # saved_for_backward: "Collection" + # cotangents: "Collection" + C0, _, = saved_for_backward + t9, = cotangents + x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight, = C0 + t25 = ltorch.reshape(t9, -1, 4096) # t25: "cuda:0 f32[4096, 4096]" + # t25 = prims.reshape(t9, (4096, 4096)) # t25: "cuda:0 f32[4096, 4096]" + t26 = torch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" + # t26 = ltorch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" + # t26 = prims.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" + t27 = ltorch.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" + # t27 = prims.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" + t28 = ltorch.reshape(t9, -1, 4096) # t28: "cuda:0 f32[4096, 4096]" + # t28 = prims.reshape(t9, (4096, 4096)) # t28: "cuda:0 f32[4096, 4096]" + t29 = prims.transpose(t28, (1, 0)) # t29: "cuda:0 f32[4096, 4096]" + t30 = ltorch.reshape(t7, -1, 11008) # t30: "cuda:0 f32[4096, 11008]" + # t30 = prims.reshape(t7, (4096, 11008)) # t30: "cuda:0 f32[4096, 11008]" + t31 = torch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" + # t31 = ltorch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" + # t31 = prims.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" + t32 = ltorch.mul(t1, t27) # t32: "cuda:0 f32[2, 2048, 11008]" + # t32 = prims.mul(t1, t27) # t32: "cuda:0 f32[2, 2048, 11008]" + t33 = ltorch.mul(t6, t27) # t33: "cuda:0 f32[2, 2048, 11008]" + # t33 = prims.mul(t6, t27) # t33: "cuda:0 f32[2, 2048, 11008]" + t34 = ltorch.mul(t5, t32) # t34: "cuda:0 f32[2, 2048, 11008]" + # t34 = prims.mul(t5, t32) # t34: "cuda:0 f32[2, 2048, 11008]" + t35 = ltorch.mul(t0, t32) # t35: "cuda:0 f32[2, 2048, 11008]" + # t35 = prims.mul(t0, t32) # t35: "cuda:0 f32[2, 2048, 11008]" + t36 = ltorch.neg(t35) # t36: "cuda:0 f32[2, 2048, 11008]" + # t36 = prims.neg(t35) # t36: "cuda:0 f32[2, 2048, 11008]" + t37 = ltorch.mul(t36, t5) # t37: "cuda:0 f32[2, 2048, 11008]" + # t37 = prims.mul(t36, t5) # t37: "cuda:0 f32[2, 2048, 11008]" + t38 = ltorch.mul(t37, t5) # t38: "cuda:0 f32[2, 2048, 11008]" + # t38 = prims.mul(t37, t5) # t38: "cuda:0 f32[2, 2048, 11008]" + t39 = ltorch.mul(t38, t3) # t39: "cuda:0 f32[2, 2048, 11008]" + # t39 = prims.mul(t38, t3) # t39: "cuda:0 f32[2, 2048, 11008]" + t40 = ltorch.neg(t39) # t40: "cuda:0 f32[2, 2048, 11008]" + # t40 = prims.neg(t39) # t40: "cuda:0 f32[2, 2048, 11008]" + t41 = prims.add(t34, t40) # t41: "cuda:0 f32[2, 2048, 11008]" + t42 = ltorch.reshape(t33, -1, 11008) # t42: "cuda:0 f32[4096, 11008]" + # t42 = prims.reshape(t33, (4096, 11008)) # t42: "cuda:0 f32[4096, 11008]" + t43 = torch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" + # t43 = ltorch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" + # t43 = prims.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" + t44 = ltorch.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" + # t44 = prims.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" + t45 = ltorch.reshape(t33, -1, 11008) # t45: "cuda:0 f32[4096, 11008]" + # t45 = prims.reshape(t33, (4096, 11008)) # t45: "cuda:0 f32[4096, 11008]" + t46 = prims.transpose(t45, (1, 0)) # t46: "cuda:0 f32[11008, 4096]" + t47 = ltorch.reshape(x, -1, 4096) # t47: "cuda:0 f32[4096, 4096]" + # t47 = prims.reshape(x, (4096, 4096)) # t47: "cuda:0 f32[4096, 4096]" + t48 = torch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" + # t48 = ltorch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" + # t48 = prims.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" + t49 = ltorch.reshape(t41, -1, 11008) # t49: "cuda:0 f32[4096, 11008]" + # t49 = prims.reshape(t41, (4096, 11008)) # t49: "cuda:0 f32[4096, 11008]" + t50 = torch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" + # t50 = ltorch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" + # t50 = prims.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" + t51 = ltorch.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" + # t51 = prims.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" + t52 = ltorch.reshape(t41, -1, 11008) # t52: "cuda:0 f32[4096, 11008]" + # t52 = prims.reshape(t41, (4096, 11008)) # t52: "cuda:0 f32[4096, 11008]" + t53 = prims.transpose(t52, (1, 0)) # t53: "cuda:0 f32[11008, 4096]" + t54 = ltorch.reshape(x, -1, 4096) # t54: "cuda:0 f32[4096, 4096]" + # t54 = prims.reshape(x, (4096, 4096)) # t54: "cuda:0 f32[4096, 4096]" + t55 = torch.matmul(t53, t54) # t55: "cuda:0 f32[11008, 4096]" + # t55 = ltorch.matmul(t53, t54) # t55: "cuda:0 f32[11008, 4096]" + # t55 = prims.matmul(t53, t54) # t55: "cuda:0 f32[11008, 4096]" + t56 = prims.add(t44, t51) # t56: "cuda:0 f32[2, 2048, 4096]" + return (t56, t55, t48, t31) +============================================ GRAPH: _transform_for_operator_executor_execution +graph roots: 0, 1, +traversal nodes: +node ID 0 : [# saved_for_backward: "Collection"] + parents ids: + children ids: 2, +node ID 1 : [# cotangents: "Collection"] + parents ids: + children ids: 3, +node ID 2 : [C0, _, = saved_for_backward] + parents ids: 0, + children ids: 4, +node ID 3 : [t9, = cotangents] + parents ids: 1, + children ids: 8, 5, +node ID 4 : [x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight, = C0] + parents ids: 2, + children ids: 34, 6, 10, 12, 13, 14, 15, 17, 18, 19, 23, 27, 30, +node ID 8 : [t28 = ltorch.reshape(t9, -1, 4096) # t28: "cuda:0 f32[4096, 4096]" + # t28 = prims.reshape(t9, (4096, 4096)) # t28: "cuda:0 f32[4096, 4096]"] + parents ids: 3, + children ids: 9, +node ID 5 : [t25 = ltorch.reshape(t9, -1, 4096) # t25: "cuda:0 f32[4096, 4096]" + # t25 = prims.reshape(t9, (4096, 4096)) # t25: "cuda:0 f32[4096, 4096]"] + parents ids: 3, + children ids: 6, +node ID 34 : [t54 = ltorch.reshape(x, -1, 4096) # t54: "cuda:0 f32[4096, 4096]" + # t54 = prims.reshape(x, (4096, 4096)) # t54: "cuda:0 f32[4096, 4096]"] + parents ids: 4, + children ids: 35, +node ID 6 : [t26 = ltorch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" + # t26 = prims.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]"] + parents ids: 4, 5, + children ids: 7, +node ID 10 : [t30 = ltorch.reshape(t7, -1, 11008) # t30: "cuda:0 f32[4096, 11008]" + # t30 = prims.reshape(t7, (4096, 11008)) # t30: "cuda:0 f32[4096, 11008]"] + parents ids: 4, + children ids: 11, +node ID 12 : [t32 = ltorch.mul(t1, t27) # t32: "cuda:0 f32[2, 2048, 11008]" + # t32 = prims.mul(t1, t27) # t32: "cuda:0 f32[2, 2048, 11008]"] + parents ids: 4, 7, + children ids: 14, 15, +node ID 13 : [t33 = ltorch.mul(t6, t27) # t33: "cuda:0 f32[2, 2048, 11008]" + # t33 = prims.mul(t6, t27) # t33: "cuda:0 f32[2, 2048, 11008]"] + parents ids: 4, 7, + children ids: 25, 22, +node ID 14 : [t34 = ltorch.mul(t5, t32) # t34: "cuda:0 f32[2, 2048, 11008]" + # t34 = prims.mul(t5, t32) # t34: "cuda:0 f32[2, 2048, 11008]"] + parents ids: 4, 12, + children ids: 21, +node ID 15 : [t35 = ltorch.mul(t0, t32) # t35: "cuda:0 f32[2, 2048, 11008]" + # t35 = prims.mul(t0, t32) # t35: "cuda:0 f32[2, 2048, 11008]"] + parents ids: 4, 12, + children ids: 16, +node ID 17 : [t37 = ltorch.mul(t36, t5) # t37: "cuda:0 f32[2, 2048, 11008]" + # t37 = prims.mul(t36, t5) # t37: "cuda:0 f32[2, 2048, 11008]"] + parents ids: 16, 4, + children ids: 18, +node ID 18 : [t38 = ltorch.mul(t37, t5) # t38: "cuda:0 f32[2, 2048, 11008]" + # t38 = prims.mul(t37, t5) # t38: "cuda:0 f32[2, 2048, 11008]"] + parents ids: 17, 4, + children ids: 19, +node ID 19 : [t39 = ltorch.mul(t38, t3) # t39: "cuda:0 f32[2, 2048, 11008]" + # t39 = prims.mul(t38, t3) # t39: "cuda:0 f32[2, 2048, 11008]"] + parents ids: 18, 4, + children ids: 20, +node ID 23 : [t43 = ltorch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" + # t43 = prims.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]"] + parents ids: 4, 22, + children ids: 24, +node ID 27 : [t47 = ltorch.reshape(x, -1, 4096) # t47: "cuda:0 f32[4096, 4096]" + # t47 = prims.reshape(x, (4096, 4096)) # t47: "cuda:0 f32[4096, 4096]"] + parents ids: 4, + children ids: 28, +node ID 30 : [t50 = ltorch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" + # t50 = prims.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]"] + parents ids: 4, 29, + children ids: 31, +node ID 9 : [t29 = prims.transpose(t28, (1, 0)) # t29: "cuda:0 f32[4096, 4096]"] + parents ids: 8, + children ids: 11, +node ID 35 : [t55 = ltorch.matmul(t53, t54) # t55: "cuda:0 f32[11008, 4096]" + # t55 = prims.matmul(t53, t54) # t55: "cuda:0 f32[11008, 4096]"] + parents ids: 33, 34, + children ids: 37, +node ID 7 : [t27 = ltorch.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" + # t27 = prims.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]"] + parents ids: 6, + children ids: 12, 13, +node ID 11 : [t31 = ltorch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" + # t31 = prims.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]"] + parents ids: 9, 10, + children ids: 37, +node ID 25 : [t45 = ltorch.reshape(t33, -1, 11008) # t45: "cuda:0 f32[4096, 11008]" + # t45 = prims.reshape(t33, (4096, 11008)) # t45: "cuda:0 f32[4096, 11008]"] + parents ids: 13, + children ids: 26, +node ID 22 : [t42 = ltorch.reshape(t33, -1, 11008) # t42: "cuda:0 f32[4096, 11008]" + # t42 = prims.reshape(t33, (4096, 11008)) # t42: "cuda:0 f32[4096, 11008]"] + parents ids: 13, + children ids: 23, +node ID 21 : [t41 = prims.add(t34, t40) # t41: "cuda:0 f32[2, 2048, 11008]"] + parents ids: 20, 14, + children ids: 32, 29, +node ID 16 : [t36 = ltorch.neg(t35) # t36: "cuda:0 f32[2, 2048, 11008]" + # t36 = prims.neg(t35) # t36: "cuda:0 f32[2, 2048, 11008]"] + parents ids: 15, + children ids: 17, +node ID 20 : [t40 = ltorch.neg(t39) # t40: "cuda:0 f32[2, 2048, 11008]" + # t40 = prims.neg(t39) # t40: "cuda:0 f32[2, 2048, 11008]"] + parents ids: 19, + children ids: 21, +node ID 24 : [t44 = ltorch.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" + # t44 = prims.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]"] + parents ids: 23, + children ids: 36, +node ID 28 : [t48 = ltorch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" + # t48 = prims.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]"] + parents ids: 26, 27, + children ids: 37, +node ID 31 : [t51 = ltorch.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" + # t51 = prims.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]"] + parents ids: 30, + children ids: 36, +node ID 37 : [return (t56, t55, t48, t31)] + parents ids: 11, 35, 36, 28, + children ids: +node ID 26 : [t46 = prims.transpose(t45, (1, 0)) # t46: "cuda:0 f32[11008, 4096]"] + parents ids: 25, + children ids: 28, +node ID 32 : [t52 = ltorch.reshape(t41, -1, 11008) # t52: "cuda:0 f32[4096, 11008]" + # t52 = prims.reshape(t41, (4096, 11008)) # t52: "cuda:0 f32[4096, 11008]"] + parents ids: 21, + children ids: 33, +node ID 29 : [t49 = ltorch.reshape(t41, -1, 11008) # t49: "cuda:0 f32[4096, 11008]" + # t49 = prims.reshape(t41, (4096, 11008)) # t49: "cuda:0 f32[4096, 11008]"] + parents ids: 21, + children ids: 30, +node ID 36 : [t56 = prims.add(t44, t51) # t56: "cuda:0 f32[2, 2048, 4096]"] + parents ids: 24, 31, + children ids: 37, +node ID 33 : [t53 = prims.transpose(t52, (1, 0)) # t53: "cuda:0 f32[11008, 4096]"] + parents ids: 32, + children ids: 35, + +============================================ END: after _transform_for_operator_executor_execution +============================================ START: after fusion_pass +# Constructed by Fusion (took 3 milliseconds) +import thunder +import thunder.core.prims as prims +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def backward_fn(saved_for_backward, cotangents): + # saved_for_backward: "Collection" + # cotangents: "Collection" + C0, _, = saved_for_backward + t9, = cotangents + x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight, = C0 + t25 = ltorch.reshape(t9, -1, 4096) # t25: "cuda:0 f32[4096, 4096]" + # t25 = prims.reshape(t9, (4096, 4096)) # t25: "cuda:0 f32[4096, 4096]" + t29 = prims.transpose(t25, (1, 0)) # t29: "cuda:0 f32[4096, 4096]" + t30 = ltorch.reshape(t7, -1, 11008) # t30: "cuda:0 f32[4096, 11008]" + # t30 = prims.reshape(t7, (4096, 11008)) # t30: "cuda:0 f32[4096, 11008]" + t47 = ltorch.reshape(x, -1, 4096) # t47: "cuda:0 f32[4096, 4096]" + # t47 = prims.reshape(x, (4096, 4096)) # t47: "cuda:0 f32[4096, 4096]" + t26 = torch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" + # t26 = ltorch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" + # t26 = prims.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" + t31 = torch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" + # t31 = ltorch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" + # t31 = prims.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" + t27 = ltorch.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" + # t27 = prims.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" + [t33, t41] = nvFusion0(t0, t1, t27, t3, t5, t6) + # t32 = prims.mul(t1, t27) # t32: "cuda:0 f32[2, 2048, 11008]" + # t33 = prims.mul(t6, t27) # t33: "cuda:0 f32[2, 2048, 11008]" + # t34 = prims.mul(t5, t32) # t34: "cuda:0 f32[2, 2048, 11008]" + # t35 = prims.mul(t0, t32) # t35: "cuda:0 f32[2, 2048, 11008]" + # t36 = prims.neg(t35) # t36: "cuda:0 f32[2, 2048, 11008]" + # t37 = prims.mul(t36, t5) # t37: "cuda:0 f32[2, 2048, 11008]" + # t38 = prims.mul(t37, t5) # t38: "cuda:0 f32[2, 2048, 11008]" + # t39 = prims.mul(t38, t3) # t39: "cuda:0 f32[2, 2048, 11008]" + # t40 = prims.neg(t39) # t40: "cuda:0 f32[2, 2048, 11008]" + # t41 = prims.add(t34, t40) # t41: "cuda:0 f32[2, 2048, 11008]" + t42 = ltorch.reshape(t33, -1, 11008) # t42: "cuda:0 f32[4096, 11008]" + # t42 = prims.reshape(t33, (4096, 11008)) # t42: "cuda:0 f32[4096, 11008]" + t46 = prims.transpose(t42, (1, 0)) # t46: "cuda:0 f32[11008, 4096]" + t49 = ltorch.reshape(t41, -1, 11008) # t49: "cuda:0 f32[4096, 11008]" + # t49 = prims.reshape(t41, (4096, 11008)) # t49: "cuda:0 f32[4096, 11008]" + t53 = prims.transpose(t49, (1, 0)) # t53: "cuda:0 f32[11008, 4096]" + t55 = torch.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]" + # t55 = ltorch.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]" + # t55 = prims.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]" + t48 = torch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" + # t48 = ltorch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" + # t48 = prims.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" + t50 = torch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" + # t50 = ltorch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" + # t50 = prims.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" + t43 = torch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" + # t43 = ltorch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" + # t43 = prims.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" + t44 = ltorch.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" + # t44 = prims.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" + t51 = ltorch.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" + # t51 = prims.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" + [t56] = nvFusion1(t44, t51) + # t56 = prims.add(t44, t51) # t56: "cuda:0 f32[2, 2048, 4096]" + return (t56, t55, t48, t31) +============================================ GRAPH: fusion_pass +graph roots: 0, 1, +traversal nodes: +node ID 0 : [# saved_for_backward: "Collection"] + parents ids: + children ids: 2, +node ID 1 : [# cotangents: "Collection"] + parents ids: + children ids: 3, +node ID 2 : [C0, _, = saved_for_backward] + parents ids: 0, + children ids: 4, +node ID 3 : [t9, = cotangents] + parents ids: 1, + children ids: 5, +node ID 4 : [x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight, = C0] + parents ids: 2, + children ids: 7, 8, 9, 12, 19, 20, +node ID 5 : [t25 = ltorch.reshape(t9, -1, 4096) # t25: "cuda:0 f32[4096, 4096]" + # t25 = prims.reshape(t9, (4096, 4096)) # t25: "cuda:0 f32[4096, 4096]"] + parents ids: 3, + children ids: 9, 6, +node ID 7 : [t30 = ltorch.reshape(t7, -1, 11008) # t30: "cuda:0 f32[4096, 11008]" + # t30 = prims.reshape(t7, (4096, 11008)) # t30: "cuda:0 f32[4096, 11008]"] + parents ids: 4, + children ids: 10, +node ID 8 : [t47 = ltorch.reshape(x, -1, 4096) # t47: "cuda:0 f32[4096, 4096]" + # t47 = prims.reshape(x, (4096, 4096)) # t47: "cuda:0 f32[4096, 4096]"] + parents ids: 4, + children ids: 17, 18, +node ID 9 : [t26 = torch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" + # t26 = ltorch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" + # t26 = prims.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]"] + parents ids: 4, 5, + children ids: 11, +node ID 12 : [[t33, t41] = nvFusion0(t0, t1, t27, t3, t5, t6) + # t32 = prims.mul(t1, t27) # t32: "cuda:0 f32[2, 2048, 11008]" + # t33 = prims.mul(t6, t27) # t33: "cuda:0 f32[2, 2048, 11008]" + # t34 = prims.mul(t5, t32) # t34: "cuda:0 f32[2, 2048, 11008]" + # t35 = prims.mul(t0, t32) # t35: "cuda:0 f32[2, 2048, 11008]" + # t36 = prims.neg(t35) # t36: "cuda:0 f32[2, 2048, 11008]" + # t37 = prims.mul(t36, t5) # t37: "cuda:0 f32[2, 2048, 11008]" + # t38 = prims.mul(t37, t5) # t38: "cuda:0 f32[2, 2048, 11008]" + # t39 = prims.mul(t38, t3) # t39: "cuda:0 f32[2, 2048, 11008]" + # t40 = prims.neg(t39) # t40: "cuda:0 f32[2, 2048, 11008]" + # t41 = prims.add(t34, t40) # t41: "cuda:0 f32[2, 2048, 11008]"] + parents ids: 11, 4, + children ids: 13, 15, +node ID 19 : [t50 = torch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" + # t50 = ltorch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" + # t50 = prims.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]"] + parents ids: 4, 15, + children ids: 22, +node ID 20 : [t43 = torch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" + # t43 = ltorch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" + # t43 = prims.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]"] + parents ids: 4, 13, + children ids: 21, +node ID 6 : [t29 = prims.transpose(t25, (1, 0)) # t29: "cuda:0 f32[4096, 4096]"] + parents ids: 5, + children ids: 10, +node ID 10 : [t31 = torch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" + # t31 = ltorch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" + # t31 = prims.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]"] + parents ids: 6, 7, + children ids: 24, +node ID 17 : [t55 = torch.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]" + # t55 = ltorch.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]" + # t55 = prims.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]"] + parents ids: 16, 8, + children ids: 24, +node ID 18 : [t48 = torch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" + # t48 = ltorch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" + # t48 = prims.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]"] + parents ids: 8, 14, + children ids: 24, +node ID 11 : [t27 = ltorch.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" + # t27 = prims.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]"] + parents ids: 9, + children ids: 12, +node ID 13 : [t42 = ltorch.reshape(t33, -1, 11008) # t42: "cuda:0 f32[4096, 11008]" + # t42 = prims.reshape(t33, (4096, 11008)) # t42: "cuda:0 f32[4096, 11008]"] + parents ids: 12, + children ids: 20, 14, +node ID 15 : [t49 = ltorch.reshape(t41, -1, 11008) # t49: "cuda:0 f32[4096, 11008]" + # t49 = prims.reshape(t41, (4096, 11008)) # t49: "cuda:0 f32[4096, 11008]"] + parents ids: 12, + children ids: 16, 19, +node ID 22 : [t51 = ltorch.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" + # t51 = prims.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]"] + parents ids: 19, + children ids: 23, +node ID 21 : [t44 = ltorch.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" + # t44 = prims.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]"] + parents ids: 20, + children ids: 23, +node ID 24 : [return (t56, t55, t48, t31)] + parents ids: 17, 18, 10, 23, + children ids: +node ID 14 : [t46 = prims.transpose(t42, (1, 0)) # t46: "cuda:0 f32[11008, 4096]"] + parents ids: 13, + children ids: 18, +node ID 16 : [t53 = prims.transpose(t49, (1, 0)) # t53: "cuda:0 f32[11008, 4096]"] + parents ids: 15, + children ids: 17, +node ID 23 : [[t56] = nvFusion1(t44, t51) + # t56 = prims.add(t44, t51) # t56: "cuda:0 f32[2, 2048, 4096]"] + parents ids: 21, 22, + children ids: 24, + +============================================ END: after fusion_pass +============================================ START: after _transform_for_operator_executor_execution (always) +# Constructed by Transform for operator executor execution (took 1 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def backward_fn(saved_for_backward, cotangents): + # saved_for_backward: "Collection" + # cotangents: "Collection" + C0, _, = saved_for_backward + t9, = cotangents + x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight, = C0 + t25 = torch.reshape(t9, (-1, 4096)) # t25: "cuda:0 f32[4096, 4096]" + # t25 = ltorch.reshape(t9, (-1, 4096)) # t25: "cuda:0 f32[4096, 4096]" + # t25 = prims.reshape(t9, (4096, 4096)) # t25: "cuda:0 f32[4096, 4096]" + t29 = torch.permute(t25, (1, 0)) # t29: "cuda:0 f32[4096, 4096]" + # t29 = ltorch.permute(t25, (1, 0)) # t29: "cuda:0 f32[4096, 4096]" + # t29 = prims.transpose(t25, (1, 0)) # t29: "cuda:0 f32[4096, 4096]" + t30 = torch.reshape(t7, (-1, 11008)) # t30: "cuda:0 f32[4096, 11008]" + # t30 = ltorch.reshape(t7, (-1, 11008)) # t30: "cuda:0 f32[4096, 11008]" + # t30 = prims.reshape(t7, (4096, 11008)) # t30: "cuda:0 f32[4096, 11008]" + t47 = torch.reshape(x, (-1, 4096)) # t47: "cuda:0 f32[4096, 4096]" + # t47 = ltorch.reshape(x, (-1, 4096)) # t47: "cuda:0 f32[4096, 4096]" + # t47 = prims.reshape(x, (4096, 4096)) # t47: "cuda:0 f32[4096, 4096]" + t26 = torch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" + # t26 = ltorch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" + # t26 = prims.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" + t31 = torch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" + # t31 = ltorch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" + # t31 = prims.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" + t27 = torch.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" + # t27 = ltorch.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" + # t27 = prims.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" + [t33, t41] = nvFusion0(t0, t1, t27, t3, t5, t6) + # t32 = prims.mul(t1, t27) # t32: "cuda:0 f32[2, 2048, 11008]" + # t33 = prims.mul(t6, t27) # t33: "cuda:0 f32[2, 2048, 11008]" + # t34 = prims.mul(t5, t32) # t34: "cuda:0 f32[2, 2048, 11008]" + # t35 = prims.mul(t0, t32) # t35: "cuda:0 f32[2, 2048, 11008]" + # t36 = prims.neg(t35) # t36: "cuda:0 f32[2, 2048, 11008]" + # t37 = prims.mul(t36, t5) # t37: "cuda:0 f32[2, 2048, 11008]" + # t38 = prims.mul(t37, t5) # t38: "cuda:0 f32[2, 2048, 11008]" + # t39 = prims.mul(t38, t3) # t39: "cuda:0 f32[2, 2048, 11008]" + # t40 = prims.neg(t39) # t40: "cuda:0 f32[2, 2048, 11008]" + # t41 = prims.add(t34, t40) # t41: "cuda:0 f32[2, 2048, 11008]" + t42 = torch.reshape(t33, (-1, 11008)) # t42: "cuda:0 f32[4096, 11008]" + # t42 = ltorch.reshape(t33, (-1, 11008)) # t42: "cuda:0 f32[4096, 11008]" + # t42 = prims.reshape(t33, (4096, 11008)) # t42: "cuda:0 f32[4096, 11008]" + t46 = torch.permute(t42, (1, 0)) # t46: "cuda:0 f32[11008, 4096]" + # t46 = ltorch.permute(t42, (1, 0)) # t46: "cuda:0 f32[11008, 4096]" + # t46 = prims.transpose(t42, (1, 0)) # t46: "cuda:0 f32[11008, 4096]" + t49 = torch.reshape(t41, (-1, 11008)) # t49: "cuda:0 f32[4096, 11008]" + # t49 = ltorch.reshape(t41, (-1, 11008)) # t49: "cuda:0 f32[4096, 11008]" + # t49 = prims.reshape(t41, (4096, 11008)) # t49: "cuda:0 f32[4096, 11008]" + t53 = torch.permute(t49, (1, 0)) # t53: "cuda:0 f32[11008, 4096]" + # t53 = ltorch.permute(t49, (1, 0)) # t53: "cuda:0 f32[11008, 4096]" + # t53 = prims.transpose(t49, (1, 0)) # t53: "cuda:0 f32[11008, 4096]" + t55 = torch.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]" + # t55 = ltorch.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]" + # t55 = prims.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]" + t48 = torch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" + # t48 = ltorch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" + # t48 = prims.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" + t50 = torch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" + # t50 = ltorch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" + # t50 = prims.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" + t43 = torch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" + # t43 = ltorch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" + # t43 = prims.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" + t44 = torch.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" + # t44 = ltorch.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" + # t44 = prims.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" + t51 = torch.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" + # t51 = ltorch.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" + # t51 = prims.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" + [t56] = nvFusion1(t44, t51) + # t56 = prims.add(t44, t51) # t56: "cuda:0 f32[2, 2048, 4096]" + return (t56, t55, t48, t31) +============================================ GRAPH: fusion_pass +graph roots: 0, 1, +traversal nodes: +node ID 0 : [# saved_for_backward: "Collection"] + parents ids: + children ids: 2, +node ID 1 : [# cotangents: "Collection"] + parents ids: + children ids: 3, +node ID 2 : [C0, _, = saved_for_backward] + parents ids: 0, + children ids: 4, +node ID 3 : [t9, = cotangents] + parents ids: 1, + children ids: 5, +node ID 4 : [x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight, = C0] + parents ids: 2, + children ids: 7, 8, 9, 12, 19, 20, +node ID 5 : [t25 = torch.reshape(t9, (-1, 4096)) # t25: "cuda:0 f32[4096, 4096]" + # t25 = ltorch.reshape(t9, (-1, 4096)) # t25: "cuda:0 f32[4096, 4096]" + # t25 = prims.reshape(t9, (4096, 4096)) # t25: "cuda:0 f32[4096, 4096]"] + parents ids: 3, + children ids: 9, 6, +node ID 7 : [t30 = torch.reshape(t7, (-1, 11008)) # t30: "cuda:0 f32[4096, 11008]" + # t30 = ltorch.reshape(t7, (-1, 11008)) # t30: "cuda:0 f32[4096, 11008]" + # t30 = prims.reshape(t7, (4096, 11008)) # t30: "cuda:0 f32[4096, 11008]"] + parents ids: 4, + children ids: 10, +node ID 8 : [t47 = torch.reshape(x, (-1, 4096)) # t47: "cuda:0 f32[4096, 4096]" + # t47 = ltorch.reshape(x, (-1, 4096)) # t47: "cuda:0 f32[4096, 4096]" + # t47 = prims.reshape(x, (4096, 4096)) # t47: "cuda:0 f32[4096, 4096]"] + parents ids: 4, + children ids: 17, 18, +node ID 9 : [t26 = torch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" + # t26 = ltorch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" + # t26 = prims.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]"] + parents ids: 4, 5, + children ids: 11, +node ID 12 : [[t33, t41] = nvFusion0(t0, t1, t27, t3, t5, t6) + # t32 = prims.mul(t1, t27) # t32: "cuda:0 f32[2, 2048, 11008]" + # t33 = prims.mul(t6, t27) # t33: "cuda:0 f32[2, 2048, 11008]" + # t34 = prims.mul(t5, t32) # t34: "cuda:0 f32[2, 2048, 11008]" + # t35 = prims.mul(t0, t32) # t35: "cuda:0 f32[2, 2048, 11008]" + # t36 = prims.neg(t35) # t36: "cuda:0 f32[2, 2048, 11008]" + # t37 = prims.mul(t36, t5) # t37: "cuda:0 f32[2, 2048, 11008]" + # t38 = prims.mul(t37, t5) # t38: "cuda:0 f32[2, 2048, 11008]" + # t39 = prims.mul(t38, t3) # t39: "cuda:0 f32[2, 2048, 11008]" + # t40 = prims.neg(t39) # t40: "cuda:0 f32[2, 2048, 11008]" + # t41 = prims.add(t34, t40) # t41: "cuda:0 f32[2, 2048, 11008]"] + parents ids: 11, 4, + children ids: 13, 15, +node ID 19 : [t50 = torch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" + # t50 = ltorch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" + # t50 = prims.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]"] + parents ids: 4, 15, + children ids: 22, +node ID 20 : [t43 = torch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" + # t43 = ltorch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" + # t43 = prims.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]"] + parents ids: 4, 13, + children ids: 21, +node ID 6 : [t29 = torch.permute(t25, (1, 0)) # t29: "cuda:0 f32[4096, 4096]" + # t29 = ltorch.permute(t25, (1, 0)) # t29: "cuda:0 f32[4096, 4096]" + # t29 = prims.transpose(t25, (1, 0)) # t29: "cuda:0 f32[4096, 4096]"] + parents ids: 5, + children ids: 10, +node ID 10 : [t31 = torch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" + # t31 = ltorch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" + # t31 = prims.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]"] + parents ids: 6, 7, + children ids: 24, +node ID 17 : [t55 = torch.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]" + # t55 = ltorch.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]" + # t55 = prims.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]"] + parents ids: 16, 8, + children ids: 24, +node ID 18 : [t48 = torch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" + # t48 = ltorch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" + # t48 = prims.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]"] + parents ids: 8, 14, + children ids: 24, +node ID 11 : [t27 = torch.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" + # t27 = ltorch.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" + # t27 = prims.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]"] + parents ids: 9, + children ids: 12, +node ID 13 : [t42 = torch.reshape(t33, (-1, 11008)) # t42: "cuda:0 f32[4096, 11008]" + # t42 = ltorch.reshape(t33, (-1, 11008)) # t42: "cuda:0 f32[4096, 11008]" + # t42 = prims.reshape(t33, (4096, 11008)) # t42: "cuda:0 f32[4096, 11008]"] + parents ids: 12, + children ids: 20, 14, +node ID 15 : [t49 = torch.reshape(t41, (-1, 11008)) # t49: "cuda:0 f32[4096, 11008]" + # t49 = ltorch.reshape(t41, (-1, 11008)) # t49: "cuda:0 f32[4096, 11008]" + # t49 = prims.reshape(t41, (4096, 11008)) # t49: "cuda:0 f32[4096, 11008]"] + parents ids: 12, + children ids: 16, 19, +node ID 22 : [t51 = torch.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" + # t51 = ltorch.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" + # t51 = prims.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]"] + parents ids: 19, + children ids: 23, +node ID 21 : [t44 = torch.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" + # t44 = ltorch.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" + # t44 = prims.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]"] + parents ids: 20, + children ids: 23, +node ID 24 : [return (t56, t55, t48, t31)] + parents ids: 17, 18, 10, 23, + children ids: +node ID 14 : [t46 = torch.permute(t42, (1, 0)) # t46: "cuda:0 f32[11008, 4096]" + # t46 = ltorch.permute(t42, (1, 0)) # t46: "cuda:0 f32[11008, 4096]" + # t46 = prims.transpose(t42, (1, 0)) # t46: "cuda:0 f32[11008, 4096]"] + parents ids: 13, + children ids: 18, +node ID 16 : [t53 = torch.permute(t49, (1, 0)) # t53: "cuda:0 f32[11008, 4096]" + # t53 = ltorch.permute(t49, (1, 0)) # t53: "cuda:0 f32[11008, 4096]" + # t53 = prims.transpose(t49, (1, 0)) # t53: "cuda:0 f32[11008, 4096]"] + parents ids: 15, + children ids: 17, +node ID 23 : [[t56] = nvFusion1(t44, t51) + # t56 = prims.add(t44, t51) # t56: "cuda:0 f32[2, 2048, 4096]"] + parents ids: 21, 22, + children ids: 24, + +============================================ END: after _transform_for_operator_executor_execution (always) +---------------------------------------------- all traces +import thunder +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): + # x: "cuda:0 f32[2, 2048, 4096]" + # t_fc_1_weight: "cuda:0 f32[11008, 4096]" + # t_fc_2_weight: "cuda:0 f32[11008, 4096]" + # t_proj_weight: "cuda:0 f32[4096, 11008]" + + # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) + x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" + # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" + x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" + # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" + + # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 + a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" + # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" + # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" + # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" + # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" + # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" + result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" + # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" + + # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) + t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" + # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" + return t18 +############################################## +# Constructed by Dead Code Elimination (took 0 milliseconds) +import thunder +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): + # x: "cuda:0 f32[2, 2048, 4096]" + # t_fc_1_weight: "cuda:0 f32[11008, 4096]" + # t_fc_2_weight: "cuda:0 f32[11008, 4096]" + # t_proj_weight: "cuda:0 f32[4096, 11008]" + + # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) + x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" + # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" + x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" + # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" + + # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 + a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" + # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" + # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" + # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" + # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" + # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" + result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" + # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" + + # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) + t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" + # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" + return t18 +############################################## +# Constructed by Dead Code Elimination (took 0 milliseconds) +import thunder +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): + # x: "cuda:0 f32[2, 2048, 4096]" + # t_fc_1_weight: "cuda:0 f32[11008, 4096]" + # t_fc_2_weight: "cuda:0 f32[11008, 4096]" + # t_proj_weight: "cuda:0 f32[4096, 11008]" + + # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) + x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" + # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" + x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" + # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" + + # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 + a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" + # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" + # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" + # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" + # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" + # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" + result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" + # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" + + # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) + t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" + # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" + return t18 +############################################## +# Constructed by Augmented forward pass +import thunder +import thunder.core.prims as prims +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): + # x: "cuda:0 f32[2, 2048, 4096]" + # t_fc_1_weight: "cuda:0 f32[11008, 4096]" + # t_fc_2_weight: "cuda:0 f32[11008, 4096]" + # t_proj_weight: "cuda:0 f32[4096, 11008]" + t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" + t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" + t4 = ltorch.add(1.0, t3, alpha=None) # t4: "cuda:0 f32[2, 2048, 11008]" + # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" + t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" + t6 = ltorch.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" + # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" + t7 = ltorch.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" + # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" + t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) +############################################## +# Constructed by Transform for execution (took 2 milliseconds) +import torch +import torch.nn.functional +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): + # x: "cuda:0 f32[2, 2048, 4096]" + # t_fc_1_weight: "cuda:0 f32[11008, 4096]" + # t_fc_2_weight: "cuda:0 f32[11008, 4096]" + # t_proj_weight: "cuda:0 f32[4096, 11008]" + t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + [t3, t5, t6, t7] = nvFusion0(t0, t1) + # t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" + # t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" + # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" + # t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" + # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" + # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" + t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) +############################################## +# Constructed by Update Call Context (took 0 milliseconds) +import torch +import torch.nn.functional +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): + # x: "cuda:0 f32[2, 2048, 4096]" + # t_fc_1_weight: "cuda:0 f32[11008, 4096]" + # t_fc_2_weight: "cuda:0 f32[11008, 4096]" + # t_proj_weight: "cuda:0 f32[4096, 11008]" + t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + [t7] = nvFusion0(t0, t1) + # t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" + # t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" + # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" + # t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" + # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" + # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" + t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((t0, t1, t7, t_fc_1_weight, t_fc_2_weight, t_proj_weight, x), ()) +############################################## +# Constructed by Delete Last Used (took 0 milliseconds) +import torch +import torch.nn.functional +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): + # x: "cuda:0 f32[2, 2048, 4096]" + # t_fc_1_weight: "cuda:0 f32[11008, 4096]" + # t_fc_2_weight: "cuda:0 f32[11008, 4096]" + # t_proj_weight: "cuda:0 f32[4096, 11008]" + t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + [t7] = nvFusion0(t0, t1) + # t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" + # t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" + # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" + # t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" + # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" + # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" + t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((t0, t1, t7, t_fc_1_weight, t_fc_2_weight, t_proj_weight, x), ()) +############################################## +import thunder +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): + # x: "cuda:0 f32[2, 2048, 4096]" + # t_fc_1_weight: "cuda:0 f32[11008, 4096]" + # t_fc_2_weight: "cuda:0 f32[11008, 4096]" + # t_proj_weight: "cuda:0 f32[4096, 11008]" + + # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) + x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" + # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" + x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" + # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" + + # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 + a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" + # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" + # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" + # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" + # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" + # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" + result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" + # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" + + # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) + t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" + # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" + return t18 +############################################## +# Constructed by Dead Code Elimination (took 0 milliseconds) +import thunder +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): + # x: "cuda:0 f32[2, 2048, 4096]" + # t_fc_1_weight: "cuda:0 f32[11008, 4096]" + # t_fc_2_weight: "cuda:0 f32[11008, 4096]" + # t_proj_weight: "cuda:0 f32[4096, 11008]" + + # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) + x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" + # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" + x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" + # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" + + # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 + a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" + # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" + # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" + # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" + # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" + # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" + result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" + # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" + + # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) + t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" + # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" + return t18 +############################################## +# Constructed by Dead Code Elimination (took 0 milliseconds) +import thunder +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): + # x: "cuda:0 f32[2, 2048, 4096]" + # t_fc_1_weight: "cuda:0 f32[11008, 4096]" + # t_fc_2_weight: "cuda:0 f32[11008, 4096]" + # t_proj_weight: "cuda:0 f32[4096, 11008]" + + # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) + x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" + # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" + x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" + # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" + + # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 + a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" + # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" + # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" + # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" + # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" + # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" + result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" + # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" + + # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) + t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" + # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" + return t18 +############################################## +# Constructed by Augmented forward pass +import thunder +import thunder.core.prims as prims +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): + # x: "cuda:0 f32[2, 2048, 4096]" + # t_fc_1_weight: "cuda:0 f32[11008, 4096]" + # t_fc_2_weight: "cuda:0 f32[11008, 4096]" + # t_proj_weight: "cuda:0 f32[4096, 11008]" + t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" + t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" + t4 = ltorch.add(1.0, t3, alpha=None) # t4: "cuda:0 f32[2, 2048, 11008]" + # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" + t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" + t6 = ltorch.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" + # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" + t7 = ltorch.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" + # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" + t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) +############################################## +# Constructed by Transform for execution (took 2 milliseconds) +import torch +import torch.nn.functional +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): + # x: "cuda:0 f32[2, 2048, 4096]" + # t_fc_1_weight: "cuda:0 f32[11008, 4096]" + # t_fc_2_weight: "cuda:0 f32[11008, 4096]" + # t_proj_weight: "cuda:0 f32[4096, 11008]" + t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + [t3, t5, t6, t7] = nvFusion0(t0, t1) + # t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" + # t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" + # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" + # t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" + # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" + # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" + t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) +############################################## +# Constructed by Update Call Context (took 0 milliseconds) +import torch +import torch.nn.functional +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): + # x: "cuda:0 f32[2, 2048, 4096]" + # t_fc_1_weight: "cuda:0 f32[11008, 4096]" + # t_fc_2_weight: "cuda:0 f32[11008, 4096]" + # t_proj_weight: "cuda:0 f32[4096, 11008]" + t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + [t7] = nvFusion0(t0, t1) + # t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" + # t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" + # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" + # t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" + # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" + # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" + t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((t0, t1, t7, t_fc_1_weight, t_fc_2_weight, t_proj_weight, x), ()) +############################################## +# Constructed by Delete Last Used (took 0 milliseconds) +import torch +import torch.nn.functional +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): + # x: "cuda:0 f32[2, 2048, 4096]" + # t_fc_1_weight: "cuda:0 f32[11008, 4096]" + # t_fc_2_weight: "cuda:0 f32[11008, 4096]" + # t_proj_weight: "cuda:0 f32[4096, 11008]" + t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" + t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" + [t7] = nvFusion0(t0, t1) + # t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" + # t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" + # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" + # t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" + # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" + # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" + t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" + return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((t0, t1, t7, t_fc_1_weight, t_fc_2_weight, t_proj_weight, x), ()) +############################################## +---------------------------------------------- ans +tensor([[[-0.0380, 0.1292, 0.0922, ..., 0.0574, -0.0760, -0.0142], + [-0.0312, -0.1352, 0.1404, ..., -0.0036, -0.1777, -0.0775], + [ 0.0121, -0.0281, -0.1634, ..., 0.0387, -0.2150, 0.0118], + ..., + [-0.1302, 0.0754, -0.1463, ..., -0.0835, -0.1263, 0.1630], + [-0.0158, 0.2085, 0.0153, ..., -0.0273, -0.0947, -0.0970], + [-0.2236, -0.1944, 0.0894, ..., 0.0347, -0.0962, 0.1017]], + + [[ 0.0363, -0.1088, 0.1518, ..., 0.0293, 0.1325, 0.0490], + [-0.1212, -0.2084, 0.1211, ..., -0.1555, 0.0875, -0.0580], + [ 0.1207, -0.0828, -0.0089, ..., 0.0490, 0.0931, 0.0576], + ..., + [-0.0100, 0.0776, 0.1118, ..., 0.0961, 0.0167, 0.0933], + [-0.1560, 0.0455, -0.0116, ..., 0.0028, -0.0157, -0.0022], + [ 0.3174, 0.0314, -0.0429, ..., 0.1140, 0.0264, 0.0614]]], + device='cuda:0', grad_fn=) diff --git a/examples/dev/backward_trc b/examples/dev/backward_trc new file mode 100644 index 0000000000..84a36f90a6 --- /dev/null +++ b/examples/dev/backward_trc @@ -0,0 +1,96 @@ +digraph { + "0([Symbol name=unpack_trivial])" + "1([Symbol name=unpack_trivial])" + "0([Symbol name=unpack_trivial])" + "1([Symbol name=unpack_trivial])" + "2([Symbol name=unpack_sequence])" + "0([Symbol name=unpack_trivial])" -> "2([Symbol name=unpack_sequence])" + "3([Symbol name=unpack_sequence])" + "1([Symbol name=unpack_trivial])" -> "3([Symbol name=unpack_sequence])" + "4([Symbol name=unpack_sequence])" + "2([Symbol name=unpack_sequence])" -> "4([Symbol name=unpack_sequence])" + "8([Symbol name=reshape])" + "3([Symbol name=unpack_sequence])" -> "8([Symbol name=reshape])" + "5([Symbol name=reshape])" + "3([Symbol name=unpack_sequence])" -> "5([Symbol name=reshape])" + "34([Symbol name=reshape])" + "4([Symbol name=unpack_sequence])" -> "34([Symbol name=reshape])" + "6([Symbol name=matmul])" + "4([Symbol name=unpack_sequence])" -> "6([Symbol name=matmul])" + "5([Symbol name=reshape])" -> "6([Symbol name=matmul])" + "10([Symbol name=reshape])" + "4([Symbol name=unpack_sequence])" -> "10([Symbol name=reshape])" + "12([Symbol name=mul])" + "4([Symbol name=unpack_sequence])" -> "12([Symbol name=mul])" + "7([Symbol name=reshape])" -> "12([Symbol name=mul])" + "13([Symbol name=mul])" + "4([Symbol name=unpack_sequence])" -> "13([Symbol name=mul])" + "7([Symbol name=reshape])" -> "13([Symbol name=mul])" + "14([Symbol name=mul])" + "4([Symbol name=unpack_sequence])" -> "14([Symbol name=mul])" + "12([Symbol name=mul])" -> "14([Symbol name=mul])" + "15([Symbol name=mul])" + "4([Symbol name=unpack_sequence])" -> "15([Symbol name=mul])" + "12([Symbol name=mul])" -> "15([Symbol name=mul])" + "17([Symbol name=mul])" + "16([Symbol name=neg])" -> "17([Symbol name=mul])" + "4([Symbol name=unpack_sequence])" -> "17([Symbol name=mul])" + "18([Symbol name=mul])" + "17([Symbol name=mul])" -> "18([Symbol name=mul])" + "4([Symbol name=unpack_sequence])" -> "18([Symbol name=mul])" + "19([Symbol name=mul])" + "18([Symbol name=mul])" -> "19([Symbol name=mul])" + "4([Symbol name=unpack_sequence])" -> "19([Symbol name=mul])" + "23([Symbol name=matmul])" + "4([Symbol name=unpack_sequence])" -> "23([Symbol name=matmul])" + "22([Symbol name=reshape])" -> "23([Symbol name=matmul])" + "27([Symbol name=reshape])" + "4([Symbol name=unpack_sequence])" -> "27([Symbol name=reshape])" + "30([Symbol name=matmul])" + "4([Symbol name=unpack_sequence])" -> "30([Symbol name=matmul])" + "29([Symbol name=reshape])" -> "30([Symbol name=matmul])" + "9([Symbol name=transpose])" + "8([Symbol name=reshape])" -> "9([Symbol name=transpose])" + "35([Symbol name=matmul])" + "33([Symbol name=transpose])" -> "35([Symbol name=matmul])" + "34([Symbol name=reshape])" -> "35([Symbol name=matmul])" + "7([Symbol name=reshape])" + "6([Symbol name=matmul])" -> "7([Symbol name=reshape])" + "11([Symbol name=matmul])" + "9([Symbol name=transpose])" -> "11([Symbol name=matmul])" + "10([Symbol name=reshape])" -> "11([Symbol name=matmul])" + "25([Symbol name=reshape])" + "13([Symbol name=mul])" -> "25([Symbol name=reshape])" + "22([Symbol name=reshape])" + "13([Symbol name=mul])" -> "22([Symbol name=reshape])" + "21([Symbol name=add])" + "20([Symbol name=neg])" -> "21([Symbol name=add])" + "14([Symbol name=mul])" -> "21([Symbol name=add])" + "16([Symbol name=neg])" + "15([Symbol name=mul])" -> "16([Symbol name=neg])" + "20([Symbol name=neg])" + "19([Symbol name=mul])" -> "20([Symbol name=neg])" + "24([Symbol name=reshape])" + "23([Symbol name=matmul])" -> "24([Symbol name=reshape])" + "28([Symbol name=matmul])" + "26([Symbol name=transpose])" -> "28([Symbol name=matmul])" + "27([Symbol name=reshape])" -> "28([Symbol name=matmul])" + "31([Symbol name=reshape])" + "30([Symbol name=matmul])" -> "31([Symbol name=reshape])" + "37([Symbol name=return])" + "11([Symbol name=matmul])" -> "37([Symbol name=return])" + "35([Symbol name=matmul])" -> "37([Symbol name=return])" + "36([Symbol name=add])" -> "37([Symbol name=return])" + "28([Symbol name=matmul])" -> "37([Symbol name=return])" + "26([Symbol name=transpose])" + "25([Symbol name=reshape])" -> "26([Symbol name=transpose])" + "32([Symbol name=reshape])" + "21([Symbol name=add])" -> "32([Symbol name=reshape])" + "29([Symbol name=reshape])" + "21([Symbol name=add])" -> "29([Symbol name=reshape])" + "36([Symbol name=add])" + "24([Symbol name=reshape])" -> "36([Symbol name=add])" + "31([Symbol name=reshape])" -> "36([Symbol name=add])" + "33([Symbol name=transpose])" + "32([Symbol name=reshape])" -> "33([Symbol name=transpose])" +} diff --git a/examples/dev/backward_trc.pdf b/examples/dev/backward_trc.pdf new file mode 100644 index 0000000000000000000000000000000000000000..208da6ea8a2479559a06032d0d74978f696871e4 GIT binary patch literal 17024 zcma*P1y~(R7A>4WfZ!I~-Q@rWcXxMpcemi~?hxE9xVyUq4ess|G(Y6dop)#E{onVx zIo-8)?W$e1tGdsj*6Jpa6%?ic&;p@I>Q5`5pqTLK@on_Yp*T75=_HM;O&m?}fp02# zD13Z;I$<+QM;{h@@9s^10BT z@4L6~j{Vs(zBP7Q31j}g{RY?hFo50h^kfYGJF;PY>$ms1*YDfzlf8?g@w8qZyz)q| zDpDMMHclDJ-F9$~YSIO+eAl5vS&6my$!9n21k8JV%FwNoGleyFnY^Viod;~$pXh=H>cOtu&z0>f@as5P9m|%BkpS8NW@vxv(A6T#V}d2Fxuc8Yvi(yQ z?^j)J-BLK67q_RE=?w0jZM-P5^r%jrCD-TGLw~2i-)vuPH-AN}AO0Ggh46l!jK3X> z4xSBmeZCnlsgqcL1%2JwmiZ=~&f_)t)WNp9{qSk+H|u6kXQn0FYe>h{<(sYT^Rf4) z3;XMBx&o)y^P^$!C4ABBg3gQc(_qz}pIZ7+mSmqkrdkZ&d;?*@3F@`60EWZRdWwZ? zQlrmuSg41%(S=mnH5a*wqZWIaQ?0@?w*hbfly7pAX&`)at&Gi0=VZv@{Jbp$h)g0ZUc{1y4 z_e1IVuFvhD@hZ56g{g{-IZO0vPTMazXKVwQD{&Nk?Zs$#ZuF&dB|de(uulmZ(Q-)+O|_9tAfD&KKL|Y{(F}{s|q|!Ta?! zKXGTpu!JLPcwdQlfmN$M%fM<6UIK+lwzZG3pAL4?P%Cs{>4V^^VJJ?!+0ve)Y{*lJ zvu1VfsT~gPWmjP>mrMz6VM)u5)=ijwkq_#okVrwrXOgYY>kN>AZC1ibjPbP?AlX=ZF(Mx+U45PhH=|JN2T=4nq4wHwUBm zi%WYd7xCS*)tsqKg8J`0(Gwymx#GH+AfhX}Gmrg_49I#uY8bk3e{NRO8PksYR-V^A z<%MGiqAuj_c&9jWg4s9F`sA>ykM(0iGsv@ZvD-Od1tr1II-}8&;Gl+BHWZmsIU9*X z@dNTTl%x*~VS%-IfR{u{3OI6ZxGY5#M^g96+{oPs-+a#|kbG%!fng4q_dW27$gf8W(%5(!sO6E`V&Y*E2C)cT*xZ82&U<~xLY&i^rc7t z(JxN^xpgG^^m^a}b^eYKUuy)9z$IfP#ELhvV`rJ05|U^q{YCjI?k9pNqS0n~vQVby z&_xY2f**Wtt%Edssn+=|-65LqV4A_hgp*rj-EWJc8;Lc1f!LIL-TJzzwlq;gpLs1- zeFa6qIz6H}vMD)ai^a7Rt8Se`Lv@I=svzUIT|Y>xW~H0l z?dL9@Xoh+C=s~$mySU~T%DTRnj9|Gk8Z;rbk~aht4a2XLS*EGald;{%{-ot&<{JsS ziKNm!fDsiV6d*mMRv%bU{fRuH=;Bg{n8+E1=|E};6k?U1tOoB)cqD5I{V+d#hfk)X z%J?j9KPFUTj0qXruXLHgkh|ePbu;+CHI3Tvcee-eSJu}$~IH@IW@l_yG9MM*?>A1 z%rrw@e#eyW<+saCS2SCo$L*B0+J|W^)@C8~LTHx7;@~ZiAijF10+?gjkR@8#5Md30 zp=4Q8SgR#)BX?Dqc@Jk&S@dWEqycoyQT84P$zxjW-qQ8eN`=VfE z<6SI>3`l*)qO{!%ACT*%6xU_gvYx)3lCPHf{vB3%uCdsRS$csdv<$=`L%3;BVY8fg z@gTrw|EgC{{9svk_%rIYyrE1a_=|`Y!W595kX^TvgchKlYs!RvdI8yHs8eoUM@1`& z4XUA?R4sm{7sqw48(z*+U=^6QlYYUHC#CF#l1kE3epv^8c*R*Va}9@1LhN`R(K&g| zhTmv}9(kyr0#6(e zouMr*Ky2F!atz{dxC}`qgp*+%PwmpRvGWcsk4?`TGtrBh(VQ-!v*8+$1e>nSRV4a8 z+$k(`MdAn-9i#&i;E&M=X(eF9kgCI*iG!2_@lGp{TSwS=SpW)~LZp7mSeXyLTh@=} zg6v;8)=-%k4s(N^KXJtZ$f&@%?*u*s7+9gs0+0QQcY4FfM>?Se3!12TLW9W7@Io*x zarjT1`4sK-2r2cD$7!uvcoQ`LYL+)@?nQAUVqzyuZYjBN=@?}P;s+4CPsTOZQ#LU8 z;Ij#a#<~S#LAYeH_7>lLcen8Q>qj$#4=VhCAN*~>;1OdCSVxk7fIth~bxtt)#HUMF8kFe;`x^UL5w8I#zUVeX+Y$Jnn%V&Bbh8t{;n`J(DG3 z$riZolp zS3P)BP5kR6VqOPOIW)Qp;zrWRt1Sh6WsW&y@Fb{hf;BX1B&02EL`X+g>zFzQcD1vG zE@h%2@^l@e!o98XgU*)|+6YNM>7r~>nKY}A1=Jk@8iXlykYRfh2nzFHk64&(T03UR z3Yo+ZHO=T`i@!K$dSLHF@OE1#P!UkkQqFJ2vBsh98ykzWiEA9j7cNU7o{;4)(0%w& zOMFM=pzcUfel5))&X&aTjJC4ZFFK^Si+Su5m>cWOG(DDw!t49(q78k=x*nub^;<&* zJ^nm5v6!$ErG+wu0Eno1A1@+{5@8sKK3@i{yyP#@vVs}9)g<%@qrTBDU-bf)l4ge%6NSfw!9ZHsV+eEMo3c05ONCfyP8Ob}yIF*%J7=fT6iL_!n0)S8D2IL9x$wS15EoQ$AjNjJ*oIk+0i!D16mUg&^<+VqUEixx#YWOrI(g==7d~n0k6qktJ=lfN7 z#DUi^cOWPE>ylbzwlgx%j>e#Eod{3Q2+tY|;;bEJ38Hu86D@?GMv@o1zx9-cBt3n( z<_2U9b>fsUIaqp1(m`L%K0Df8Er9?z??a>CI2NV%Y0zmMW60@(?+*5{RhuOJHUDHw zeJOR!4i9(Ows81-?{l%VujBZAem}<-wt3M}{(Wn1ag{bCIAGl}D3RMO$f+QOh<&dx zochR=nu$Q1YcGBtp1;;9ZFZHeGH-t;PxHx$=SCCe1Rn5Wt(ErrYe^=l2 z$tGfYj(NdHI!HjpbXbX&L|Dm?qV`LfZhe=!4S0C=Xa#{C1;1)ug#GFTtT^%K&M|3$f zVysHNehE{xRLD?%9{H1g&E;;1h_BTF0#)45G)m5)TJ|bNwWAz7Xh!40X^%WQwewag z7=YB$_s$4?bvid;LspMZ{Krj&KNCvrMcTDlc#MK=ec!D`0mDsCn-ZdhX$ZxI$08L- zS^-IlRfm!|3L|L1VsDNbC*l+J(WqGt%sozRJ>PbnJo$ATWDlDz-bhC+reLb0DQ=Q5 zEqateIYqFtV(G-k@VY*>snH7@#aFnqr7d6pyXgWK(<%S6s%65%rwUN)Xp_hdY_0tc z^xvSD)aCqWGmgRQUuMqH2-|f;HPZZqpR?2`Rfrce2?mg2^6Dwdv6!_#ZCQR+LKNZR^jOhm8Laza#EzoF4 zrL07GVyv-W9GKWat#8>g#+;FGRgRqI1GX~g6~kQ5?nMA-=S_Y5N9?)eZGY6wRMpl8 z8FRsvrY8BL2p=3M!p!{RUl+%ngm#;y{0Fq+8v{&Cav!+kE$9I`rxY@~lDH`b8CiDu zL$vsm9p+A@gTY$#wvswr4C>tRMMJc_6WiGQc-mHF)S-&4Ic5b>$r_r<-ENRFDjNkV z6DwLM&|Kxqo{_C$((MgwSiUkQ_9d>wZfTVQmv#!~0fV_nGlQdy;}ksdfRanxXdyrG zyW|)}NUUa*OkhlQi8!(xQvL?gb&uGk>?#t2K4gYruo|YuwUpCX6XCU6AODxUF$Qld zEv*j*3KHh7CPn1U#$T7sX_25|y)tt3#mwh>bb_)e%Xx?8^j9&oi$10g^YQ#xqwFb( zNl#4`{&lEs|53yDNsvKY9XQ#-GAwKkOis{jGt6rZ@SSS#ze|D$8_eIW#Wqm7JrWdW zXp$wTwiBp(WIDG|hFXdVm-}B~r)Hxt^vg++xPlzTfzlX{x*_BZNVu?Q-(wOsk>n$C#4`%s z=0g%AuA&5mi}GH&M^Re1b@Z&FBcAxSde8%JqEs6h^yymJ@=>=ew3LIsL6~1kt*$`n zYbY!(yfL4sCE52dbS&B0*I<0ZaD-6*AS()c5sl1;jji`MZcUOgLaXpHdB|#3NrDx z?l|7-j(W&u?W0bI%4_Blu6v}AfcVkJDgIhYbJ2(+Tz0xp z(=>s zN8#*Unq9kMw$rvdlNy&L5Gy0AAq}vr5|;_c4}=DK)%Ro#cmv^4_m@ zQWW{RcGW;OrGbRhsRBAOqbtc#Qc`j#mCn+FM8VcIxG^?^-A~f`_%cS?b6b$gL{lLJ zmr%X3VcHCnmvvn$C}hFx({8}3)2<0i-XgT?G+pO4#wz>u-L{&Q{r7BW#OjthHrGlIeHq~?>D((H;=!8L|@ovPed$K#2O1VUiNT@N*4f%I%RQH6s zAE2?Ofw(i=P+Z^|RVTWQZ@G}hsP13RlW5yp$!}u*4g=*4`6c7ts7JmlHxY4?eQNSA zUEvcDy#7ZC&nbyk2t50B67H}@y36VFmaaiVqnrRkvO1VzMT|+CsURY!cq@ zC{giSIa~m9j<=SL{e## zu@ghuW4L5cEiK3haU3Y=*y)C34-!CFOlwYhUCi5h+FmTmK@aw;c@S%mYd^%dh)1}> zWYj0PQk2t2YNxZfsw(G_5d&@=IF+H2r&1^GydNr9?gchQnv?rL<*bF1%#qfHBCGrX zQOpJQiKkt49PK3yrH-?&RAhLcfyQ%IaqKY&@X#xtKY{yYpM^_)uEk)R(Jn6o!^Do* zBVN`tH=}k4(iNvh+D*paxD31PA-r|HCIY|M!>A@^P5Cx}PRTbkVyVh3Zy<1PJRt28hnd}S5GXZ-VVT%ZWfM9M3l~7+sK5CLr*sx&5&~T z>N9d2ud>QmfM7gj9+<1OK5oJzkNAAQzp~a_;vvgjnMlUP%9&Q0rsj|v!Ea~gIy+W6 zzBv1IB(z_x?h%n4ae+ewhn7z^TQ}@T6{W5~N-P)|^^f6VxtYktzycWAPwg94e z_>i1Wc>fgkmfm0nFP>aY8oi@Sl|si>I0D-n&xve{Xbk#* z5y7}*_g~Ika>cc%h|LMiR=P<-&&>mz2ZVZ+(z~cEreoz1aCyg4!`D}a&730hOLPOT zJ`#F+3y@mPIIb7x8mYhXE}m~d?OqR!`(Zo9Hw0mYhL?eZ+#_Zn%8E6XVR#;~;vc=J z_$I0t8AMabwXNdPS{Jln&WS4wck^4P{r4wW=&*}FiuVosyl%45MfmN{&kZ*18L5SD z^VPDO!;!68z-@Ah-=_b-m?%ZAV6W4M6k>y*{Dy5n8a~RkXd*GgpTc*4)IA@lj7IB3 zBg|9E^lh>^ef#(X&E`V%*OBj0?|t!{c83&eSo6xI$7Xtb|I9aE5@kF3ovDPT%5>@*XCQBzZQyF9nB70i(&J{`S8vbx&4fB_hF|!@g zoBn>j3*B-1vMdXF`OAa?1XboBPsLCIxF&@!a&g9;U-x_=yfYTurX~F{Y&Rf80veA} zaYIuTtFC8G7gSO<;Q(q@=HZ{=ONE@dUc7ZPz+GLmexSRMj_$>T(0AH*t4CmnBn9EZ zpDBhP!_6ZY7a<|MeW$v<-SG`HL>_t5MpltwoX=5UUU@7@+1g{Cw>TBTzodY3NuP+L z{>~2$#eui8BYe@o+nu4fvq#f9-M@?Vtzn(?Wafz@RzoNfE3@E!@=En7=lk!PR#5zQ} zcJR@gYjME{1+W@Wa-#yUE6{?GNfcL$Y*O24N++e%p2MQS5fB4;DBnoR=JW_O+)gsr z6{nC?j`JOBE)4llMbtY(x>m4`r)>|2?5QDy} zde*ERVPS~~Xul*q&BOA?)maTeDjZCpCL_y8Xhx~Og31&rq8OQIWI8BfVK-ZDWEnC) z1cyCuWOu@S0$r+x8HzZ!p%(DkV61Xl6o|-ghdjSW*t|E2i&yzR5YmqFJzll+6Qq0^ z0mZ%-I>+RxrAr>bbN=YZUSfs6Z!)xeG`O+uq*tK16u!3$EPmvXl{o1rsR zh$45u*gD{SmQZvc`b(XkLlRl8#TiSLPLz1dm&<%qzgKhnsWKYZNCk@ZUP=tcs_gHdh?a^e)f*; z=Wh@-s3$BaBWuHdA$f19cMR|y_M;J@RXe}S4J_BL;eLjR}je|h<%{kuB~Z@%sItQ~CMy&1Uutro*~ zw0AQ4TM~G)5;Sr)Gcb}D;eS{DtF!V(4mM8q21X9}?}3y5N8ec9in>p?tQj|&0FaR#W zMG06u%$0h+@-*4L?qbnevQXK+uxwGW{(3q`7=eH^_6Dq4o%RjaP zR;K*xs8;(xXlNk_!J1mT#r87!_;)*W&M&=-xrV@9i?C*wWNz>UI>H?@S_=1K`VO%r zxO0xpr3)li-CGRqfF}#vZ?ZM6LPmqgGp+$3-)+~GwZCr=^sOv?z9-xC`n?rQnOuWq zhRQ_a`O3D(j3O{ocqrqHbLv&^}HtQKAgrT+H8ccZT zdH--Gnci3hL!8WMqC(W%PX;1}fC}bZa>hMvR2r8se;9Tce!|M2W>)d0-mCfRn z6QZ7SOh#T|d(Hp`lLF zDVZmjB)*7CQOrDAtGrHdLVZ1!^|Lfeb{Mm-i`Xakhx|?$1hJbh&H4jS*L=mHTZsMY z{OnNxp^*qPq5b$W?c1Rp(ENzKM~MNk3rPw*=*&2!ubERLat zWVgT$JR_MG>h}Bct^45qf_Az)Bt*vf0Eu&0)KETut@7+TSx>$IHyj8NHAH;hja)Fe zH8PyR%qT;a+UYy$JDc$&k5tWim$HG7@rPEIem zZ1rMhMXWL@i?O*ibDjE5b}uPiK@D>c%J}Fj5mik)Rw_A|I5q?#V-WRjRD z(8~=SX}YLhIF93J8TR^&EY`0#eKiI)UY_Xb9uk#vw6)l7M(0)Y9ZyFdbQC&N>x%0s zFjRQS`9NKm*QVXUL$@*A5kQL_jQ{x_SHyvE#O!8eW7-Y2UMY-a z_^lcIHa)bvwg_NLD35Afj~MrmvoD2YFwsB`c43vJ@Fz-=JGa*;2z*x-;_B zH*!gAYlzT;O;;Lwt@(H}tr?tl02QDN8AqN$YEM4ki@+52m-u{6&iLl&iR^nj#-`Mk zr5;NGNLoWE(~l~}8iC5&;(Lf!pIiMZH|7thA`^9k61T+75{(i)l*URI33Nf(K_F^g zqI|gRdyG?_@^zh)jL|(k<7zx%JM0mPwtgsUcB+TMsVEszf94t5jlck0MiT8b8X3KW zj+q!4{Z6UB^Ahuje zb{NAz#3AH;QxBA&DEb}@KR@#dVisEEXC4e|c!lzo4g*xy=P9gA{8?FY`@T&p4E`X3 z-b6$vX;u@~w?*nXjQRa-yX=}-yY&MrciEVa|96Hav%jy!uJe`j(^TMzy`76>1RSf5I7sQj5K_{ZX zPYym{*AVfvNJ!r(g!^YV33HcDN0bdQC>qgjSX2P41sWnD*{C<#Q!hVAZ(i(Hhjn1A zCIfY26m^G&In|Lr3}WPno6C{Zv!6tmDKsn;}Lxc6MY&l-8uRPihqH4X! z&8&BVg>|aukSNhy?rgHnx=KG@NOajp!6omf;#hzA+0XO3%Tg5E(%{oppsVAgB^v0I zvo_@4xH`;L!W(vMfwub%lEZ~;FBc+9#^(2JCY^A>JyPeAQ zL{$Qv?M)&TtNMfe%tA7qPP6^P)dI9_Qtn3em4=Yqgt;mQWo=yLQ?4?{4c;y zPX2#!%lZNt=rv$cd5baW(76cl7%Urgh`~zj z^$?!*G6D`4KMXL|Fl}0sSK_Fo2~%spq8t3ey>Pa z(YRB8h+Fq8Z^m-Z+w80A9#$L+U$nqFAjl>-Kx88tun7^4LsStu-#NB=J@svAddyeX z@clkACaAC@6V4L2LI{Of%(0dlS56~TS}T-9C|kH+GjqL)CO_%Im4y9)sf&?~PrO95 z1bZg6M;Z38FH@52P`}j_ZDIGo-np4Qk3ElLwa{~lg+t%VRv^7FF^F}kH^}6Hs4@rI z(1JZa!luD;rA^;WYtYN+_hLg0QRi{9PiiJs5+~_6m8;PmUT5Fz@0QZit?#G541Yru zj220=FWr?HD9l@uXZ#kkzcOjO$%1wNDF>5AI|yjYgSu=rYX_9#zQ{W*s%oFx=ApyT zVwHh^k|Fwm3H}19oH2|3{IRq z*{`nwJ8zhSI^15uuGpMkPuBqsSr&_$wF3GGzlTkL$8e4QD)*auBV-~`oz6#dA#G^3 zfR~7qIo5O$W=43TGrgoD)-WMjT_S^)UKTT~DTI2>>@g>y3PYQpz(n_4#O!MmAB1Vs z0_^h|TjBW3rR-=1rO;v_ovz%NaB#8D(G0=)*??YH^|C2;;afRYr(h)}r3YbB*FwHq z$;JV+l_9P!H0{bzvuHhQ_DT&0&BkR(kI0wCRHfeR&wHvMV(aJ!||O1$#u=>6=TL6UCzQ^c<+Y^BB*io^|4m7MLjhW zWlZZ9TA)3nLunw_5N4XSmb@H>h;p(i7&l3oN@K~KTrBXvbc9CJEf6=5fne$lTA89_ z!!+>KMFe24@q~op@bF5YP^JurgY#Y{)_^v4$%j%<`D4g^odmsVm1{T=Zq>T&s~i25 zP@t_N0i{1?Cd^XP+MIwMbNl?X50PCXbKCNUTz;q26?zsgm}-nCoZD5ExsW-flyB*z zbT29%le_<=Fy5M`Y~!-c0f!6L$j@>y6(PB6e)>q)`D?4J{umq(|AIu%h`BK2R;kQ& zZt28&+qOs<+Lp!Acy!w~KB0=UO)U0hFjc5%)~-=u2hv~$bi8CMNoa{A@v=@8u~6{5 zI#k%Vmeola8sysym|2LSA4^H)r~IWfPZ}2H}G!VJ)+Omo6UVT~7Uo z5x3Xu?c3b7;ns{c4M&dnUxt7sHu*8Ty(ZG|+6SEw=!d{9ev?oOVCT{jG46JrYP<9l zMsaK!D{8#C4K~vwJ6+fhL=54Axq?H^7*(4m58)>J5G%B2 z?#iJI)-6KkI>vzo-}u0h6N!VSkM6H8G~iHDJBo=XllVcA0#EXo6hrvs{98Xm%*n0g zI9!Oe5OWi?khXBQ0M&;Y!FZRbofZt6Vd>&)emB{ym$_Ej>n`RGBd1tB-v!N+4GlAJ zDzPcgf9pInxgWP!RK0)h!j$9=U=8D!8NO^=1E$Hl$Bn(Oy(~m_*2^$%xn>H8yl?glU0* z!zdYh*Cz*~{yY2=I_>Xyew>O;EvAHzZCDZGc-I0~dK8)5Rt~))FD7$y?!UTA>%c5) zu2#8bWg6gTI>Q)h!<@2-s_!Jaim*+?P6*vC65ECRcF_?hK9w($22rO@e6C7^oeQW9PiXnCQ zwU7;b*x|9-O!T>quU|jg#VG%LQT4K)1=X8`HeZz4bClMS-Jc!*K{Hh`bLg72Fjx)v zk>m0?lsr%>hr)gFIYlH_(u|>H)G_%1jeN2VAu&bed)a|`W!!?_)E8*Jz2BQ1;5sVv z&$?u_;B8jA=QGJ=k4xIS=aPX*&=(_wYEdOg(7LXMMs6Xbk-r3IJ}hnpZ`_(NqKJbx zP;zCwEPoSkJ_L310pE2p&P=KbiH&4mM51NHqGs<))003y!g`C5-mQ%-J6I3#b5xw`EQw6Q$&vMgQr-RV_C|sLCvfM2hi-#VnDr{m_!<{YhL6 z$Fb+A>V&jN8qpi8gyzzw=Zo?H`cEgNKW%eD6AHv7}6DHVJ)kT%< ziEv^Lt^=W-QveEIG34pY-%jlLRT|^$U+=dahjf|yii`xV{MNGR4&VnMKpg^`gPI4L z3!B%PJDOjcFWr12*CIP2Un4KEK&945b>cfIIw)+rJi1)EyzyUoyNHVB^Kl#SoUxs8 zopGYrvA-^RE!!=-EjzIvIo7F->%v$j*cj9Q@Nr=CO6Zo|wsE7-XmmO?xOTQ#lz+Lj zsEwi{H)5&M{O*2Mh||jK`LNbgIlZ7~5-V3b~&HEniC&^goxxtcebvGrS+C*>NYIDIkmPb5>-D5 zjmQ1GQz#{qnlw`=`)KjI<`CuEQ#_j2?|oa;DDRrfp{zTVFf3=z4yzZU4Mvv_)^+~4 zFg4lAY_SwSpxrgeKR>p9*A*Nc_vs!D6bjf>gHw|fWV(n@I^yC1C-s(0PHUHM1l#xh z2m``e*3I0f5`?>ToFa41_%KGyur#|IX`VZs|B6^|nWvQQL6zXUBa!g2^-`1*tQLFQ zFz9n{YHW|zm?0m;WdO}1-={FyAnHA?K!U-?>&m+--GI6=^Ttj@3h>RdV^K|BS*m^Vl0gtTA_RxZa}~A3txIXz;Qn>)EMsFt1Ry-~q z=9@KYOlG6+sA`dr64t6ZrweLq?;fBQLbfcA+|Ti>XA z3CqOs?PHsl&9m!I*C*FtGg}BA%MB6w-{XBPDC?S1!6dN?$nZ3;@%Estl{W@*Yz3MD z7&_+emY1mtBPFM;&Q|LkWvOoN%ZirIM@S9i=?P!O3+!z)SL$=^GlmBI<-G1`(RYnI zj%RprA8^@*hsELiaBfUwQPb_6=*=z>dP)~laQ z#MS!*)uXOZyB}uJKubRu)R`aUOB@0ZPY=t*otU(le<=*AK2NgAtS1U(x`E9{9^ail zQXI{?{N7(DBYA1g-_U72@=p8UyHE=7jY0fA=Eyoo0GX8{2ixLnvK4Gg;`QlVcTJ!6 zG0ro!L#_ku<&wp^xbrfX*^`de=l&Rc?dQhViSc`$uf@N-5YM}5TD5d|gO=h(keb#(q1l;NhM`pF#d0FJ`u)b?lOk)13^EMetmJ zG@%k?fgm65eeUb&Upfc?UCdkY%w=z zrLn0_T@D;Wn$hr*=v;?8e5BCeG~aDnmXOrcnf4820_4<_q9c-4S9IRXFQPiQxt!fJ z2iNKk$HEi|4>v>=nnLiVL}ioWhax_fR>fa#I3e55qd5nMe`(9}hS}ko4sG}wGH|Z? zh%bSV^HuDaKMFbOrkJC9^XfPzx%r40`~Sj^$cwT=edGxl@N4VGKQt_kpyX-xTMj`a zjwr0kD`_d=)n;8~XM$c+j;C%5zRm`ziyHI8E$ppjXTtFY$@A&O-6n^o!i9RscP6zH z*K2*g=3M7+cs9$yTYV(Xtil;LWvRGyT4@3hMMvg1ESp(!i0%Ab>g&?+KHAlp<)1a) zd01X=XLQM4;DEN@TpHtSQas-PaCS1XNR6Tp?sda0Y*qzLA`d=8TwSHDYi7hV*jzln z6uso-%RF>;RU8qqZmE7gM5V$Fpw2qyIn|*2%6{=Y4+sdv$fKbxg1XK zh8zWGQA?Mz>JuIU+_2fqI@O&Y^>))Buvr>P*Mhtm=f5f(b0&#lk=z+O=5sHMX~QMX zeAu~Rx1ZsaMjyIVIlXaBX)`Co(YvaGVk?YM`Eq3uh^;gfd)Impdz}A5Yp+TWJRh6< zX?9xAP{33u$T9val6djPZ@;6>prn3Wj%96R&l=M9OFuw^W2AwvC)x zCr7~HvUo5dSd-T+=(0Ez1xG#V-mod=%bX%A03m$!9r9AJJZ0Jq|*0D9v zGI@!;sT51dyfHJo9%`up_|2Mwuc5-E(ON8D<9b**brrO5G!+JRBL(!jzN(@Q7j-2jD0|Q zVCcluH)7=bjNAjUNq*p->aZv@h8|%rN_Mz&T1AYW8Z-NtbWIP)cunkuEEg;^G4}UK zS+3!+klfT}Vxm<(##~WhvUFUsDMtt~oyo#OXft(W$bw}vQ%o^168TeNC1VBNn*{zi zPU3jdXfgrbYJyba&qdk%al)~J3SXqg$RIybNu`K^T&Q)4gD_xCB6re-Jzj@STI6() zZyg)9)2A%(w;eCB`t}W%lkw^cA;t5qOsD?fiWd~9nocDaC!j_DQGie+i0#bblz`Ye zE_~{&pRAT1shX2{`{{evjrteDZm~38!5AG&LSt7GLEA$(8Qu7m1xM9@bbyGu;`y#jXY}4WqP)2Vs9B1GQ z*Cdb%u#v|YKJ`7PfLm zJ40@g5FWRe)&3SPS^IR{_Ue{a%ve}%)5s@D1PIWK(~R!WKL7OCen3V~K@>zB zlVp%+d{?{I5LxKbaJN zI!R&9vAJ-~ohK{BVZ)_1p zzW{wl-D_}_K@v#$S6Gqb)EZvMId<^8WCkp4Xf$oA&{ zjm-jNek07hpV{6!`krEDVtVVt`~8RcZ5q6%*qB)VXqo95{$)Xr&&0xt4+JpcGtmS7 z1JWXKQL^ z;6Q6*Z$k0cg&^?uvH?dk8*9OL$`QFBI|Drf6Fmz(0Kf!bU}RCJ_@}@BLsj~lJtUy_ zc2(Gz{Au%_*8ht;^nVyfuJ%U8P;aCmCMf!UZ}`kWAR`dp82>LFfSHNmZDij!eCvPd z=vf&5FFFRm+ZV-u+R@Vk-p2Vqbu0k3|JDJ3Y{384(X#-V{@X7A_-6TE^}Xp>-$MN# zJ^(=Gw>k4aI@Y%k{ihu>6WjlkeJlDlPyV)ZwAXvP?(P4)BSOK<-RM15?{_$P8=JTB n{)^@GE{j_>5)vU9VW|HPZvI;z literal 0 HcmV?d00001 diff --git a/examples/dev/backward_trc_final b/examples/dev/backward_trc_final new file mode 100644 index 0000000000..a6d47c2df1 --- /dev/null +++ b/examples/dev/backward_trc_final @@ -0,0 +1,63 @@ +digraph { + "0([Symbol name=unpack_trivial])" + "1([Symbol name=unpack_trivial])" + "0([Symbol name=unpack_trivial])" + "1([Symbol name=unpack_trivial])" + "2([Symbol name=unpack_sequence])" + "0([Symbol name=unpack_trivial])" -> "2([Symbol name=unpack_sequence])" + "3([Symbol name=unpack_sequence])" + "1([Symbol name=unpack_trivial])" -> "3([Symbol name=unpack_sequence])" + "4([Symbol name=unpack_sequence])" + "2([Symbol name=unpack_sequence])" -> "4([Symbol name=unpack_sequence])" + "5([Symbol name=reshape])" + "3([Symbol name=unpack_sequence])" -> "5([Symbol name=reshape])" + "7([Symbol name=reshape])" + "4([Symbol name=unpack_sequence])" -> "7([Symbol name=reshape])" + "8([Symbol name=reshape])" + "4([Symbol name=unpack_sequence])" -> "8([Symbol name=reshape])" + "9([Symbol name=matmul])" + "4([Symbol name=unpack_sequence])" -> "9([Symbol name=matmul])" + "5([Symbol name=reshape])" -> "9([Symbol name=matmul])" + "12([Symbol name=nvFusion0])" + "11([Symbol name=reshape])" -> "12([Symbol name=nvFusion0])" + "4([Symbol name=unpack_sequence])" -> "12([Symbol name=nvFusion0])" + "19([Symbol name=matmul])" + "4([Symbol name=unpack_sequence])" -> "19([Symbol name=matmul])" + "15([Symbol name=reshape])" -> "19([Symbol name=matmul])" + "20([Symbol name=matmul])" + "4([Symbol name=unpack_sequence])" -> "20([Symbol name=matmul])" + "13([Symbol name=reshape])" -> "20([Symbol name=matmul])" + "6([Symbol name=permute])" + "5([Symbol name=reshape])" -> "6([Symbol name=permute])" + "10([Symbol name=matmul])" + "6([Symbol name=permute])" -> "10([Symbol name=matmul])" + "7([Symbol name=reshape])" -> "10([Symbol name=matmul])" + "17([Symbol name=matmul])" + "16([Symbol name=permute])" -> "17([Symbol name=matmul])" + "8([Symbol name=reshape])" -> "17([Symbol name=matmul])" + "18([Symbol name=matmul])" + "8([Symbol name=reshape])" -> "18([Symbol name=matmul])" + "14([Symbol name=permute])" -> "18([Symbol name=matmul])" + "11([Symbol name=reshape])" + "9([Symbol name=matmul])" -> "11([Symbol name=reshape])" + "13([Symbol name=reshape])" + "12([Symbol name=nvFusion0])" -> "13([Symbol name=reshape])" + "15([Symbol name=reshape])" + "12([Symbol name=nvFusion0])" -> "15([Symbol name=reshape])" + "22([Symbol name=reshape])" + "19([Symbol name=matmul])" -> "22([Symbol name=reshape])" + "21([Symbol name=reshape])" + "20([Symbol name=matmul])" -> "21([Symbol name=reshape])" + "24([Symbol name=return])" + "17([Symbol name=matmul])" -> "24([Symbol name=return])" + "18([Symbol name=matmul])" -> "24([Symbol name=return])" + "10([Symbol name=matmul])" -> "24([Symbol name=return])" + "23([Symbol name=nvFusion1])" -> "24([Symbol name=return])" + "14([Symbol name=permute])" + "13([Symbol name=reshape])" -> "14([Symbol name=permute])" + "16([Symbol name=permute])" + "15([Symbol name=reshape])" -> "16([Symbol name=permute])" + "23([Symbol name=nvFusion1])" + "21([Symbol name=reshape])" -> "23([Symbol name=nvFusion1])" + "22([Symbol name=reshape])" -> "23([Symbol name=nvFusion1])" +} diff --git a/examples/dev/backward_trc_final.pdf b/examples/dev/backward_trc_final.pdf new file mode 100644 index 0000000000000000000000000000000000000000..3932a6093d9be465465aa05099c82817366d1a44 GIT binary patch literal 14194 zcma)j1ymeM^DmmOZ;^liR)CeBDUyHyfJMg8(#X~r z!1=_IM*;u9 ze#W^HF2;pQ1&v}5>1bbzDHHRHayM;GKiuceZKQz0sbEGj3yqJov%DnJ@|XO^YJq}S z?GoXgjnS|4waKHa`n!4pliNWuhxBHyMk^2XgQFGWTaw2I{p*wcRJGxwpIdgyUDi4^U;?>%iNq&x>93irr+2D9X<}e*!lOUayU49c2e>7!&;YcD*R}^Gd z4_|d2Fxq?gq5btDYtpPh-}=#)i_>Md#zv2txpE?XO#}-fe>JAYPxCWcm1c$0tCK~& zn^jgvL$9=3UCxJ3HkV0`P9BQx+jH;ZUlL#HkQ@~+K5koYcRXOGPd-=$@V7qBw%^=$ z9$DjGTIrBU-|h|nJZ^3Bc)TdTx<_|CIw$`b{_^VK)RX1Gx!mxP<7VL2A)vAR@d4lU z)@EzNQ%-YHmgo+k2zN_TPG zsBh%~5||~2orF6XTj;gX!T?ZZKUW*{x_3AB4uX>Jl!kJ(wMk!dwpp>Pqe*t@p_86I z7QfQUdvzb~x^|x|I@j9GA-B-7_5Mw@^rEGzp!@CmjQlb2VYN=7j#j~BHc?|-)Tcbt zBc2`2(+cN_s+xr}Pn!5oY0h(Qb*^pa?iKKZ)fUH==cs9-M8pE!%$jffLCfgDiAzbDP`6#Vx?TQ?eC!af`RdBKcBn z%EFRzsAANj@5mtQ6Ds>|=K-;Oy$%0mDZYB|Zu1=^q_A-$%e{xf>$}z5!9(4*`MT6` zZWpWFC^BomWKcRDkgK5sN=12!9pFy~sMr`(pA3^pQX)sp>ZE zDy56q)Jf!3`*vDa_}b{2#06O(StoA9fGUI0%WVI7^4R$Xi-^4+N|N&l6|!vfx#rbQ zoc$K1S;BZ839uwfS{YTpaEjM2w680}P3xaV@!}Kc!U{^+6w4Dg{HT{3Dz;mHN24h&htipW z;*g>+bXJ0HE}w1>6bww{5g$=yhHgS12T*C}sO!T|=dX4t%Y!Vv`QP>h_7>Uny%xd} zQ&{w9JYa?rNM_z*l`*=LP4agkVfX(^$j0E+K)8@KWv=IRei68#!pPzK*yv1eIseYq z@k$s%#pot#_;&fT(m1Q9C;(0E!=IewpU0Wp+ zUAg8-+yKP3`fN{v(kybsCVxU=1r!Hn@SjdG9*jv$Yi-M|pb=QT(AE1U!@j8i<*&yF ze{#N;N`>;Gua6Yav9%3RgY|_{B+X!8W60$(6zUHVUMph*C?QZWn5sxHY=uyg$G;#u zO~N+%Ui=HI!?n1G#U0;G+YQ$jG@`uDh_AnSAb3o_fsFvAiD*xc!0!Oxft3Pv`BAF# zpb6_&``Z&FYB9A;VX5|0d|T<0w|H~*j`kXe1sbV}{w(xHXT*~fm}RXoI$`8?vrc_= zos8yo;byLiI#Al-p7^hkC_i|dEncGILxairAsKLeMbSL#D8?BGluo@%ft)sP0a)ET zydGCVMjr{Qp>l6wG_Q&%X|WR2vp|xhnMg16^p@p_iU8d*sS{l^(ZWbobGJ z#s;?p_$;+@9$d4NHF&1JhG2QU6Y*(>RE+Fq!sN(PZdIR>mu{>E*PwD8xca?sfQpx1 z7DIvIFJ{YHM6y*9I5M*kkQb(Sakuqe3v&UE(a7@hi^FH8nVL4vVmwzkW2r2>+8dl| z3?uU|Mydf32ka?Rmclv2f*ZDPj?R{|dS*mnp;0$DEZj2 z=rGu@av!Cc2;eUECuLUl(WEo+CYjF^9a+zzO)x`6CGRx7A}san-cgjDU8aGUgk(Dn zuQw6ewRCCd(i(zJ>2(7>;FG{6%rKEU%r{JPtodvZ?%cwC7TxMsp5uHE;`W$+!+78? zhQ~fxtpBJWin=K3%j1A(??p97hAabJyEz&6vid682%*vD4Fes6kba&mg*TJdFi{hw z9kxB9>XcUvBvK?xtP+VFMp8FR8)^BSOqiIS^`tY5<2I+=TfLLE+(JbE16S1L?KA>o zKPxc=V|l9tt@2ns;O^Hc#-3q$7Je8$U_pGjzt9JU&%pdSVsN~9r&b|e$$pw!*yYI} zTj3F3_>Fx{d}yQCVGh2m%bTWg+f0kd^qfAoX7k{=#KVw;ArdG#I{-~p8VOC`#vDWF zEP3?TJ#GssHU5weOr^2n7n!IcKk>Z6Z=m|rtpBqtn`^B!*^CMvmeV{q+ zoUgjCagziQU;zD7g-aJjwq>gEh#)jE?5jR%N$o&z!q?eCR+3^Jm@!6L;`f82eOL)p z>NUarL|+;;%p$bIlQ&~%{Tca4Q`Mp4TP!hch-FUs2@g1;d?!{#oa}I!uNB(5IXFD2 z!g;0QdPA_@Nrwgb#inIT(elz%)JHoty{aaSH6N##&5?JOb*0{tXg8%7lro8J$EYID zu{9vQ;{q*Io|eS8Zw*u$qjG5b&;MF$K6I=oHS|Oxpf3V;9lOIBE502Ul<=9Qkm0^v zLw69#a%cvAKbt*iWXnC`4ZkK78#lJ<`EX5DFO{^@jUqC$hn2x>$Z%kbkE=Eh8gOjM zwR=4#;yY5)HQPOrU!=tD=ZiwFtY*AgE$K_O6^ibPV9n%9rM*Seo8p259w0Igms#3d z?RFXBg^&ixl*S%NP&1c?0bkf)O5#YyuJRE9*-@){<5x#^IKJ5%%Jo%Weyc<+bonUi zj;8m0*sH$cF!cqU4m5`#B#)MuJ74Bdl-sJ#DmPf@V!g>b6D&~7ljE)iX zUBhNN2ga8o9+Ok0IWn*=?n}{DeNOx4q@2etnR13JA4%o2!*gj?Gwc;M{PgiDSA`|x zyhbF;@%0MtP_*umFf;zbTih}+91GY(7s zmyJ}M(uxSeWl`lgCT}Uo_mr^kDBS!-zu8Co8h#?}J*M~cUV;G)Eg>=qmTVT~w1pVy z7nfvTPlaO$BWz%zzYliavq}tGH%rGQg&Z*i69=+1m}r5wi!-!?!g|=g!641uiETD} zAw^HG3Ckni+rz{rl{r_IMC}d5ZmK{BA{AN!TuqJR3^zH#bbyVd?Fjg|jYKJe6HpfEDC)8)zhhk|cjev-DQY5HAyI z6@fhvKZvQ0G4}QHsVmfdlw4>^U*H=NTOg$Govkh^@Ll@#5=(z$=+w9I9r|2$$xx;6 z&?(1yUP&fj+TvW))R}M71g7;1)piU`GunBaenK%@oT#USoAc!GWWJJj@663X zUn{YREgCQ;0z+$MHkI^{r=?R>)7X{gjW-RKL+3;g4-=J2u{Se!WqUlaPd2AQ5NpXW@DsRBP!icPm^Q$pS}V@Hdz2}nw0A{We_0nfgN;>)hf zSR~I|t-m+VMS4n4Hw zt+wlr&1^LN%=PG9tx14*GXsXyyif?zJ$-DvYGE&FXQ)0XN!h$M*P5`J`f?Z%r5EZ3 zYtzJAftlLfBz^vVO?Rj6{ivpzQ||^xmaclL*~}ilfoX7*^6A@QtSNFOsUjzrVH#nF zi0UMm^R*grfnFs;w35?^2!55^L1Xg66XmSgA}dk^hUr?3&`*Mhha~W{YR1kS?;CN} z0@1Ra8!ESUV#Q5J&L?m@kY%(@fe__40*$FTdn+YVmJTQ@2D$Fhn8xApPan7JBR=%& z3;7Vl+I0F&$Tc|7h2ES|iZOSkm2zxFTjav^hCx=_imzsAva}LRpcu^cWHOy_ zfT3_0fdG?{VlC*`cU&a)I*yC{ZYfLAo}iH26zUIAZuDc;pHVxSn^@qeWNnT~mvS0Y zA>uFJkS>Fp7ZGoMoWq8F0r@-dhf42;9K<#%{j^VOpaB0$s|}K;iYigee8?d~E+N(I z(wCdT%o$T3EUoR2<1<5u$I`yjM3Z93hj3)BT^BOszJgLdpvoOVsho0+I!?R(Y<4<| z^@Wa}7bD_LDxHWH{pn=|o>vKA5DMw*$7;MF=kXPFX41G)^hZhXY#G~$H<(JPHYr34z?NEG$T94~hJLyyvL}44(|FJR!OEeLt9i+M=#jIz1YcmYzZjRC0&Cx**=*4L3MI=K@gTJCDPye3dF3-W1=jT(n z1qfv2coujjIsOxt`5lN6bF`IEuzd=`00ab{8BYfbi1mp8u!sr+*Z?g0pr`v&Xy&Ql z&z(iw%F^~9(qJZ5CRPB~-=Uo+p{L(p%Gm#-48Wr3WNipwkp&qU0;mCOzm>~^AcjwE zvHe|_<%6Mt2}s1s5uowJVg<1A@^AonIJf{>zr{tL6hFz^0G>ljf8t0I5GyHgK^bGEQi^p#0n3Va2v%;=*Ko zM?k=&+@tHvztS}Tr(lvf>ictSITFy&j}-2~joE-(PM186?prjvKQ>=#oBdq5DUtAD z6_I>99hlMfT@0PM0mqb`a^nl2YbdEvPc4J%_L|xEE*Ybvaf>Ojb*E_0!%Tv14!vc8 zuhn+CsW(pXz4Zf($bBI`j-YCSYMEugcB{9G4}Ke}M%$T?rpJp}DP%mRqO0?lYw-xw zmDWqD>Rgw?UB1Mt-v~>S@dZ^Ye6>DRVIt96zfwJs9go^9s{gtA4T}5*N^eY{CM{Ls z(x>{xhSJbMa-a+LDKH%*2fpinq}SQi@Jb% z#U1_i(9wByXWYaH;un=!6dWoB8ph1W7$du^kj#EcJ3CtwH?5<+2)?SQEG{)Yk)V** z^WTFc<7Dj>ZwJknCdbMp%vK6Kx=$RmROQRUIu(u;1yDc9X-wfy;ZMgk)I+oM7}g7G zWy}=tPr+vpihn99W8+MLRg-Fo^XWWj~twd>|H0FUr6M4-lHS|13zI%LPNi)>_F95BK2{idb|@H$o;>@VG#G;G-^))Z3xsdN~y z5`hIp&FdY835o>@wVNGy$s3mSS@r`*-;1Rbj0mXq>#P@C5)vI1rF(gF*x)CsCi1db zq{=?#T|3Huu~^8)yRE*Zso}|8;$ByZ@yEg7M$K6=CT2h<g7IflzI&FthMm3uVD|?HC-M7uR6?PfA)CDAQAw52gCaE*C3#-G zzCM*?X6@j;ecSbKNrR<^o%v)m12W&cbjOKS!hdDb$W+^4Q%8{0k1_fdxrKm9BHa) z)Q25mzB)J2_z>!2eG?g_j?fP4hG2pBA;l_9JB`LNTGvw{-)!{`ubE((Q&K~&#OG{ z)rDA7&Fae6HkN15%nn@<6=h`PtxWX0QeYt_`s?26=}W0Key2_e(^^(c@=jQD(gBfC z`Yc%OM!(GoOt!1+b1Kcmg#?l_9+@dwUQcu6h7h_D<`5vQIE-r-$<~o|cnC5-aD{#a(kH65{?;U8QjS6S-)%g5)<<9ze z99=ijS?VOzF|1X!mkcaOMbER&2q`gF#vymR$$Hs#*=>0ek8^8XE~9dg-iY6acm5;; z2i50s$&c&PO_4cdga!Xi51*rd z)STB{@w`JV+eH5$L`LJv^j8$akNMTMS0!ztXgKd2==+5k_z|V&x?SwmK2y_6#LmFD zCo?UZ4ODPo=ArBiOT$5e#aSdcFNP25i?MyOF-0wB#%su-Mle3et3^+3*|IJadSA`0 z1{0>_%JyP5Sszy^9nO~4s0i|_E`{PBeOFdcQK_;pVp9rs!NVO<_OQACc5`vi0o(29 zw7BTea27D?ZbM|!0>8oJe?EOD!guAT&Y@zqn0pOh#xc5sJ85|h|Dm)v9_F**XOxjo zExfgz@DJoh5nE&_XkD-8YFsW$wHk@--Amkq)=M?pKkuomw=}q355x|}w1Hm)*Ej{)I~cqw;at_w#js^X zDDQ2HaJ5nT^&l#!64sW2gKu_oxt4GUC+KDz#NMX+*-_1VX&+u*ig?{Z*_+YN0xL>l z;Z}@?O%GkezXFp5UmTN$@hg5cTbrrzAzZQjzB!ttVBD_GzSXa24G8rPgu0D1n(@{? z{~jOkagzK(#$?dzo?%h3ywF+HhOHrQV|t>y^;$B3*tz+rs|?zJXuEsHjB$rN<6#Fx z%^_qQEE+|q(sQ_BcG!OSHOY4&#^M6TCMPy1i!LG3ILJ{9hfB^XlaWcuL5yfluM{C> zx2f3P{V`|i+AS{y+4qHOE|rvtp$YE&;>6s*gGGI)3FVS3w9LvkKd>9AS< zOE(+6L3GyYef8atXbxj*d)61IqI43oi3aYCPMx`i@6FUNi*6;j;z_Qxw#9}HG+VDQ zP3+qTCm)Y-)9$u$V-3U2dud&kS+w~+fA8PYAx$nfbz4mdS5SD=~u~lqS1)|C8zuq=c#fpQh<1J5l+M)H8p3;7wY*cI7@R+@+LjrZ>@Nj}UL~Gg z&O3>$ZZ^dy(n&|cmr0at`b!)SFquvMlv+d9-Z%`!*H$N9!5aMzSE}-|RK?|kM~F`y zQ_+Yp+eoD7P=I3(xt-O;B(+4Epr_GjDcn@$>QO~E?U%vjhlGBnU8l97by|QBt^xv-3ac5`kx$9$P5$NI6lp?_n4Ut>}VGbIa`6ou7B$p>*tD+gAtZ8|BeZaL-g zO~~>gmQ#mn2*#EoWRo-?TTCF7%a@*w?=VyChe3Qh4P{#lxeO&0@>sSoqBmHBc}%k` z{m64pLdPo^WcK>|xTp%z7xy1zMco2=O?anw%WcGBSzjC)9m{s?a*hl13eEBp_6O|3 z@$*jK;BTLvc3+pmAI09$s5Co_#W5>PAkXPHskN2pO57U=Au}g6Co~_9g&Z`D@UQg}o(?(13GAm{ zB=!(?piGW-@1EE$+TOyf3!DlJx|&7u=s3%|k8$;CZzOCp1eUQXBjDyr7l@RWmz1mU z)h*QVaq)9Eny&A6JMZGIe_pLGnC3*)D^(;}w{KuM3nm{17;#{`d_#vCv0fy-+txSb zU#`i^D-sq|aCmzvxkPhkAQj|{xKm2d+9cw%ba(1y;ol?|?|T7rjT z#$UDM7dHtQRe@EPJHVllzT^Xb`Iq&*IQ-EQ)5B*T!-7*Zne- z9wAF}zrhuFFUV^A{fcc!+1*8JQh&n$F)g2zakqcvD%!E^r0j|0-hM6+ID7&RVfzhbaI$bRoK zF0utm{|LieeT{*;*d3U%8SKZ2&*u!AjQmr?c)f+G`txp7gKoGK_+=hag^!QVb?pky z(Qri~Y)@GK5Xo&T`nl^QjQ98nw38DodvD0K#G@Wl;24DJ2NIJ+Wu`K<3M;#7bhjxR z@be2s67gRz2>f$UxRLa2>!1^z3jIW~G1ZtNUVMHnEG64%7 zMNXOz-$r2;ox@<2;T+;DW+iY;qce{5ASa_Di7(cBnWX-KKVuT$$l6u9|2Y&0Bpo?uI%v*}RsX;_-dlU%~TG^q_`m zDq;*XzZlkqah1~YZIl^l>Q$0>`~1bm&+E&jX0ubXTR7`>SMI8lcHK631)v%^H+j_k zdrLpgM47To_Z>@WV|&23cO0AFGR3#CN_l(LAxkPIjW8CRRV`DDpdqcNC#V4U;{NZKE7tV@_8L;)wf7` zjb%ZsRym#fjkpEzS`y?2$_L1n!g&*MT20|C{_4ezhxxplbtBDB`5#GHx^)FRne6Ii z0U>)#0U_*53;_auNhTI^=P^};kUfjK0dm5RD|C_Q{Jqp{gcRRSTUi2Z9gLOoetKx4 z?H2jNotgH9l)oWCc-a}WI`$eAXAQ`qEw+tcG zbX79GA8#;tFa#nT0UHVvv$56x({+U2bD(_>h&*g%9ayXgzgpwaIZOf zKaeG0I+-tr!A4e>ZL>6Pr&YXpKiF4Hzj&{g&7J0#uw#>;fZz2_A@(ey{fLIH{QcWl zG^C(RhxqCbbbxRkqUgE6#Sy<$8Mw`>rZ^xI)x?~BSD?o$%yUSB__)jnAUPK`0sUxj z9;h^b5SM0!W)eBJQq~L|iJ_grb@08mxzzHuHg#j;jj>?|o46%PpxouueuCu)+8 z;u~n@C2#Vt_xJnZ_L^k}HS6MmD>L3HCc0L;;yRHlEknv1%qL$824&@>G1chsvR`6_ zw8$2EijZt?4xvHMC!DF)P8m7Z4cHrT!XCg4h8xHJCCfcWTjL?G`@OH|?F9+%3_HK`)$uwpw-888WS?fd5!A9Vv>*Xh77cYGJrGG;lMbP$u|ik;eMwvUdp|$l zRW-DJJc=a89vV9>0_f;BnOa5KCMNChflxOHDbpsz2aJ;l7k008QB+}+v62s@n7}1t z#^Vox&Nbt?qs^!n%WGeHeVZNW(md6|KC|8XwGYXzGJ~6Rp$VItM!v-Ss z1vyGDUhCJUcMF}L7G3fEviB%}bn`f2?(X49PkpL6=Y!j?33d;>+YzHqv~EGN!lY%w z{z4*_LRj|BW(XcSeFy=uZ!4YcJ1-$(>vZXe~@A zmGqTXHPva-JvtO)mhNg9%%Hwt#f^79eP2Fjs)XR*1_`mXzox*vjU+S5lo>MNDr%~J76_-(&QNUmC(5b>42> zG>hpug_@Pct20NOnK%h_dLfbOu@K{mm)!8ahQJtffb#foklM%lIDYNOn6oLmp*Hes?G+$${6C{nhf7j%@qe^~Iy ze{|owUizff==f_vq5kW2U*uL^7zsHe7bnns;rrNe?XBkIZMw&U`y{W!-RcLH)Hjb4 zxt&SjYO_V}NudR zd{teJ{v#NBG@bXfu};VBl=79tCc-$mCnYUWo^Q7uq{c9{U$>aX&^GK}tf?4J2 zN;&KvOu6~*+;#=0T}5I)o(X){h@NCsDAT7Q99F^_4rIiVi=5O+DrZo{X&qCGz{0@N z^sV9@PKk7`Ry*3MKR45;L&LWSiw*^6iOpaUePr)ziaV56mWrWLn{l5?JV+-=AgR93 zxRAaS6x>r1ctP7)9G^_4=DIb-lQ3~= z)gHsHLi+)3>ss@kQQMEKE&2GXukX(*Uoo*76tPUqW-q zE|*p1t$r&L9*y^_Jy1O!SIy{7U(w(6MR;S8_ek%#>D{yFQ$H%u&yi_i8%2RIHRz!XJoJ*N4t;p<_V{M0+%v9Yp(fKo_YrEu-8;Cw zk~tTl9blF+dx~LygX(0f^UKh8z2m}%DQTkT$IF5{`g6T+e`1eI*eCsOo@+xu}0g67F5xZph#b!!binw6M2!kt8K*XPdC8#|>sYk6Z&mbd0237r^V z+qn3e z#3xa+uSz@c_V=Ot{?5S2n{n$4L11`+SSMlMj0wR9gO#_ZbJ516msmeoL;QDsDz-O& zvGzsl#{j0ddeU@oW7oK6jJl2O>4ub;3#cRgG+WmSm_ay1GY+VBU%i5<9v+){@=9sy z{?@MmY7YZX$Fu3plnA7U9MhaP$5V1w_WuFu^xO z#-q@vH(0A0CU8UV@v4!GTVWTxFAwl_4~a1#3C=0nLP#Ya?gkDhHh?6^sf)=Cu_VFNF zHLQXPlG@18wXP6nR9+JbTA7V2Ml@#uY3wsswt*REqoED^Qq-qfx&v^8nGO4VwOnp8 zR$uN6vBdzYeU{8j=h!}SzrNar9KKU_S}{lDUWbT!NmNY^mKL!DskVL&xL9HTz%53OC4OV^}K2AhEA2BB+rx zYW-hp1B02F$-jt4PGE5`#Y&oRYUt5Cm7s{?>ChtMr(_^2RqDD6X<2)}&IfOgzRx{9 z?RWJ{jh6}jk(0TfT^=!H7sboqBufon^>Pd*c8K&XsjO6s^ile|BMnh-fdQB_@oQ9i zk--fjq)a*&FUo~v!$^*OGmj#*(*I78Dq4b3zweWbfgX8cMBgOhxq)lUcQGpG_|%Z9 zNPa@EWkQLQo&^(aL_&V?O#MTHA!F=Ul#%1Xkzrhz-I6#Kq0(oGK0-1P*vCk^Ga2qE3+IN zoA9F}8-+!2lEIdAa4g}i)1h`MslGOFM_@!m- zyN{!mJCC(cj}X|hi*%eH4s!p_PCw^L|H@8tbMUhNoiOJ8cUG0{ITQTbgCNZ8{el=Od95zt_Pb4DpLkp?SyBy{yUeeL^g*3f7kL3W+ zOt%Zkl9$mEdb@;Bc>Ip4gDZ)*y3zu=WlPD)O6r0XF8#H&;Hc|DfV<*{tS{l!PHF2( zENORhzkDXt@vX*fNIK+LTI^reqs5!ttYQbK#hW%xr`;+Z<)msaV((MXvRn!sqMzC& zPCQTum0|BQl&M5)h>va;*pVz+8%)P+Ds#KFDv{i^3N8S0_Lm%pl|-@BENvYD3kw5ex<{z z+e)r?i*F44F4+ql%_%C%M07y(%M7$rtwuYdxAA(W&58=&&1_eLJU5C!gjuRoOegs# z+)c?2Evun8iWG`4Y|x7iHOUb@>uvF)l-GC41nvXw#(?p1|B;#ol^&6!K%Yp1jXd#dr5G3xR#4|91IrNB-q=4ql$WUC#EL{{K6Z{wII` zj4*g-Xrm7?v9`5>Jg4J-gAP7|ES?Yqic(S%A0!yx8JdEW?Gy|lCgA@Fh=VLl%$)#q zf64&A3+exqnuCmR-z)`NLknd9&!76w zgyi4E^AMj<0*35>XCT2pfd$V>ksuj(5F^`z^4&Bzk9fzCc*PRULenJE*C5NZ~5msKyDrYCmRO< z$jbJ+jFambhw|JO@XwrinwwTimL|`enE=m?{nM-ep5)I+7@OZ#d~ag#Y{O@;jQoF~ zE_Tnrn1A3+*#3#7k%d?p*y%rQprX?UnLw-nY|Lz2%&Y)9V_REm8$OmN(9RRr2m-P; zHqp0Xwt^VZ|5*m2Pur_(O{^^6Jp+E|-tn=svIAMUS=rctZ0sD|>h%BA_kV#xe`9w< zK~HPK%IJ5WfA{_`iDee&z`0kHgw z#|q?P|2L15?dc=&-*T+1Y)?M^ZypfH`#*VX9Bltfdpu98{J-ly4VvR=nE%z@6OZ%X zay%SQ=KmjZJe>d5%fcIY literal 0 HcmV?d00001 diff --git a/examples/dev/backward_trc_fusion b/examples/dev/backward_trc_fusion new file mode 100644 index 0000000000..7f9f9df5b0 --- /dev/null +++ b/examples/dev/backward_trc_fusion @@ -0,0 +1,63 @@ +digraph { + "0([Symbol name=unpack_trivial])" + "1([Symbol name=unpack_trivial])" + "0([Symbol name=unpack_trivial])" + "1([Symbol name=unpack_trivial])" + "2([Symbol name=unpack_sequence])" + "0([Symbol name=unpack_trivial])" -> "2([Symbol name=unpack_sequence])" + "3([Symbol name=unpack_sequence])" + "1([Symbol name=unpack_trivial])" -> "3([Symbol name=unpack_sequence])" + "4([Symbol name=unpack_sequence])" + "2([Symbol name=unpack_sequence])" -> "4([Symbol name=unpack_sequence])" + "5([Symbol name=reshape])" + "3([Symbol name=unpack_sequence])" -> "5([Symbol name=reshape])" + "7([Symbol name=reshape])" + "4([Symbol name=unpack_sequence])" -> "7([Symbol name=reshape])" + "8([Symbol name=reshape])" + "4([Symbol name=unpack_sequence])" -> "8([Symbol name=reshape])" + "9([Symbol name=matmul])" + "4([Symbol name=unpack_sequence])" -> "9([Symbol name=matmul])" + "5([Symbol name=reshape])" -> "9([Symbol name=matmul])" + "12([Symbol name=nvFusion0])" + "11([Symbol name=reshape])" -> "12([Symbol name=nvFusion0])" + "4([Symbol name=unpack_sequence])" -> "12([Symbol name=nvFusion0])" + "19([Symbol name=matmul])" + "4([Symbol name=unpack_sequence])" -> "19([Symbol name=matmul])" + "15([Symbol name=reshape])" -> "19([Symbol name=matmul])" + "20([Symbol name=matmul])" + "4([Symbol name=unpack_sequence])" -> "20([Symbol name=matmul])" + "13([Symbol name=reshape])" -> "20([Symbol name=matmul])" + "6([Symbol name=transpose])" + "5([Symbol name=reshape])" -> "6([Symbol name=transpose])" + "10([Symbol name=matmul])" + "6([Symbol name=transpose])" -> "10([Symbol name=matmul])" + "7([Symbol name=reshape])" -> "10([Symbol name=matmul])" + "17([Symbol name=matmul])" + "16([Symbol name=transpose])" -> "17([Symbol name=matmul])" + "8([Symbol name=reshape])" -> "17([Symbol name=matmul])" + "18([Symbol name=matmul])" + "8([Symbol name=reshape])" -> "18([Symbol name=matmul])" + "14([Symbol name=transpose])" -> "18([Symbol name=matmul])" + "11([Symbol name=reshape])" + "9([Symbol name=matmul])" -> "11([Symbol name=reshape])" + "13([Symbol name=reshape])" + "12([Symbol name=nvFusion0])" -> "13([Symbol name=reshape])" + "15([Symbol name=reshape])" + "12([Symbol name=nvFusion0])" -> "15([Symbol name=reshape])" + "22([Symbol name=reshape])" + "19([Symbol name=matmul])" -> "22([Symbol name=reshape])" + "21([Symbol name=reshape])" + "20([Symbol name=matmul])" -> "21([Symbol name=reshape])" + "24([Symbol name=return])" + "17([Symbol name=matmul])" -> "24([Symbol name=return])" + "18([Symbol name=matmul])" -> "24([Symbol name=return])" + "10([Symbol name=matmul])" -> "24([Symbol name=return])" + "23([Symbol name=nvFusion1])" -> "24([Symbol name=return])" + "14([Symbol name=transpose])" + "13([Symbol name=reshape])" -> "14([Symbol name=transpose])" + "16([Symbol name=transpose])" + "15([Symbol name=reshape])" -> "16([Symbol name=transpose])" + "23([Symbol name=nvFusion1])" + "21([Symbol name=reshape])" -> "23([Symbol name=nvFusion1])" + "22([Symbol name=reshape])" -> "23([Symbol name=nvFusion1])" +} diff --git a/examples/dev/backward_trc_fusion.pdf b/examples/dev/backward_trc_fusion.pdf new file mode 100644 index 0000000000000000000000000000000000000000..31a387d1c8b03317011050e8fa7d397e8e4463f8 GIT binary patch literal 14244 zcma)j1y~$Q({9it1b1f%?z^}rxCeK4cPF?*aCZpq?(XguJh($}mkY@`=brEX{`)+4 zhn<Y070dYeRwUyuew%LGGGb5wc{!w3KZEOks^I5`1y;`$bbc18fE7nTeR z005v9G&Zx-w|TjmY1`@Z>+4$T>BDey!`Rx{=xdw9IHioJN`&LqA+=nWC2S09FFsI5 z1*oA5gJbp+R#?;`v7)d!R2H#@Z2_LSf1!O!s;X+ihDz>7B$lX9HiN(2pLxr6rFF71 zx}JWxW#z8+RDo%J-(`uA?xqXxal4~~@XWiob94V=Yp4vtfGOY;-7K!VWBF)O4JKQ+ ztJ4y(i0S()t&(^qVNPN;C6rD|?$h&IsZ7->3RIiHqWEgN98%lvb3|SZ32Hu?1anST zZc9rTgZbQv>TQuU1kLKQ>kUo#YfE$-1>KMj-)XuKzipgEnXv_`=b#2iTH?o#wYD%g zIz6t9)D6W8**1#xm4%T}_0dT(NIZ9)STa*F{?A8i(}9~XhsyYM#rUD>{N7TwyrG<=x1;tY^CfB_ zX6EKDTG1`LmsU3x&VrZj74vxm?M>5goIwQ8VIy;}OTBW9ev3xyenHp%CJ}u5+Y9o> zTW;efTSxZaxc6=MOs2#w{UbvsR2y|HHzJIYj=A(-lo%(ufHC$LUTJ;nt(K&1Zba)nHx*DNxthkP|I8;CGVZl;b%ucmptRhxA)I1tzBL&SF5P*1y)HZSF$_WS!yl_RHs4r!mS@oZats1+Q^=W z?x~(59~ho=^CvwLD%m*zOx^GWsFRenJcT}Wyo0^o%Rw?ADD-yt;6*thAB!i%x2%ng zyoMVT!aGMTcjaRy`7cJiOY6jF15-Lh^Ss@GJ4kR0{P+D6dH5I%O69$4Pj(MxeTa{PSBE2>kBF>A^CH|a-#aB@Zd5}uyxGDLDLP1 zZo7Vy&Msw@eChIvsfKPGTFnH+DEjGoXL)a6ue!gHX_k>XUE+a=$OF6xM@KJ~pyQ09 zIh-y;(x6uh9ZxgGpnC;4)+gWe5zuEz8QKLJ)h{KMC`mQ*=pkrw=gsJu1;^L$rv~=* zz87p)DDHq>A7Z!@#f6oQ=Y_v%RDf!Q{0Sbn@ZLBuRbIyn8C-)Td-qUWn6Q688wGg> zp>#RFcpx^cahgJZH_EEFUOH|cZ&B>lO1z+GRJIjQO}ogYxm$1xt^M+y6_1XRTy*-~ z#bXz3Ga`}qN6a^#M!IlP6(P#E{kmIR>`oZ%Z~)k`%dYd7dOZLf^)+nt98%WcI4TLxp3Yq+x0)n+go&EXA5QgJ;mv3s%Hxw4aXBoI zz3nMulr--bi_l7PS@ec*ErfFo+W2M?%%%9eRckSLS8}IN>2kj^NLBWF1V^uON3b+a zrIIVHL!%dI1>sS@33i4i2jJ`EQ1>~0qN;U2Z$+NCmMZUa^e6gN6)8F6)J_1z*C!fA zYufvWju0Z#FOY$WP)M31=ve?JNYQ(^~?Xhx{gai7R!le{ZjK-lom+DLI%*Mi} zLc}<8Yzfn}&7*hLcmO(Q3MgS`D9>sC#Q|b6AJznGcIKWbsc2$#D#I3bx>M^=>_bnV z3UQo!8}$leSSR#&nkA^E#MnDbsi^7!PxQvIN!xUY*3M}f04DP$(<*3^2Xs8hPY<0( z=F1bW&c5M0qxq4XPs_h-8)&f!-FfZY_c^$GuSJ%wFO-nG?D81H+PQNl0(-i?3ND>+ z$X7B-tS>LGD-~=38#~2+x}90h>pLi)x4K0_w+IxS%QSXw5Lj6E2o$XGH{*s9j7-^_ zEF=&5F6>Q**e$X|+4$7@*-o|wl|8Fev1@_IvOqT)DBY+L5Z@PC?5vo3lJa9Sn4ZjY9cULdATaJSl}#^yle^ojHz|a*R@9q zhJ)lC)HTUodx{|waAi9q=AS*er7Yha$A#me!r=2&pINFZYHEgSkz%(j+o}TPm9mZ9 z5+Q&4m}<(Uu-fCaDr;48Njfb9O~h<$pO-}GqN{Qga!d0r(VZP<+ii9Rzg=-qzwR?S zlI*VctPiaXJawGa3BE3w^rUtVcV*C%BK98VjGng_^RefRR1v)dY)da1oIRv&tb+fB z;U-4UL9XmzIh6g7dI{5R{*HuAER#PaG`^2YK?S*HU4Fep@1^$w+ z6!&`--lDs98O%fwhTm7MiS0~G>si5Plz|8xY;bH)~`2;lp-P{MJ`Ar zUa63RLy$M$;SKPv-hS(atwUcelkshfJ78K?303{rw3GE^t{cAa;LbNDk(Hr)8ZV4+^ECM^e| z(FtIl-bWS%V#{?-*<-^I(=^76D@sk34c*?Krc39K`SpzwmT&%SwSLQ zvkdgU5+UgQa`LzU<|VvrOXdmv!2M%FXa%+Cul4709l{W!3+vUv-=Ub|PPrxpx)I?( zib>aE%yz!RU;y~ojr%@}Z8v;Qy5$m`Q3O4y zjuVqU_Gtet*o{)cG--PTpzSa&rmp{d=A}dY^bCr!RgvxHy~#M;$(r33PNU9XfR(s8 z%U)}zMDu>~H5c4OTJ7OmYEo(c!lv88vnYJWcI4Ti!(wogN~jlxoOrE zZNe)|!sLg9_z47oz=LmiN>{rOw6pqZ7E|>IuFTK9Tv5ae9xCp;d&t zDp>fspPvb3+FP}>;$j!)EejGY-d)~_c!6PA2?xI11X&@zhdG-(ivBpauk;m)MzVZ= z05xsAkgVCYCN5`!rB~yvTVY=^88l&7=n0O3t@S!FDG76Zf95HQ6PZ9Vt@O+_jb#~2*f!&JU&XBsv?vHX z>#DCQJRe_U2#S^dz-SE7Q?R7FnBjeR>_np+&pT=vBXxO4V}jx6=^0 zgmy6;2nD=jlKwHU2nM!0LysXN7!Stt*5U&-`b~?9g#ig>YH;QLMn?y>ZRB4tCl2p8vf;)KHTrSTaRzM@9>f*HL4 z%62B`8BvR`%bDv5VzSAGE(}gN3CZC&n`q-4>OXKvC>8bbV|5vnEW@ z3ZV8UX0rkzleSB;18*qtWs38fn;{x-Mo18yd`%76J|$-7VM&|zQr}Ktua>C37xtQ@ z?=@|?7ieC*Dftjukh$4mNw^dA)0+aP&F{o_)r8Nmo|dT+yHKH)Cw6_|z{L2wWR9`K z)I*?m+tRRgI$#N5kivbPwf1Hgx=(uco;Y-a{6MC)-}<{!OWgOCkdp4|Guo40>wZ)f z0%Q5xJoa~INK+t>rFXQkgIt^S-KuN3X^N+$__}wvMYsvUhr_hnNga_TqO19!ODZN2 zpMbY`!hHPx7GKDmN_xq?lYF+At<|=@M(oq7YSBjSBPq}Pd+mn2m~6HxbJ_GratK3p zXG)7i^aG3 z#GF+Et2e9z8iD8Dj{J4mu_3z$GaO-D8OSbN_#%=m43K@tD)SpE(OXW34#Qujj`ms2 z*TZ2V`hd#@u;zYX?z48u$%WTu{R`rC=5t(Zi1?*BNowbEiDW8-W#3uNCuPCybSk>` zkf#?T6;&XeMO>=La_*9gXty2Ox^G=r6JQQ1WK4gcj@sKNEb?t030qMp=^Ht)?NNM= z&sImpYShWK?vabQjz3#VkWoy{dq4ujTTjt!J&KVY*L>F4yCa)%ekLloTRz5`vct|= zd9XGCcZtH6uJuRUznm}e5`yAWlWTQig*>(1cNQLo2DfA<6kDs{$)uC0P)KjFAM*MX zIdJdCdnG1K$8mDwbrUqx2h?fjMhckIm&h+qbdeOSaLNhFsASZ{HLa16NVcX(LYCmD z#;Jw$=ev;>^ArT7>Ws@awoqjj%UG#kF8;WSff_nnv{k81OH_)V*+)iJSt%$`{O)C>tfAY2L2oZL@kz0b9%yv1UKKPlFUm8e&?zw@<~3 zT zPbcmWMVw&o*4!uu#N)5QRaC`zJ3ln5pc`JM3!^+K&`b?_*SP2m4oA$8!kLs^9LH0H zV(~bdYPZ;pia4pO`W%z$K5D`=9xzEC$bUpMd{`n1EhnF#ozJ=aj`IeWv*G$c8dgk@|lAvK^WMX&~cqJMBN#y)a-3U0@ z3CY^Mq-_A4oUe?R51lsfg#gg;@c=*oI$iCT^Gh=4rQy$+PSDcA?qAXdG(Z|4fcfua z(2LN^-(Sk;|4{~@lXJGx2hd4r8|ni{0HEK>rL=AIU&aFc-Iq>6U(Z;Z*U|}~`oaPN z=z$<0fR%w6p#EE&??v&8ye;508TBXgBxGZ0Z}qRq{893o`?oc+FRE>{Eo`k`mFYVF z%@zUJ+1Ts*~u0@xC(uvQNlBxna1)$|Br+i>p#=_ z-x6Z4p8b4SxfKUArbA9cjez`=^3OT5jg4k zGi*Cyl2VKo?n92-8nm7-dmKNus`h+oxz;fK_4BSo$Op(P{COrYqkTyLfwlqFgq~>g z8=!kQsYypUgZciJ*7qS9v9oELCb4a|Xy3zBhW5(zqP9E7u;$d&_6^zpwkrHT)th8e@9YjwXCGVe9hDCi?h~! zrBvaoE?I>HLt%xdbS5uPEn?>Dtl; z_kLks!j0qB--(83-)?PY;HLCt9_J9Z-hZfB(Q3J>3#gad)!hgkn^$o`ON_vQuFU$# zps1&6M0Vxg=BTbDQx?`Gdm_gPCn>EujWLZe6W34=M%Sxf&!e6(Tfi|5m4PKFSyTpMN`X+8 z`qZOS4k%q3B;+9tk>B692=AO*{~uG5vVcZ2D0Air%=RI^LBLuxbG@y zo6*rGd$fP37=CxGFxRdJ&}h-Wb!%3&qyM-ro9a)jNr{}uSWwim(W##xS0G!v)rp?G zX;Gi$FnGLFEFx=&NqkUeweU3|(Me9Uk5v-{HCZ*8mrW;9_9gGuN#>jRLN@w+^*w0~ zYwj}3hJ1`aDk2M9&ax2>B?1<^tn#d;FjG#FEKR}F^k)+wlXX5w5z_$M0N-GsW}6iD z9N`1(hiy`LI0;131kxljrQUloNK!~L;!#wCB!eVF6$5AkXhTwT95Fu`*@A{{nRPn2 z7-%sR4F5u3O4-LZ#kXNHak+WGFNSI-Q03(Q+vFk`qiB8 zp5?A;ugeM8TV-Iv5sQbz;=UWU&0850RvbGlQdX3gVbkgFS4^hW2;M)i+h|N0D%J1G zCm}Fm_y8wMD(%OVWCDLGQVn{JqkEZnu9cVaKQ&z1HJ4V7g&Y!-WP{y)ocj;9K zcRvWjCVpB}`ljBa5c7<#Fr_u4U>m2J6oj!3GmOoj7b7a$z(_cPTI=6~Y7Ct=xFnW0 z{0VTM1*J4TTOq0!{GHHiee$-X33OI%36 zyfsN3G&5tdM z_GkU~4hijE`qegigh2DPr+N15ns8e_o^AnZ38xwKlJ=hniQ*TZnEUUf5A+!~B zJnIZ$5_82Jb9b7pR_s>XR;JLIwl}0RDu*ZxIeggW&oWTqe4dy6m?iIu%xp&KFg`Us zdk!*U#K@PJD%`FR*=9h|^{*(_-@Qq!aNWASb{R#-U>F!PV{@0g>{QM+);$anSG_Sg zjiOwgUu(xJY3GAS<#(hQ;GyJrFFM!r)j`>sghD8G7Ti6VX2o={f&nQHc6US+(k57t zPUt0ReOOaY`{$H{qd(WBuaV9ebX+qe~Klg9sY z=7E>}#!rPo(R4BQ7OIS4Y!_|H;ucDxv^XBzTHhLWRI-(=whQWs&@f_~AO*ggdamZ{ zb*Xw2j)QxNd(cLyT8H(%;znzO>+N9dP<&hBw<>uwp_!kBg)U!q7S;gWTtH31Y;xw( zC^|!F5+5RRGXPQ&$s4WA<@g%sAO}Z1yb`80RV_q2;Jfm^_6S#7`O_ypF2%6+6jTh; zyX*CYBS07tFln=hpuDQe_ z#`?x+kBgIYgHPu5p~gfN^6%hiYv%NX8BwLsQQuXlV~SXtDa?W_@H| zSLig@t(OM2HSv?nP2AQ}!ewPYwd7pejW2qwPw0_&t)<2W{^TdNEH)V;!EXB*Yok*VS|UHJjsvwz|TwdLFz4r1t20*^+Aku~(s_AL0>?yj3RMR^)Q z17DHGveKj6T!y!0i--9FzXonmT$T63P4owb?T(bgbt0O#2`3&D&i1!%n~q=`1bB|7 zgY$KN-xAGK-5y4NTeY%VAs3uKq12uknCveCiBKh&UCuW3LB({Mowth|mpv0Z+2m9Z z9Wa$mI91g5L1XhM6hlJ==O^;m5@f03+fv1^5*{Hwbu>kzJfP7?zTp6;UP61zt0@wp zG%ioWu~NwC%C+N)9kXA11j6vJOqoT|)BX@$n?_4+IKzCIC&R{x zioSo-Kz~zG3M~;Gv&cuw%@PSgB1=ag^A5QPa*wpa#Fov<5wdfqQV8O~HR? z1VufZI?d?k4=e3B0$JZ2o1Dsa?Q>2FbP7%L6AlLL!!h#CKVj^gpZDCBLLJ9GkSew~ zgo^Mmg5|u+E2_<2VmeuM&H)I$@{AdiKuj7VYlAH#pBZ zhg?k~Sv6gx+{c;wG&U1M7hWcP&=U!A zdB0nV+14!Wrm3lzeW#R73qFddDTo97mN4vSAJ^YGPNQdkVl^0}6b5b${eH^kWZ?s9 zy6WesoY{`E{3u;)JSco0;*CJ0IWT<}yUGUFHFd6|aij0*GK*W7R7#9h*Sp25JZY`e zog#2Cj@T)#!GwNeA*wivI6qLyHW&Fa%xdG92I@}rMMnwJ-0#ph#khb*OE;il1^2J& zQ@Wda@6+<}srCkb-b8;j{9&H&T>9f}G(vkA=k9xwNR|b{L3Ap8e)YvS2{u_GtBJ70 z0%oprxD^IHJrBoz@CEAUvZ>OShsr=#uH#V|e^KF1WFn-X*!;~HwJ6myJr}fWEkEhN zPD;kHf_l`tb?eUYkSm)vZ!u|~Hhpy1ohNe}5x$NZueh9a z42CqA72Z8clBJvNeRv)eEyn)g6^2;j0e>2UA|bWZZB%5(D7qSkv__4Hw%8MxvK8#d zgu(6tk^JEouhB*;O||u2RD)Kyh{4-DmsK+A}i4eVE1H-uYZ3vgHQ{dhc zXJF3G5cGW^w?fZ4G=bwb#H%ngLY0{cB#JmfO*L~!(@g#)7*oUO~@EurMzBxpoIMXmwF%Z)! zjitBhblTVsW`I0YBFk7%*u9ufI*t*4S-)ZE&X%!@oq(piR7-I@Rt>5Q=X(wTrt(RE zGhAzF=^YPg70mc-a+BSTG3#R9Qz*rhDVVzeu^4E^6bv@pj2dA!`z?0`{hK%+t>zBg zT=9`^=9*|*wR>F5u6j+o&7%H||PP_C2=f1==;zZZdEOj~0GRiQ;9M?zLa( z%E!j>N*M>GVGCj{mOfRe_lx4(f%Da@Jd@yzTsrryD99ws-e^;-+9}2<&MCSnvO`h4 zE>Z35cnJ(zaKsU$A)TPIzy=Z$3=Fjnh08kps>VnP)fFz}Hfhbr&A0`@THFr}L{B!` zvX{*`X*GrS7;9H|9%l1yR!yY8WLA^1wCW0WGwD^z0z&p_0z&AQDFZnDl8nvgE@P^& zZ1&CT1_`mg{3MS=;OHX(VSQ{oZ=(yab2O6A`{khqzgOfBd12BYQvM0|-P^99wQ*|g zIO90yIQ=*e1uX?Dh4qke=x6_S|5g8PxbH3=ZW-K4=}H9btDg{A5ji8A0GqOtKXM5i za2@a+@Ei!PaKku`SdTc5*pIl6*zgQzC^1`t&PHJT4XiA7!YMJ69%*O{OlynX^{&z` zlZ&`r>=#ZOdJc29h7?8@>UFR0%2pD8az88#qTMp|Nf0C;Ih(D7K}1%U?a(#tq*Z(p z8|p8nSbWsUW=Zo)*tJcN#pvdjjlGEII3^`87yBFw4-=H>7+)-{z4WA2v}psPX8+lKZ|Zr>%Jzy$WF!eG-ZlUCM7V zoF*HW#1FS6`Hi`h1z~vn$1EObIs@>p#=k6n43rkB?^v3LDg)qXWZq1oFitY;x%c~u zySevceg;c19!8yeXni-EjglquRnbF;K{F}3*qH>kv-l2-cG;Wo`{Uz5xPw~RVa?d4u$L6!HBWa>6X-=BAYBtExttV%^3>$qWL7 z!SS0FBQP#Z;w(~?rLTITc2-> z6M34aa+o#f(XV4zYK_*QS?kT$&dQ2#FC9WY=JT+@2wg6Q(kp7++Vmdo%k!ce_EQIs z0-GLIXQaJ-bkS+anoD-bgPLIXz=vG{l0>Uk7)vBF8kBFi0x4K!{IBvrSoP27jL%)`+gBS{x~; zKUi*)->3hZb*5qn#(fYssDt_=(tRX>VW#-7F*}E@Ja!5E86R*LGeWu4a3P~)>e-)b zAA@E5L2_^?bnI}+@uVSFIoURM3gNnFDk^Haw5(13nNRic=(6qWSpLnEg(VF{L+rz3I`{^Kr;bIElDdsMd&BUNk9X9 z!706&kQLI(cN{NR1Q;BrYC3&!qFvwSV^3Zq}Y-YWgBa zw~xT?*)q2iZ8}YDgtHja*bnU&mZ%mfSW<8~i9S9pc;r93@82#^(&meBiYXor*+R9IwwL_7lGz3%LSsT>_SVLD@WsQ@B4k6 z627wCc8rV3{qgc*-7gERv=m09}_iFQwOw)N%if*@ZuBk1*4R!*&#KmZKy>n-_xVm8Ta@OO?)_t zkoFu3AwfeFuMvwUl>g=9ER66@cV~X|Og3+>N}emLTtz;I-h(DLpWkhdYsQr~_R9t5 zl4|r6P_|5$6l+8tc_feuSvqn`GpU?X4z+DuIRY6GStiwD2S2|MkX`y7mmA$mRFua3U zER5RpGYwo3XgzmKUD6lVoR3F#o?WMIc92)T;hb{MC4arHDsS`qIqA{#xYqlo*W;!I z!TCGFr~U|UWWrw2eK(y)I$e@yS&BIVbrd6fBL_EH*i}EIP3Q%5UsiZ8UB-{_Qte$2 z7Z7^AM1eN=bMFe#M# zdJQohEW5@fxY%)RQse=1H{frm*?xb%+K4qWk*-O$bKyqE)>jNE$pAH)yQ|5Vr9hve zKhU5w)#maxl)rC+=O`tO@A71YYgc!!@AEH|2@yM_Pr+9k2ouR4kVg=pica6KSe7KZtCr&TWpTKhJl4GUO+fNJKJnz8`H~ z(fCE(@AFmGu)V+ms&4Hm)fvl00v3c&O`0AfYw(Te3Gdb!k0st9cjTDk=`lJ@dy?$I z(5gTivd_~mLOXJwPx$KbG5VNflC0(0(p($eAf$0Za9v1}&-ACO3XM_V1Aiys0YlxJAyX~sCHkw0k0kzpM8l+Y3lW0!dtjGcnHTe5+Jx0LKAkj z7iO@JkzW&}%61Pnf~h$qD1?ru2^w)v7e2A`fwQ+%CGyX^> zzIj6h@4_jnVI46`eyd5zmKzl>e`XqsD?bn!*R&gXkiA7~qeK!g8-XEz<{iuFF)p6& z7~x}ZrWW&17{QtR`ATSjK@wxqQ0~l*(7A9$G+mT*1P$E$3#mkOKX(u(nW0CsGg)x?6C+b+=+%yEF<}$#f?>+&ta*-K}oiTcDgWW6aG+UlU${3K76C^8hPj#fJhsYLZ!=B9%tif5b#2qtVgtD*-G{@hs{^FcU=Eb-Z{@q-To zy2_i<7|z*g#SD@AoxCa~Q8hWp>NrC98oD_MLIc=T%IYMOGY&E&QG+g(qNHfx$_*bs zD{7EY#a(Anf$KScs)XfMiTfN^xJeirLsv0w6H8qYK?0*%>rbu07)(n`_)Rc!5}AP} zR@j(HRfqJY1(pw8lk5XVO2!BIN-cM8bt`Y`e1n}avE1|X0aw4&c=6!XoXmsl@`z#k zC^kxGDH5ovx8vZk!}y=^r9@gqkJI^&Rrw4G^bGJ5zelAP>D|4zDHF{_kNS$cX(&y< zmG?2V(*Hq@I9iBGw_j3RPlqruqJQfBrJifdk^r$wd}>HlBnOt)3YO4W?}D+$dn^vY zOx+{BVIvegqR5Hh$S`K49${2-{JcuX>1b&)XgqrsEVfP_p)Vha{6E!2%1i>>-=6}wtxWqCdR!+oE ziWQ=N%W9k(pjTS~fegvsE&?rAh1JWOBrW-I4<0*_f08H2?yZR893vuENksP5SY|<% zhzDY^2m1Ri!s4hHZQ}X$?*q5Q6na@AvLf%eKg`PShSX${Y+gDL8G{};u>>Z2ymB^$odvu?VPqDPnY&Ecj_~# zf?+vfi`ywp*Xr=L9zNdqZVe?sIo_mcChcDCI44zO5#`_`8QnGa5yH8B;^fmu?lP1E z$}+`>O~J990(;y=E4`VREd>_0HhJ8KHm(Ig&cU)HjyxZ-O58wx?tfnw#078 z-0mxoEqx(~+m=3v%8F#k=d1Qx2N{AE$~}EdcXSOe#ZIRBFXH);X1qM7*kXs{2lPffLVtCh({bvCJI+bxMP+)VdW33H<`d1*yT1vHa?LEe?@k^%Jv zVMSnhAcCMfm4!!jtab#CQ>Y&lFx>~;jQ|tn{-ZU!#8&iY<8>IZWbyOOV^K}r3BBY9 zR%rx@DTF6(y?NGA#Sf3*3xWOk{~hw*f&44x3~a1_haB`;{{Op@{-=KbiZI~Ux7D>V zwz9Ldc`e8P1|59SHh)17$ccyuNeEH#>zimR*vsnM7#sW}AgFC_Y~~Cg|I-He-AM7L z)lA#)rG`)ck0|dSlmQJ3GaG=0fsq+N4+1g)Kp-I7Utj|fJ8d&#T^eRBl>>!1Fwgz(?QtBEhm zS)U&83MBY9u;3L*@!#(FGn~DS-ES0y+zWK!)m6W_yxO+{AK?CJaPgSFS~h+Eerv4ipCZ^7PiKJGk&X9G}f~-vVB2oFo0go z`1irU1_Zo*|9vpOKso+@zJL0#GW^dmUtlI|uQ-hV{a|K&g}nUx_@n>NnF;vX!o>D6 z1ICvt^B?Nf(3epe8DA`TefKcGEP~fNHb&OpTxKBsZ~50gj4Z4GCJ+OF5eWL-#>D)J zLwOyG@z0uhS(}#f7RImJp8&6e{X44vy~tmYFt)$FC}ynp>cdyCjLbh!7yDOW%)js^ zpnqd&q--qp>~&wZT#@T)8{1d{K(rucS|EVj$j;8nmYwbewDZC>w9&RQGS;=FwX`v$ z`12U>y==&`Gq$wge+B%I^Rv?f=^246KoE!#M9;vYLh)~X{}(9qH+F|l`{kLiH2gj1 z-?RSU*h`u0a*OS12QtR{5Ow@ z`Q=CAKjnZx&`XT}lgG%!_&<4|m*x21_+ gFn9^;? "13([Symbol name=return])" + "1([Symbol name=unpack_trivial])" -> "13([Symbol name=return])" + "2([Symbol name=unpack_trivial])" -> "13([Symbol name=return])" + "3([Symbol name=unpack_trivial])" -> "13([Symbol name=return])" + "4([Symbol name=linear])" -> "13([Symbol name=return])" + "5([Symbol name=linear])" -> "13([Symbol name=return])" + "7([Symbol name=exp])" -> "13([Symbol name=return])" + "9([Symbol name=reciprocal])" -> "13([Symbol name=return])" + "10([Symbol name=mul])" -> "13([Symbol name=return])" + "11([Symbol name=mul])" -> "13([Symbol name=return])" + "12([Symbol name=linear])" -> "13([Symbol name=return])" + "4([Symbol name=linear])" + "0([Symbol name=unpack_trivial])" -> "4([Symbol name=linear])" + "1([Symbol name=unpack_trivial])" -> "4([Symbol name=linear])" + "5([Symbol name=linear])" + "0([Symbol name=unpack_trivial])" -> "5([Symbol name=linear])" + "2([Symbol name=unpack_trivial])" -> "5([Symbol name=linear])" + "12([Symbol name=linear])" + "3([Symbol name=unpack_trivial])" -> "12([Symbol name=linear])" + "11([Symbol name=mul])" -> "12([Symbol name=linear])" + "10([Symbol name=mul])" + "9([Symbol name=reciprocal])" -> "10([Symbol name=mul])" + "4([Symbol name=linear])" -> "10([Symbol name=mul])" + "6([Symbol name=neg])" + "4([Symbol name=linear])" -> "6([Symbol name=neg])" + "11([Symbol name=mul])" + "10([Symbol name=mul])" -> "11([Symbol name=mul])" + "5([Symbol name=linear])" -> "11([Symbol name=mul])" + "7([Symbol name=exp])" + "6([Symbol name=neg])" -> "7([Symbol name=exp])" + "8([Symbol name=add])" + "7([Symbol name=exp])" -> "8([Symbol name=add])" + "9([Symbol name=reciprocal])" + "8([Symbol name=add])" -> "9([Symbol name=reciprocal])" +} diff --git a/examples/dev/forward_trc.dot b/examples/dev/forward_trc.dot new file mode 100644 index 0000000000..866907bfe1 --- /dev/null +++ b/examples/dev/forward_trc.dot @@ -0,0 +1,4518 @@ +digraph { + "0([Symbol name=unpack_trivial])" + "1([Symbol name=unpack_trivial])" + "2([Symbol name=unpack_trivial])" + "3([Symbol name=unpack_trivial])" + "4([Symbol name=unpack_trivial])" + "5([Symbol name=unpack_trivial])" + "6([Symbol name=unpack_trivial])" + "7([Symbol name=unpack_trivial])" + "8([Symbol name=unpack_trivial])" + "9([Symbol name=unpack_trivial])" + "10([Symbol name=unpack_trivial])" + "11([Symbol name=unpack_trivial])" + "12([Symbol name=unpack_trivial])" + "13([Symbol name=unpack_trivial])" + "14([Symbol name=unpack_trivial])" + "15([Symbol name=unpack_trivial])" + "16([Symbol name=unpack_trivial])" + "17([Symbol name=unpack_trivial])" + "18([Symbol name=unpack_trivial])" + "19([Symbol name=unpack_trivial])" + "20([Symbol name=unpack_trivial])" + "21([Symbol name=unpack_trivial])" + "22([Symbol name=unpack_trivial])" + "23([Symbol name=unpack_trivial])" + "24([Symbol name=unpack_trivial])" + "25([Symbol name=unpack_trivial])" + "26([Symbol name=unpack_trivial])" + "27([Symbol name=unpack_trivial])" + "28([Symbol name=unpack_trivial])" + "29([Symbol name=unpack_trivial])" + "30([Symbol name=unpack_trivial])" + "31([Symbol name=unpack_trivial])" + "32([Symbol name=unpack_trivial])" + "33([Symbol name=unpack_trivial])" + "34([Symbol name=unpack_trivial])" + "35([Symbol name=unpack_trivial])" + "36([Symbol name=unpack_trivial])" + "37([Symbol name=unpack_trivial])" + "38([Symbol name=unpack_trivial])" + "39([Symbol name=unpack_trivial])" + "40([Symbol name=unpack_trivial])" + "41([Symbol name=unpack_trivial])" + "42([Symbol name=unpack_trivial])" + "43([Symbol name=unpack_trivial])" + "44([Symbol name=unpack_trivial])" + "45([Symbol name=unpack_trivial])" + "46([Symbol name=unpack_trivial])" + "47([Symbol name=unpack_trivial])" + "48([Symbol name=unpack_trivial])" + "49([Symbol name=unpack_trivial])" + "50([Symbol name=unpack_trivial])" + "51([Symbol name=unpack_trivial])" + "52([Symbol name=unpack_trivial])" + "53([Symbol name=unpack_trivial])" + "54([Symbol name=unpack_trivial])" + "55([Symbol name=unpack_trivial])" + "56([Symbol name=unpack_trivial])" + "57([Symbol name=unpack_trivial])" + "58([Symbol name=unpack_trivial])" + "59([Symbol name=unpack_trivial])" + "60([Symbol name=unpack_trivial])" + "61([Symbol name=unpack_trivial])" + "62([Symbol name=unpack_trivial])" + "63([Symbol name=unpack_trivial])" + "64([Symbol name=unpack_trivial])" + "65([Symbol name=unpack_trivial])" + "66([Symbol name=unpack_trivial])" + "67([Symbol name=unpack_trivial])" + "68([Symbol name=unpack_trivial])" + "69([Symbol name=unpack_trivial])" + "70([Symbol name=unpack_trivial])" + "71([Symbol name=unpack_trivial])" + "72([Symbol name=unpack_trivial])" + "73([Symbol name=unpack_trivial])" + "74([Symbol name=unpack_trivial])" + "75([Symbol name=unpack_trivial])" + "76([Symbol name=unpack_trivial])" + "77([Symbol name=unpack_trivial])" + "78([Symbol name=unpack_trivial])" + "79([Symbol name=unpack_trivial])" + "80([Symbol name=unpack_trivial])" + "81([Symbol name=unpack_trivial])" + "82([Symbol name=unpack_trivial])" + "83([Symbol name=unpack_trivial])" + "84([Symbol name=unpack_trivial])" + "85([Symbol name=unpack_trivial])" + "86([Symbol name=unpack_trivial])" + "87([Symbol name=unpack_trivial])" + "88([Symbol name=unpack_trivial])" + "89([Symbol name=unpack_trivial])" + "90([Symbol name=unpack_trivial])" + "91([Symbol name=unpack_trivial])" + "92([Symbol name=unpack_trivial])" + "93([Symbol name=unpack_trivial])" + "94([Symbol name=unpack_trivial])" + "95([Symbol name=unpack_trivial])" + "96([Symbol name=unpack_trivial])" + "97([Symbol name=unpack_trivial])" + "98([Symbol name=unpack_trivial])" + "99([Symbol name=unpack_trivial])" + "100([Symbol name=unpack_trivial])" + "101([Symbol name=unpack_trivial])" + "102([Symbol name=unpack_trivial])" + "103([Symbol name=unpack_trivial])" + "104([Symbol name=unpack_trivial])" + "105([Symbol name=unpack_trivial])" + "106([Symbol name=unpack_trivial])" + "107([Symbol name=unpack_trivial])" + "108([Symbol name=unpack_trivial])" + "109([Symbol name=unpack_trivial])" + "110([Symbol name=unpack_trivial])" + "111([Symbol name=unpack_trivial])" + "112([Symbol name=unpack_trivial])" + "113([Symbol name=unpack_trivial])" + "114([Symbol name=unpack_trivial])" + "115([Symbol name=unpack_trivial])" + "116([Symbol name=unpack_trivial])" + "117([Symbol name=unpack_trivial])" + "0([Symbol name=unpack_trivial])" + "1([Symbol name=unpack_trivial])" + "2([Symbol name=unpack_trivial])" + "3([Symbol name=unpack_trivial])" + "4([Symbol name=unpack_trivial])" + "5([Symbol name=unpack_trivial])" + "6([Symbol name=unpack_trivial])" + "7([Symbol name=unpack_trivial])" + "8([Symbol name=unpack_trivial])" + "9([Symbol name=unpack_trivial])" + "10([Symbol name=unpack_trivial])" + "11([Symbol name=unpack_trivial])" + "12([Symbol name=unpack_trivial])" + "13([Symbol name=unpack_trivial])" + "14([Symbol name=unpack_trivial])" + "15([Symbol name=unpack_trivial])" + "16([Symbol name=unpack_trivial])" + "17([Symbol name=unpack_trivial])" + "18([Symbol name=unpack_trivial])" + "19([Symbol name=unpack_trivial])" + "20([Symbol name=unpack_trivial])" + "21([Symbol name=unpack_trivial])" + "22([Symbol name=unpack_trivial])" + "23([Symbol name=unpack_trivial])" + "24([Symbol name=unpack_trivial])" + "25([Symbol name=unpack_trivial])" + "26([Symbol name=unpack_trivial])" + "27([Symbol name=unpack_trivial])" + "28([Symbol name=unpack_trivial])" + "29([Symbol name=unpack_trivial])" + "30([Symbol name=unpack_trivial])" + "31([Symbol name=unpack_trivial])" + "32([Symbol name=unpack_trivial])" + "33([Symbol name=unpack_trivial])" + "34([Symbol name=unpack_trivial])" + "35([Symbol name=unpack_trivial])" + "36([Symbol name=unpack_trivial])" + "37([Symbol name=unpack_trivial])" + "38([Symbol name=unpack_trivial])" + "39([Symbol name=unpack_trivial])" + "40([Symbol name=unpack_trivial])" + "41([Symbol name=unpack_trivial])" + "42([Symbol name=unpack_trivial])" + "43([Symbol name=unpack_trivial])" + "44([Symbol name=unpack_trivial])" + "45([Symbol name=unpack_trivial])" + "46([Symbol name=unpack_trivial])" + "47([Symbol name=unpack_trivial])" + "48([Symbol name=unpack_trivial])" + "49([Symbol name=unpack_trivial])" + "50([Symbol name=unpack_trivial])" + "51([Symbol name=unpack_trivial])" + "52([Symbol name=unpack_trivial])" + "53([Symbol name=unpack_trivial])" + "54([Symbol name=unpack_trivial])" + "55([Symbol name=unpack_trivial])" + "56([Symbol name=unpack_trivial])" + "57([Symbol name=unpack_trivial])" + "58([Symbol name=unpack_trivial])" + "59([Symbol name=unpack_trivial])" + "60([Symbol name=unpack_trivial])" + "61([Symbol name=unpack_trivial])" + "62([Symbol name=unpack_trivial])" + "63([Symbol name=unpack_trivial])" + "64([Symbol name=unpack_trivial])" + "65([Symbol name=unpack_trivial])" + "66([Symbol name=unpack_trivial])" + "67([Symbol name=unpack_trivial])" + "68([Symbol name=unpack_trivial])" + "69([Symbol name=unpack_trivial])" + "70([Symbol name=unpack_trivial])" + "71([Symbol name=unpack_trivial])" + "72([Symbol name=unpack_trivial])" + "73([Symbol name=unpack_trivial])" + "74([Symbol name=unpack_trivial])" + "75([Symbol name=unpack_trivial])" + "76([Symbol name=unpack_trivial])" + "77([Symbol name=unpack_trivial])" + "78([Symbol name=unpack_trivial])" + "79([Symbol name=unpack_trivial])" + "80([Symbol name=unpack_trivial])" + "81([Symbol name=unpack_trivial])" + "82([Symbol name=unpack_trivial])" + "83([Symbol name=unpack_trivial])" + "84([Symbol name=unpack_trivial])" + "85([Symbol name=unpack_trivial])" + "86([Symbol name=unpack_trivial])" + "87([Symbol name=unpack_trivial])" + "88([Symbol name=unpack_trivial])" + "89([Symbol name=unpack_trivial])" + "90([Symbol name=unpack_trivial])" + "91([Symbol name=unpack_trivial])" + "92([Symbol name=unpack_trivial])" + "93([Symbol name=unpack_trivial])" + "94([Symbol name=unpack_trivial])" + "95([Symbol name=unpack_trivial])" + "96([Symbol name=unpack_trivial])" + "97([Symbol name=unpack_trivial])" + "98([Symbol name=unpack_trivial])" + "99([Symbol name=unpack_trivial])" + "100([Symbol name=unpack_trivial])" + "101([Symbol name=unpack_trivial])" + "102([Symbol name=unpack_trivial])" + "103([Symbol name=unpack_trivial])" + "104([Symbol name=unpack_trivial])" + "105([Symbol name=unpack_trivial])" + "106([Symbol name=unpack_trivial])" + "107([Symbol name=unpack_trivial])" + "108([Symbol name=unpack_trivial])" + "109([Symbol name=unpack_trivial])" + "110([Symbol name=unpack_trivial])" + "111([Symbol name=unpack_trivial])" + "112([Symbol name=unpack_trivial])" + "113([Symbol name=unpack_trivial])" + "114([Symbol name=unpack_trivial])" + "115([Symbol name=unpack_trivial])" + "116([Symbol name=unpack_trivial])" + "117([Symbol name=unpack_trivial])" + "120([Symbol name=embedding])" + "0([Symbol name=unpack_trivial])" -> "120([Symbol name=embedding])" + "117([Symbol name=unpack_trivial])" -> "120([Symbol name=embedding])" + "1737([Symbol name=return])" + "0([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "1([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "2([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "3([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "4([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "5([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "6([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "7([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "8([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "9([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "10([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "11([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "12([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "13([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "14([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "15([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "16([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "17([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "18([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "19([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "20([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "21([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "22([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "23([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "24([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "25([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "26([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "27([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "28([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "29([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "30([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "31([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "32([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "33([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "34([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "35([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "36([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "37([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "38([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "39([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "40([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "41([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "42([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "43([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "44([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "45([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "46([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "47([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "48([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "49([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "50([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "51([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "52([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "53([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "54([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "55([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "56([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "57([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "58([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "59([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "60([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "61([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "62([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "63([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "64([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "65([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "66([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "67([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "68([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "69([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "70([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "71([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "72([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "73([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "74([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "75([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "76([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "77([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "78([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "79([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "80([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "81([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "82([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "83([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "84([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "85([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "86([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "87([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "88([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "89([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "90([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "91([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "92([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "93([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "94([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "95([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "96([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "97([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "98([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "99([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "100([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "101([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "102([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "103([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "104([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "105([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "106([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "107([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "108([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "109([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "110([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "111([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "112([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "113([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "114([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "115([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "116([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "117([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" + "121([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "127([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "128([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "132([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "133([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "135([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "142([Symbol name=reshape])" -> "1737([Symbol name=return])" + "150([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "151([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "153([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "154([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "165([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "166([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "168([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "169([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "174([Symbol name=cat])" -> "1737([Symbol name=return])" + "176([Symbol name=cat])" -> "1737([Symbol name=return])" + "177([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" + "179([Symbol name=reshape])" -> "1737([Symbol name=return])" + "185([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "191([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "192([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "196([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "197([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "199([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "204([Symbol name=exp])" -> "1737([Symbol name=return])" + "206([Symbol name=reciprocal])" -> "1737([Symbol name=return])" + "208([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "209([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "212([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "213([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "215([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "221([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "227([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "228([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "232([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "233([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "235([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "242([Symbol name=reshape])" -> "1737([Symbol name=return])" + "250([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "251([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "253([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "254([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "265([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "266([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "268([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "269([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "274([Symbol name=cat])" -> "1737([Symbol name=return])" + "276([Symbol name=cat])" -> "1737([Symbol name=return])" + "277([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" + "279([Symbol name=reshape])" -> "1737([Symbol name=return])" + "285([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "291([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "292([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "296([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "297([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "299([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "304([Symbol name=exp])" -> "1737([Symbol name=return])" + "306([Symbol name=reciprocal])" -> "1737([Symbol name=return])" + "308([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "309([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "312([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "313([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "315([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "321([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "327([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "328([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "332([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "333([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "335([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "342([Symbol name=reshape])" -> "1737([Symbol name=return])" + "350([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "351([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "353([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "354([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "365([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "366([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "368([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "369([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "374([Symbol name=cat])" -> "1737([Symbol name=return])" + "376([Symbol name=cat])" -> "1737([Symbol name=return])" + "377([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" + "379([Symbol name=reshape])" -> "1737([Symbol name=return])" + "385([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "391([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "392([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "396([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "397([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "399([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "404([Symbol name=exp])" -> "1737([Symbol name=return])" + "406([Symbol name=reciprocal])" -> "1737([Symbol name=return])" + "408([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "409([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "412([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "413([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "415([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "421([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "427([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "428([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "432([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "433([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "435([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "442([Symbol name=reshape])" -> "1737([Symbol name=return])" + "450([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "451([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "453([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "454([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "465([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "466([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "468([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "469([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "474([Symbol name=cat])" -> "1737([Symbol name=return])" + "476([Symbol name=cat])" -> "1737([Symbol name=return])" + "477([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" + "479([Symbol name=reshape])" -> "1737([Symbol name=return])" + "485([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "491([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "492([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "496([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "497([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "499([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "504([Symbol name=exp])" -> "1737([Symbol name=return])" + "506([Symbol name=reciprocal])" -> "1737([Symbol name=return])" + "508([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "509([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "512([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "513([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "515([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "521([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "527([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "528([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "532([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "533([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "535([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "542([Symbol name=reshape])" -> "1737([Symbol name=return])" + "550([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "551([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "553([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "554([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "565([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "566([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "568([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "569([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "574([Symbol name=cat])" -> "1737([Symbol name=return])" + "576([Symbol name=cat])" -> "1737([Symbol name=return])" + "577([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" + "579([Symbol name=reshape])" -> "1737([Symbol name=return])" + "585([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "591([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "592([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "596([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "597([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "599([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "604([Symbol name=exp])" -> "1737([Symbol name=return])" + "606([Symbol name=reciprocal])" -> "1737([Symbol name=return])" + "608([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "609([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "612([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "613([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "615([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "621([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "627([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "628([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "632([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "633([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "635([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "642([Symbol name=reshape])" -> "1737([Symbol name=return])" + "650([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "651([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "653([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "654([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "665([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "666([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "668([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "669([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "674([Symbol name=cat])" -> "1737([Symbol name=return])" + "676([Symbol name=cat])" -> "1737([Symbol name=return])" + "677([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" + "679([Symbol name=reshape])" -> "1737([Symbol name=return])" + "685([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "691([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "692([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "696([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "697([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "699([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "704([Symbol name=exp])" -> "1737([Symbol name=return])" + "706([Symbol name=reciprocal])" -> "1737([Symbol name=return])" + "708([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "709([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "712([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "713([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "715([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "721([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "727([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "728([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "732([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "733([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "735([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "742([Symbol name=reshape])" -> "1737([Symbol name=return])" + "750([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "751([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "753([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "754([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "765([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "766([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "768([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "769([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "774([Symbol name=cat])" -> "1737([Symbol name=return])" + "776([Symbol name=cat])" -> "1737([Symbol name=return])" + "777([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" + "779([Symbol name=reshape])" -> "1737([Symbol name=return])" + "785([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "791([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "792([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "796([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "797([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "799([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "804([Symbol name=exp])" -> "1737([Symbol name=return])" + "806([Symbol name=reciprocal])" -> "1737([Symbol name=return])" + "808([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "809([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "812([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "813([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "815([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "821([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "827([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "828([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "832([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "833([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "835([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "842([Symbol name=reshape])" -> "1737([Symbol name=return])" + "850([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "851([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "853([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "854([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "865([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "866([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "868([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "869([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "874([Symbol name=cat])" -> "1737([Symbol name=return])" + "876([Symbol name=cat])" -> "1737([Symbol name=return])" + "877([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" + "879([Symbol name=reshape])" -> "1737([Symbol name=return])" + "885([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "891([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "892([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "896([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "897([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "899([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "904([Symbol name=exp])" -> "1737([Symbol name=return])" + "906([Symbol name=reciprocal])" -> "1737([Symbol name=return])" + "908([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "909([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "912([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "913([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "915([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "921([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "927([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "928([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "932([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "933([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "935([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "942([Symbol name=reshape])" -> "1737([Symbol name=return])" + "950([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "951([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "953([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "954([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "965([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "966([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "968([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "969([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "974([Symbol name=cat])" -> "1737([Symbol name=return])" + "976([Symbol name=cat])" -> "1737([Symbol name=return])" + "977([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" + "979([Symbol name=reshape])" -> "1737([Symbol name=return])" + "985([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "991([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "992([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "996([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "997([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "999([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1004([Symbol name=exp])" -> "1737([Symbol name=return])" + "1006([Symbol name=reciprocal])" -> "1737([Symbol name=return])" + "1008([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1009([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1012([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1013([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1015([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1021([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1027([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "1028([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1032([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1033([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1035([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1042([Symbol name=reshape])" -> "1737([Symbol name=return])" + "1050([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1051([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1053([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1054([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1065([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1066([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1068([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1069([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1074([Symbol name=cat])" -> "1737([Symbol name=return])" + "1076([Symbol name=cat])" -> "1737([Symbol name=return])" + "1077([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" + "1079([Symbol name=reshape])" -> "1737([Symbol name=return])" + "1085([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1091([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "1092([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1096([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1097([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1099([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1104([Symbol name=exp])" -> "1737([Symbol name=return])" + "1106([Symbol name=reciprocal])" -> "1737([Symbol name=return])" + "1108([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1109([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1112([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1113([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1115([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1121([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1127([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "1128([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1132([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1133([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1135([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1142([Symbol name=reshape])" -> "1737([Symbol name=return])" + "1150([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1151([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1153([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1154([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1165([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1166([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1168([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1169([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1174([Symbol name=cat])" -> "1737([Symbol name=return])" + "1176([Symbol name=cat])" -> "1737([Symbol name=return])" + "1177([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" + "1179([Symbol name=reshape])" -> "1737([Symbol name=return])" + "1185([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1191([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "1192([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1196([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1197([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1199([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1204([Symbol name=exp])" -> "1737([Symbol name=return])" + "1206([Symbol name=reciprocal])" -> "1737([Symbol name=return])" + "1208([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1209([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1212([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1213([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1215([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1221([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1227([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "1228([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1232([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1233([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1235([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1242([Symbol name=reshape])" -> "1737([Symbol name=return])" + "1250([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1251([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1253([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1254([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1265([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1266([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1268([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1269([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1274([Symbol name=cat])" -> "1737([Symbol name=return])" + "1276([Symbol name=cat])" -> "1737([Symbol name=return])" + "1277([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" + "1279([Symbol name=reshape])" -> "1737([Symbol name=return])" + "1285([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1291([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "1292([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1296([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1297([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1299([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1304([Symbol name=exp])" -> "1737([Symbol name=return])" + "1306([Symbol name=reciprocal])" -> "1737([Symbol name=return])" + "1308([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1309([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1312([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1313([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1315([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1321([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1327([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "1328([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1332([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1333([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1335([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1342([Symbol name=reshape])" -> "1737([Symbol name=return])" + "1350([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1351([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1353([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1354([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1365([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1366([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1368([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1369([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1374([Symbol name=cat])" -> "1737([Symbol name=return])" + "1376([Symbol name=cat])" -> "1737([Symbol name=return])" + "1377([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" + "1379([Symbol name=reshape])" -> "1737([Symbol name=return])" + "1385([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1391([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "1392([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1396([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1397([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1399([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1404([Symbol name=exp])" -> "1737([Symbol name=return])" + "1406([Symbol name=reciprocal])" -> "1737([Symbol name=return])" + "1408([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1409([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1412([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1413([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1415([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1421([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1427([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "1428([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1432([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1433([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1435([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1442([Symbol name=reshape])" -> "1737([Symbol name=return])" + "1450([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1451([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1453([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1454([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1465([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1466([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1468([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1469([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1474([Symbol name=cat])" -> "1737([Symbol name=return])" + "1476([Symbol name=cat])" -> "1737([Symbol name=return])" + "1477([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" + "1479([Symbol name=reshape])" -> "1737([Symbol name=return])" + "1485([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1491([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "1492([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1496([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1497([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1499([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1504([Symbol name=exp])" -> "1737([Symbol name=return])" + "1506([Symbol name=reciprocal])" -> "1737([Symbol name=return])" + "1508([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1509([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1512([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1513([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1515([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1521([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1527([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "1528([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1532([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1533([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1535([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1542([Symbol name=reshape])" -> "1737([Symbol name=return])" + "1550([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1551([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1553([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1554([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1565([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1566([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1568([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1569([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1574([Symbol name=cat])" -> "1737([Symbol name=return])" + "1576([Symbol name=cat])" -> "1737([Symbol name=return])" + "1577([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" + "1579([Symbol name=reshape])" -> "1737([Symbol name=return])" + "1585([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1591([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "1592([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1596([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1597([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1599([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1604([Symbol name=exp])" -> "1737([Symbol name=return])" + "1606([Symbol name=reciprocal])" -> "1737([Symbol name=return])" + "1608([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1609([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1612([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1613([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1615([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1621([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1627([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "1628([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1632([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1633([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1635([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1642([Symbol name=reshape])" -> "1737([Symbol name=return])" + "1650([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1651([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1653([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1654([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1665([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1666([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1668([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1669([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1674([Symbol name=cat])" -> "1737([Symbol name=return])" + "1676([Symbol name=cat])" -> "1737([Symbol name=return])" + "1677([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" + "1679([Symbol name=reshape])" -> "1737([Symbol name=return])" + "1685([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1691([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "1692([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1696([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1697([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1699([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1704([Symbol name=exp])" -> "1737([Symbol name=return])" + "1706([Symbol name=reciprocal])" -> "1737([Symbol name=return])" + "1708([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1709([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1712([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1713([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1715([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1721([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1727([Symbol name=rsqrt])" -> "1737([Symbol name=return])" + "1728([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" + "1732([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1733([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1735([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" + "1736([Symbol name=linear])" -> "1737([Symbol name=return])" + "118([Symbol name=slice_prim])" + "1([Symbol name=unpack_trivial])" -> "118([Symbol name=slice_prim])" + "1736([Symbol name=linear])" + "2([Symbol name=unpack_trivial])" -> "1736([Symbol name=linear])" + "1735([Symbol name=convert_element_type])" -> "1736([Symbol name=linear])" + "119([Symbol name=slice_prim])" + "3([Symbol name=unpack_trivial])" -> "119([Symbol name=slice_prim])" + "136([Symbol name=linear])" + "4([Symbol name=unpack_trivial])" -> "136([Symbol name=linear])" + "135([Symbol name=convert_element_type])" -> "136([Symbol name=linear])" + "180([Symbol name=linear])" + "179([Symbol name=reshape])" -> "180([Symbol name=linear])" + "5([Symbol name=unpack_trivial])" -> "180([Symbol name=linear])" + "200([Symbol name=linear])" + "6([Symbol name=unpack_trivial])" -> "200([Symbol name=linear])" + "199([Symbol name=convert_element_type])" -> "200([Symbol name=linear])" + "201([Symbol name=linear])" + "7([Symbol name=unpack_trivial])" -> "201([Symbol name=linear])" + "199([Symbol name=convert_element_type])" -> "201([Symbol name=linear])" + "216([Symbol name=linear])" + "8([Symbol name=unpack_trivial])" -> "216([Symbol name=linear])" + "215([Symbol name=convert_element_type])" -> "216([Symbol name=linear])" + "131([Symbol name=broadcast_in_dim])" + "9([Symbol name=unpack_trivial])" -> "131([Symbol name=broadcast_in_dim])" + "195([Symbol name=broadcast_in_dim])" + "10([Symbol name=unpack_trivial])" -> "195([Symbol name=broadcast_in_dim])" + "236([Symbol name=linear])" + "11([Symbol name=unpack_trivial])" -> "236([Symbol name=linear])" + "235([Symbol name=convert_element_type])" -> "236([Symbol name=linear])" + "280([Symbol name=linear])" + "12([Symbol name=unpack_trivial])" -> "280([Symbol name=linear])" + "279([Symbol name=reshape])" -> "280([Symbol name=linear])" + "300([Symbol name=linear])" + "299([Symbol name=convert_element_type])" -> "300([Symbol name=linear])" + "13([Symbol name=unpack_trivial])" -> "300([Symbol name=linear])" + "301([Symbol name=linear])" + "299([Symbol name=convert_element_type])" -> "301([Symbol name=linear])" + "14([Symbol name=unpack_trivial])" -> "301([Symbol name=linear])" + "316([Symbol name=linear])" + "315([Symbol name=convert_element_type])" -> "316([Symbol name=linear])" + "15([Symbol name=unpack_trivial])" -> "316([Symbol name=linear])" + "231([Symbol name=broadcast_in_dim])" + "16([Symbol name=unpack_trivial])" -> "231([Symbol name=broadcast_in_dim])" + "295([Symbol name=broadcast_in_dim])" + "17([Symbol name=unpack_trivial])" -> "295([Symbol name=broadcast_in_dim])" + "336([Symbol name=linear])" + "18([Symbol name=unpack_trivial])" -> "336([Symbol name=linear])" + "335([Symbol name=convert_element_type])" -> "336([Symbol name=linear])" + "380([Symbol name=linear])" + "19([Symbol name=unpack_trivial])" -> "380([Symbol name=linear])" + "379([Symbol name=reshape])" -> "380([Symbol name=linear])" + "400([Symbol name=linear])" + "20([Symbol name=unpack_trivial])" -> "400([Symbol name=linear])" + "399([Symbol name=convert_element_type])" -> "400([Symbol name=linear])" + "401([Symbol name=linear])" + "21([Symbol name=unpack_trivial])" -> "401([Symbol name=linear])" + "399([Symbol name=convert_element_type])" -> "401([Symbol name=linear])" + "416([Symbol name=linear])" + "22([Symbol name=unpack_trivial])" -> "416([Symbol name=linear])" + "415([Symbol name=convert_element_type])" -> "416([Symbol name=linear])" + "331([Symbol name=broadcast_in_dim])" + "23([Symbol name=unpack_trivial])" -> "331([Symbol name=broadcast_in_dim])" + "395([Symbol name=broadcast_in_dim])" + "24([Symbol name=unpack_trivial])" -> "395([Symbol name=broadcast_in_dim])" + "436([Symbol name=linear])" + "25([Symbol name=unpack_trivial])" -> "436([Symbol name=linear])" + "435([Symbol name=convert_element_type])" -> "436([Symbol name=linear])" + "480([Symbol name=linear])" + "26([Symbol name=unpack_trivial])" -> "480([Symbol name=linear])" + "479([Symbol name=reshape])" -> "480([Symbol name=linear])" + "500([Symbol name=linear])" + "27([Symbol name=unpack_trivial])" -> "500([Symbol name=linear])" + "499([Symbol name=convert_element_type])" -> "500([Symbol name=linear])" + "501([Symbol name=linear])" + "499([Symbol name=convert_element_type])" -> "501([Symbol name=linear])" + "28([Symbol name=unpack_trivial])" -> "501([Symbol name=linear])" + "516([Symbol name=linear])" + "515([Symbol name=convert_element_type])" -> "516([Symbol name=linear])" + "29([Symbol name=unpack_trivial])" -> "516([Symbol name=linear])" + "431([Symbol name=broadcast_in_dim])" + "30([Symbol name=unpack_trivial])" -> "431([Symbol name=broadcast_in_dim])" + "495([Symbol name=broadcast_in_dim])" + "31([Symbol name=unpack_trivial])" -> "495([Symbol name=broadcast_in_dim])" + "536([Symbol name=linear])" + "32([Symbol name=unpack_trivial])" -> "536([Symbol name=linear])" + "535([Symbol name=convert_element_type])" -> "536([Symbol name=linear])" + "580([Symbol name=linear])" + "33([Symbol name=unpack_trivial])" -> "580([Symbol name=linear])" + "579([Symbol name=reshape])" -> "580([Symbol name=linear])" + "600([Symbol name=linear])" + "34([Symbol name=unpack_trivial])" -> "600([Symbol name=linear])" + "599([Symbol name=convert_element_type])" -> "600([Symbol name=linear])" + "601([Symbol name=linear])" + "35([Symbol name=unpack_trivial])" -> "601([Symbol name=linear])" + "599([Symbol name=convert_element_type])" -> "601([Symbol name=linear])" + "616([Symbol name=linear])" + "36([Symbol name=unpack_trivial])" -> "616([Symbol name=linear])" + "615([Symbol name=convert_element_type])" -> "616([Symbol name=linear])" + "531([Symbol name=broadcast_in_dim])" + "37([Symbol name=unpack_trivial])" -> "531([Symbol name=broadcast_in_dim])" + "595([Symbol name=broadcast_in_dim])" + "38([Symbol name=unpack_trivial])" -> "595([Symbol name=broadcast_in_dim])" + "636([Symbol name=linear])" + "635([Symbol name=convert_element_type])" -> "636([Symbol name=linear])" + "39([Symbol name=unpack_trivial])" -> "636([Symbol name=linear])" + "680([Symbol name=linear])" + "40([Symbol name=unpack_trivial])" -> "680([Symbol name=linear])" + "679([Symbol name=reshape])" -> "680([Symbol name=linear])" + "700([Symbol name=linear])" + "41([Symbol name=unpack_trivial])" -> "700([Symbol name=linear])" + "699([Symbol name=convert_element_type])" -> "700([Symbol name=linear])" + "701([Symbol name=linear])" + "42([Symbol name=unpack_trivial])" -> "701([Symbol name=linear])" + "699([Symbol name=convert_element_type])" -> "701([Symbol name=linear])" + "716([Symbol name=linear])" + "43([Symbol name=unpack_trivial])" -> "716([Symbol name=linear])" + "715([Symbol name=convert_element_type])" -> "716([Symbol name=linear])" + "631([Symbol name=broadcast_in_dim])" + "44([Symbol name=unpack_trivial])" -> "631([Symbol name=broadcast_in_dim])" + "695([Symbol name=broadcast_in_dim])" + "45([Symbol name=unpack_trivial])" -> "695([Symbol name=broadcast_in_dim])" + "736([Symbol name=linear])" + "46([Symbol name=unpack_trivial])" -> "736([Symbol name=linear])" + "735([Symbol name=convert_element_type])" -> "736([Symbol name=linear])" + "780([Symbol name=linear])" + "779([Symbol name=reshape])" -> "780([Symbol name=linear])" + "47([Symbol name=unpack_trivial])" -> "780([Symbol name=linear])" + "800([Symbol name=linear])" + "48([Symbol name=unpack_trivial])" -> "800([Symbol name=linear])" + "799([Symbol name=convert_element_type])" -> "800([Symbol name=linear])" + "801([Symbol name=linear])" + "49([Symbol name=unpack_trivial])" -> "801([Symbol name=linear])" + "799([Symbol name=convert_element_type])" -> "801([Symbol name=linear])" + "816([Symbol name=linear])" + "50([Symbol name=unpack_trivial])" -> "816([Symbol name=linear])" + "815([Symbol name=convert_element_type])" -> "816([Symbol name=linear])" + "731([Symbol name=broadcast_in_dim])" + "51([Symbol name=unpack_trivial])" -> "731([Symbol name=broadcast_in_dim])" + "795([Symbol name=broadcast_in_dim])" + "52([Symbol name=unpack_trivial])" -> "795([Symbol name=broadcast_in_dim])" + "836([Symbol name=linear])" + "835([Symbol name=convert_element_type])" -> "836([Symbol name=linear])" + "53([Symbol name=unpack_trivial])" -> "836([Symbol name=linear])" + "880([Symbol name=linear])" + "54([Symbol name=unpack_trivial])" -> "880([Symbol name=linear])" + "879([Symbol name=reshape])" -> "880([Symbol name=linear])" + "900([Symbol name=linear])" + "899([Symbol name=convert_element_type])" -> "900([Symbol name=linear])" + "55([Symbol name=unpack_trivial])" -> "900([Symbol name=linear])" + "901([Symbol name=linear])" + "56([Symbol name=unpack_trivial])" -> "901([Symbol name=linear])" + "899([Symbol name=convert_element_type])" -> "901([Symbol name=linear])" + "916([Symbol name=linear])" + "57([Symbol name=unpack_trivial])" -> "916([Symbol name=linear])" + "915([Symbol name=convert_element_type])" -> "916([Symbol name=linear])" + "831([Symbol name=broadcast_in_dim])" + "58([Symbol name=unpack_trivial])" -> "831([Symbol name=broadcast_in_dim])" + "895([Symbol name=broadcast_in_dim])" + "59([Symbol name=unpack_trivial])" -> "895([Symbol name=broadcast_in_dim])" + "936([Symbol name=linear])" + "60([Symbol name=unpack_trivial])" -> "936([Symbol name=linear])" + "935([Symbol name=convert_element_type])" -> "936([Symbol name=linear])" + "980([Symbol name=linear])" + "979([Symbol name=reshape])" -> "980([Symbol name=linear])" + "61([Symbol name=unpack_trivial])" -> "980([Symbol name=linear])" + "1000([Symbol name=linear])" + "62([Symbol name=unpack_trivial])" -> "1000([Symbol name=linear])" + "999([Symbol name=convert_element_type])" -> "1000([Symbol name=linear])" + "1001([Symbol name=linear])" + "63([Symbol name=unpack_trivial])" -> "1001([Symbol name=linear])" + "999([Symbol name=convert_element_type])" -> "1001([Symbol name=linear])" + "1016([Symbol name=linear])" + "64([Symbol name=unpack_trivial])" -> "1016([Symbol name=linear])" + "1015([Symbol name=convert_element_type])" -> "1016([Symbol name=linear])" + "931([Symbol name=broadcast_in_dim])" + "65([Symbol name=unpack_trivial])" -> "931([Symbol name=broadcast_in_dim])" + "995([Symbol name=broadcast_in_dim])" + "66([Symbol name=unpack_trivial])" -> "995([Symbol name=broadcast_in_dim])" + "1036([Symbol name=linear])" + "67([Symbol name=unpack_trivial])" -> "1036([Symbol name=linear])" + "1035([Symbol name=convert_element_type])" -> "1036([Symbol name=linear])" + "1080([Symbol name=linear])" + "68([Symbol name=unpack_trivial])" -> "1080([Symbol name=linear])" + "1079([Symbol name=reshape])" -> "1080([Symbol name=linear])" + "1100([Symbol name=linear])" + "1099([Symbol name=convert_element_type])" -> "1100([Symbol name=linear])" + "69([Symbol name=unpack_trivial])" -> "1100([Symbol name=linear])" + "1101([Symbol name=linear])" + "1099([Symbol name=convert_element_type])" -> "1101([Symbol name=linear])" + "70([Symbol name=unpack_trivial])" -> "1101([Symbol name=linear])" + "1116([Symbol name=linear])" + "1115([Symbol name=convert_element_type])" -> "1116([Symbol name=linear])" + "71([Symbol name=unpack_trivial])" -> "1116([Symbol name=linear])" + "1031([Symbol name=broadcast_in_dim])" + "72([Symbol name=unpack_trivial])" -> "1031([Symbol name=broadcast_in_dim])" + "1095([Symbol name=broadcast_in_dim])" + "73([Symbol name=unpack_trivial])" -> "1095([Symbol name=broadcast_in_dim])" + "1136([Symbol name=linear])" + "74([Symbol name=unpack_trivial])" -> "1136([Symbol name=linear])" + "1135([Symbol name=convert_element_type])" -> "1136([Symbol name=linear])" + "1180([Symbol name=linear])" + "75([Symbol name=unpack_trivial])" -> "1180([Symbol name=linear])" + "1179([Symbol name=reshape])" -> "1180([Symbol name=linear])" + "1200([Symbol name=linear])" + "76([Symbol name=unpack_trivial])" -> "1200([Symbol name=linear])" + "1199([Symbol name=convert_element_type])" -> "1200([Symbol name=linear])" + "1201([Symbol name=linear])" + "77([Symbol name=unpack_trivial])" -> "1201([Symbol name=linear])" + "1199([Symbol name=convert_element_type])" -> "1201([Symbol name=linear])" + "1216([Symbol name=linear])" + "78([Symbol name=unpack_trivial])" -> "1216([Symbol name=linear])" + "1215([Symbol name=convert_element_type])" -> "1216([Symbol name=linear])" + "1131([Symbol name=broadcast_in_dim])" + "79([Symbol name=unpack_trivial])" -> "1131([Symbol name=broadcast_in_dim])" + "1195([Symbol name=broadcast_in_dim])" + "80([Symbol name=unpack_trivial])" -> "1195([Symbol name=broadcast_in_dim])" + "1236([Symbol name=linear])" + "81([Symbol name=unpack_trivial])" -> "1236([Symbol name=linear])" + "1235([Symbol name=convert_element_type])" -> "1236([Symbol name=linear])" + "1280([Symbol name=linear])" + "82([Symbol name=unpack_trivial])" -> "1280([Symbol name=linear])" + "1279([Symbol name=reshape])" -> "1280([Symbol name=linear])" + "1300([Symbol name=linear])" + "83([Symbol name=unpack_trivial])" -> "1300([Symbol name=linear])" + "1299([Symbol name=convert_element_type])" -> "1300([Symbol name=linear])" + "1301([Symbol name=linear])" + "1299([Symbol name=convert_element_type])" -> "1301([Symbol name=linear])" + "84([Symbol name=unpack_trivial])" -> "1301([Symbol name=linear])" + "1316([Symbol name=linear])" + "1315([Symbol name=convert_element_type])" -> "1316([Symbol name=linear])" + "85([Symbol name=unpack_trivial])" -> "1316([Symbol name=linear])" + "1231([Symbol name=broadcast_in_dim])" + "86([Symbol name=unpack_trivial])" -> "1231([Symbol name=broadcast_in_dim])" + "1295([Symbol name=broadcast_in_dim])" + "87([Symbol name=unpack_trivial])" -> "1295([Symbol name=broadcast_in_dim])" + "1336([Symbol name=linear])" + "88([Symbol name=unpack_trivial])" -> "1336([Symbol name=linear])" + "1335([Symbol name=convert_element_type])" -> "1336([Symbol name=linear])" + "1380([Symbol name=linear])" + "89([Symbol name=unpack_trivial])" -> "1380([Symbol name=linear])" + "1379([Symbol name=reshape])" -> "1380([Symbol name=linear])" + "1400([Symbol name=linear])" + "90([Symbol name=unpack_trivial])" -> "1400([Symbol name=linear])" + "1399([Symbol name=convert_element_type])" -> "1400([Symbol name=linear])" + "1401([Symbol name=linear])" + "91([Symbol name=unpack_trivial])" -> "1401([Symbol name=linear])" + "1399([Symbol name=convert_element_type])" -> "1401([Symbol name=linear])" + "1416([Symbol name=linear])" + "92([Symbol name=unpack_trivial])" -> "1416([Symbol name=linear])" + "1415([Symbol name=convert_element_type])" -> "1416([Symbol name=linear])" + "1331([Symbol name=broadcast_in_dim])" + "93([Symbol name=unpack_trivial])" -> "1331([Symbol name=broadcast_in_dim])" + "1395([Symbol name=broadcast_in_dim])" + "94([Symbol name=unpack_trivial])" -> "1395([Symbol name=broadcast_in_dim])" + "1436([Symbol name=linear])" + "1435([Symbol name=convert_element_type])" -> "1436([Symbol name=linear])" + "95([Symbol name=unpack_trivial])" -> "1436([Symbol name=linear])" + "1480([Symbol name=linear])" + "96([Symbol name=unpack_trivial])" -> "1480([Symbol name=linear])" + "1479([Symbol name=reshape])" -> "1480([Symbol name=linear])" + "1500([Symbol name=linear])" + "97([Symbol name=unpack_trivial])" -> "1500([Symbol name=linear])" + "1499([Symbol name=convert_element_type])" -> "1500([Symbol name=linear])" + "1501([Symbol name=linear])" + "98([Symbol name=unpack_trivial])" -> "1501([Symbol name=linear])" + "1499([Symbol name=convert_element_type])" -> "1501([Symbol name=linear])" + "1516([Symbol name=linear])" + "99([Symbol name=unpack_trivial])" -> "1516([Symbol name=linear])" + "1515([Symbol name=convert_element_type])" -> "1516([Symbol name=linear])" + "1431([Symbol name=broadcast_in_dim])" + "100([Symbol name=unpack_trivial])" -> "1431([Symbol name=broadcast_in_dim])" + "1495([Symbol name=broadcast_in_dim])" + "101([Symbol name=unpack_trivial])" -> "1495([Symbol name=broadcast_in_dim])" + "1536([Symbol name=linear])" + "102([Symbol name=unpack_trivial])" -> "1536([Symbol name=linear])" + "1535([Symbol name=convert_element_type])" -> "1536([Symbol name=linear])" + "1580([Symbol name=linear])" + "1579([Symbol name=reshape])" -> "1580([Symbol name=linear])" + "103([Symbol name=unpack_trivial])" -> "1580([Symbol name=linear])" + "1600([Symbol name=linear])" + "104([Symbol name=unpack_trivial])" -> "1600([Symbol name=linear])" + "1599([Symbol name=convert_element_type])" -> "1600([Symbol name=linear])" + "1601([Symbol name=linear])" + "105([Symbol name=unpack_trivial])" -> "1601([Symbol name=linear])" + "1599([Symbol name=convert_element_type])" -> "1601([Symbol name=linear])" + "1616([Symbol name=linear])" + "106([Symbol name=unpack_trivial])" -> "1616([Symbol name=linear])" + "1615([Symbol name=convert_element_type])" -> "1616([Symbol name=linear])" + "1531([Symbol name=broadcast_in_dim])" + "107([Symbol name=unpack_trivial])" -> "1531([Symbol name=broadcast_in_dim])" + "1595([Symbol name=broadcast_in_dim])" + "108([Symbol name=unpack_trivial])" -> "1595([Symbol name=broadcast_in_dim])" + "1636([Symbol name=linear])" + "1635([Symbol name=convert_element_type])" -> "1636([Symbol name=linear])" + "109([Symbol name=unpack_trivial])" -> "1636([Symbol name=linear])" + "1680([Symbol name=linear])" + "110([Symbol name=unpack_trivial])" -> "1680([Symbol name=linear])" + "1679([Symbol name=reshape])" -> "1680([Symbol name=linear])" + "1700([Symbol name=linear])" + "1699([Symbol name=convert_element_type])" -> "1700([Symbol name=linear])" + "111([Symbol name=unpack_trivial])" -> "1700([Symbol name=linear])" + "1701([Symbol name=linear])" + "112([Symbol name=unpack_trivial])" -> "1701([Symbol name=linear])" + "1699([Symbol name=convert_element_type])" -> "1701([Symbol name=linear])" + "1716([Symbol name=linear])" + "113([Symbol name=unpack_trivial])" -> "1716([Symbol name=linear])" + "1715([Symbol name=convert_element_type])" -> "1716([Symbol name=linear])" + "1631([Symbol name=broadcast_in_dim])" + "114([Symbol name=unpack_trivial])" -> "1631([Symbol name=broadcast_in_dim])" + "1695([Symbol name=broadcast_in_dim])" + "115([Symbol name=unpack_trivial])" -> "1695([Symbol name=broadcast_in_dim])" + "1731([Symbol name=broadcast_in_dim])" + "116([Symbol name=unpack_trivial])" -> "1731([Symbol name=broadcast_in_dim])" + "121([Symbol name=convert_element_type])" + "120([Symbol name=embedding])" -> "121([Symbol name=convert_element_type])" + "182([Symbol name=convert_element_type])" + "120([Symbol name=embedding])" -> "182([Symbol name=convert_element_type])" + "1665([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "1665([Symbol name=broadcast_in_dim])" + "265([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "265([Symbol name=broadcast_in_dim])" + "650([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "650([Symbol name=broadcast_in_dim])" + "1165([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "1165([Symbol name=broadcast_in_dim])" + "1550([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "1550([Symbol name=broadcast_in_dim])" + "150([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "150([Symbol name=broadcast_in_dim])" + "665([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "665([Symbol name=broadcast_in_dim])" + "1050([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "1050([Symbol name=broadcast_in_dim])" + "1565([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "1565([Symbol name=broadcast_in_dim])" + "165([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "165([Symbol name=broadcast_in_dim])" + "550([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "550([Symbol name=broadcast_in_dim])" + "1065([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "1065([Symbol name=broadcast_in_dim])" + "1450([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "1450([Symbol name=broadcast_in_dim])" + "565([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "565([Symbol name=broadcast_in_dim])" + "950([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "950([Symbol name=broadcast_in_dim])" + "1465([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "1465([Symbol name=broadcast_in_dim])" + "450([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "450([Symbol name=broadcast_in_dim])" + "965([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "965([Symbol name=broadcast_in_dim])" + "1350([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "1350([Symbol name=broadcast_in_dim])" + "465([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "465([Symbol name=broadcast_in_dim])" + "850([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "850([Symbol name=broadcast_in_dim])" + "1365([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "1365([Symbol name=broadcast_in_dim])" + "350([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "350([Symbol name=broadcast_in_dim])" + "865([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "865([Symbol name=broadcast_in_dim])" + "1250([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "1250([Symbol name=broadcast_in_dim])" + "365([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "365([Symbol name=broadcast_in_dim])" + "750([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "750([Symbol name=broadcast_in_dim])" + "1265([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "1265([Symbol name=broadcast_in_dim])" + "1650([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "1650([Symbol name=broadcast_in_dim])" + "250([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "250([Symbol name=broadcast_in_dim])" + "765([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "765([Symbol name=broadcast_in_dim])" + "1150([Symbol name=broadcast_in_dim])" + "118([Symbol name=slice_prim])" -> "1150([Symbol name=broadcast_in_dim])" + "768([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "768([Symbol name=broadcast_in_dim])" + "1153([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "1153([Symbol name=broadcast_in_dim])" + "1668([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "1668([Symbol name=broadcast_in_dim])" + "268([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "268([Symbol name=broadcast_in_dim])" + "653([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "653([Symbol name=broadcast_in_dim])" + "1168([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "1168([Symbol name=broadcast_in_dim])" + "1553([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "1553([Symbol name=broadcast_in_dim])" + "153([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "153([Symbol name=broadcast_in_dim])" + "668([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "668([Symbol name=broadcast_in_dim])" + "1053([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "1053([Symbol name=broadcast_in_dim])" + "1568([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "1568([Symbol name=broadcast_in_dim])" + "168([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "168([Symbol name=broadcast_in_dim])" + "553([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "553([Symbol name=broadcast_in_dim])" + "1068([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "1068([Symbol name=broadcast_in_dim])" + "1453([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "1453([Symbol name=broadcast_in_dim])" + "568([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "568([Symbol name=broadcast_in_dim])" + "953([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "953([Symbol name=broadcast_in_dim])" + "1468([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "1468([Symbol name=broadcast_in_dim])" + "453([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "453([Symbol name=broadcast_in_dim])" + "968([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "968([Symbol name=broadcast_in_dim])" + "1353([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "1353([Symbol name=broadcast_in_dim])" + "468([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "468([Symbol name=broadcast_in_dim])" + "853([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "853([Symbol name=broadcast_in_dim])" + "1368([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "1368([Symbol name=broadcast_in_dim])" + "353([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "353([Symbol name=broadcast_in_dim])" + "868([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "868([Symbol name=broadcast_in_dim])" + "1253([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "1253([Symbol name=broadcast_in_dim])" + "368([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "368([Symbol name=broadcast_in_dim])" + "753([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "753([Symbol name=broadcast_in_dim])" + "1268([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "1268([Symbol name=broadcast_in_dim])" + "1653([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "1653([Symbol name=broadcast_in_dim])" + "253([Symbol name=broadcast_in_dim])" + "119([Symbol name=slice_prim])" -> "253([Symbol name=broadcast_in_dim])" + "137([Symbol name=reshape])" + "136([Symbol name=linear])" -> "137([Symbol name=reshape])" + "181([Symbol name=convert_element_type])" + "180([Symbol name=linear])" -> "181([Symbol name=convert_element_type])" + "208([Symbol name=convert_element_type])" + "200([Symbol name=linear])" -> "208([Symbol name=convert_element_type])" + "202([Symbol name=convert_element_type])" + "200([Symbol name=linear])" -> "202([Symbol name=convert_element_type])" + "213([Symbol name=convert_element_type])" + "201([Symbol name=linear])" -> "213([Symbol name=convert_element_type])" + "217([Symbol name=convert_element_type])" + "216([Symbol name=linear])" -> "217([Symbol name=convert_element_type])" + "133([Symbol name=convert_element_type])" + "131([Symbol name=broadcast_in_dim])" -> "133([Symbol name=convert_element_type])" + "197([Symbol name=convert_element_type])" + "195([Symbol name=broadcast_in_dim])" -> "197([Symbol name=convert_element_type])" + "237([Symbol name=reshape])" + "236([Symbol name=linear])" -> "237([Symbol name=reshape])" + "281([Symbol name=convert_element_type])" + "280([Symbol name=linear])" -> "281([Symbol name=convert_element_type])" + "308([Symbol name=convert_element_type])" + "300([Symbol name=linear])" -> "308([Symbol name=convert_element_type])" + "302([Symbol name=convert_element_type])" + "300([Symbol name=linear])" -> "302([Symbol name=convert_element_type])" + "313([Symbol name=convert_element_type])" + "301([Symbol name=linear])" -> "313([Symbol name=convert_element_type])" + "317([Symbol name=convert_element_type])" + "316([Symbol name=linear])" -> "317([Symbol name=convert_element_type])" + "233([Symbol name=convert_element_type])" + "231([Symbol name=broadcast_in_dim])" -> "233([Symbol name=convert_element_type])" + "297([Symbol name=convert_element_type])" + "295([Symbol name=broadcast_in_dim])" -> "297([Symbol name=convert_element_type])" + "337([Symbol name=reshape])" + "336([Symbol name=linear])" -> "337([Symbol name=reshape])" + "381([Symbol name=convert_element_type])" + "380([Symbol name=linear])" -> "381([Symbol name=convert_element_type])" + "408([Symbol name=convert_element_type])" + "400([Symbol name=linear])" -> "408([Symbol name=convert_element_type])" + "402([Symbol name=convert_element_type])" + "400([Symbol name=linear])" -> "402([Symbol name=convert_element_type])" + "413([Symbol name=convert_element_type])" + "401([Symbol name=linear])" -> "413([Symbol name=convert_element_type])" + "417([Symbol name=convert_element_type])" + "416([Symbol name=linear])" -> "417([Symbol name=convert_element_type])" + "333([Symbol name=convert_element_type])" + "331([Symbol name=broadcast_in_dim])" -> "333([Symbol name=convert_element_type])" + "397([Symbol name=convert_element_type])" + "395([Symbol name=broadcast_in_dim])" -> "397([Symbol name=convert_element_type])" + "437([Symbol name=reshape])" + "436([Symbol name=linear])" -> "437([Symbol name=reshape])" + "481([Symbol name=convert_element_type])" + "480([Symbol name=linear])" -> "481([Symbol name=convert_element_type])" + "508([Symbol name=convert_element_type])" + "500([Symbol name=linear])" -> "508([Symbol name=convert_element_type])" + "502([Symbol name=convert_element_type])" + "500([Symbol name=linear])" -> "502([Symbol name=convert_element_type])" + "513([Symbol name=convert_element_type])" + "501([Symbol name=linear])" -> "513([Symbol name=convert_element_type])" + "517([Symbol name=convert_element_type])" + "516([Symbol name=linear])" -> "517([Symbol name=convert_element_type])" + "433([Symbol name=convert_element_type])" + "431([Symbol name=broadcast_in_dim])" -> "433([Symbol name=convert_element_type])" + "497([Symbol name=convert_element_type])" + "495([Symbol name=broadcast_in_dim])" -> "497([Symbol name=convert_element_type])" + "537([Symbol name=reshape])" + "536([Symbol name=linear])" -> "537([Symbol name=reshape])" + "581([Symbol name=convert_element_type])" + "580([Symbol name=linear])" -> "581([Symbol name=convert_element_type])" + "608([Symbol name=convert_element_type])" + "600([Symbol name=linear])" -> "608([Symbol name=convert_element_type])" + "602([Symbol name=convert_element_type])" + "600([Symbol name=linear])" -> "602([Symbol name=convert_element_type])" + "613([Symbol name=convert_element_type])" + "601([Symbol name=linear])" -> "613([Symbol name=convert_element_type])" + "617([Symbol name=convert_element_type])" + "616([Symbol name=linear])" -> "617([Symbol name=convert_element_type])" + "533([Symbol name=convert_element_type])" + "531([Symbol name=broadcast_in_dim])" -> "533([Symbol name=convert_element_type])" + "597([Symbol name=convert_element_type])" + "595([Symbol name=broadcast_in_dim])" -> "597([Symbol name=convert_element_type])" + "637([Symbol name=reshape])" + "636([Symbol name=linear])" -> "637([Symbol name=reshape])" + "681([Symbol name=convert_element_type])" + "680([Symbol name=linear])" -> "681([Symbol name=convert_element_type])" + "708([Symbol name=convert_element_type])" + "700([Symbol name=linear])" -> "708([Symbol name=convert_element_type])" + "702([Symbol name=convert_element_type])" + "700([Symbol name=linear])" -> "702([Symbol name=convert_element_type])" + "713([Symbol name=convert_element_type])" + "701([Symbol name=linear])" -> "713([Symbol name=convert_element_type])" + "717([Symbol name=convert_element_type])" + "716([Symbol name=linear])" -> "717([Symbol name=convert_element_type])" + "633([Symbol name=convert_element_type])" + "631([Symbol name=broadcast_in_dim])" -> "633([Symbol name=convert_element_type])" + "697([Symbol name=convert_element_type])" + "695([Symbol name=broadcast_in_dim])" -> "697([Symbol name=convert_element_type])" + "737([Symbol name=reshape])" + "736([Symbol name=linear])" -> "737([Symbol name=reshape])" + "781([Symbol name=convert_element_type])" + "780([Symbol name=linear])" -> "781([Symbol name=convert_element_type])" + "808([Symbol name=convert_element_type])" + "800([Symbol name=linear])" -> "808([Symbol name=convert_element_type])" + "802([Symbol name=convert_element_type])" + "800([Symbol name=linear])" -> "802([Symbol name=convert_element_type])" + "813([Symbol name=convert_element_type])" + "801([Symbol name=linear])" -> "813([Symbol name=convert_element_type])" + "817([Symbol name=convert_element_type])" + "816([Symbol name=linear])" -> "817([Symbol name=convert_element_type])" + "733([Symbol name=convert_element_type])" + "731([Symbol name=broadcast_in_dim])" -> "733([Symbol name=convert_element_type])" + "797([Symbol name=convert_element_type])" + "795([Symbol name=broadcast_in_dim])" -> "797([Symbol name=convert_element_type])" + "837([Symbol name=reshape])" + "836([Symbol name=linear])" -> "837([Symbol name=reshape])" + "881([Symbol name=convert_element_type])" + "880([Symbol name=linear])" -> "881([Symbol name=convert_element_type])" + "908([Symbol name=convert_element_type])" + "900([Symbol name=linear])" -> "908([Symbol name=convert_element_type])" + "902([Symbol name=convert_element_type])" + "900([Symbol name=linear])" -> "902([Symbol name=convert_element_type])" + "913([Symbol name=convert_element_type])" + "901([Symbol name=linear])" -> "913([Symbol name=convert_element_type])" + "917([Symbol name=convert_element_type])" + "916([Symbol name=linear])" -> "917([Symbol name=convert_element_type])" + "833([Symbol name=convert_element_type])" + "831([Symbol name=broadcast_in_dim])" -> "833([Symbol name=convert_element_type])" + "897([Symbol name=convert_element_type])" + "895([Symbol name=broadcast_in_dim])" -> "897([Symbol name=convert_element_type])" + "937([Symbol name=reshape])" + "936([Symbol name=linear])" -> "937([Symbol name=reshape])" + "981([Symbol name=convert_element_type])" + "980([Symbol name=linear])" -> "981([Symbol name=convert_element_type])" + "1008([Symbol name=convert_element_type])" + "1000([Symbol name=linear])" -> "1008([Symbol name=convert_element_type])" + "1002([Symbol name=convert_element_type])" + "1000([Symbol name=linear])" -> "1002([Symbol name=convert_element_type])" + "1013([Symbol name=convert_element_type])" + "1001([Symbol name=linear])" -> "1013([Symbol name=convert_element_type])" + "1017([Symbol name=convert_element_type])" + "1016([Symbol name=linear])" -> "1017([Symbol name=convert_element_type])" + "933([Symbol name=convert_element_type])" + "931([Symbol name=broadcast_in_dim])" -> "933([Symbol name=convert_element_type])" + "997([Symbol name=convert_element_type])" + "995([Symbol name=broadcast_in_dim])" -> "997([Symbol name=convert_element_type])" + "1037([Symbol name=reshape])" + "1036([Symbol name=linear])" -> "1037([Symbol name=reshape])" + "1081([Symbol name=convert_element_type])" + "1080([Symbol name=linear])" -> "1081([Symbol name=convert_element_type])" + "1108([Symbol name=convert_element_type])" + "1100([Symbol name=linear])" -> "1108([Symbol name=convert_element_type])" + "1102([Symbol name=convert_element_type])" + "1100([Symbol name=linear])" -> "1102([Symbol name=convert_element_type])" + "1113([Symbol name=convert_element_type])" + "1101([Symbol name=linear])" -> "1113([Symbol name=convert_element_type])" + "1117([Symbol name=convert_element_type])" + "1116([Symbol name=linear])" -> "1117([Symbol name=convert_element_type])" + "1033([Symbol name=convert_element_type])" + "1031([Symbol name=broadcast_in_dim])" -> "1033([Symbol name=convert_element_type])" + "1097([Symbol name=convert_element_type])" + "1095([Symbol name=broadcast_in_dim])" -> "1097([Symbol name=convert_element_type])" + "1137([Symbol name=reshape])" + "1136([Symbol name=linear])" -> "1137([Symbol name=reshape])" + "1181([Symbol name=convert_element_type])" + "1180([Symbol name=linear])" -> "1181([Symbol name=convert_element_type])" + "1208([Symbol name=convert_element_type])" + "1200([Symbol name=linear])" -> "1208([Symbol name=convert_element_type])" + "1202([Symbol name=convert_element_type])" + "1200([Symbol name=linear])" -> "1202([Symbol name=convert_element_type])" + "1213([Symbol name=convert_element_type])" + "1201([Symbol name=linear])" -> "1213([Symbol name=convert_element_type])" + "1217([Symbol name=convert_element_type])" + "1216([Symbol name=linear])" -> "1217([Symbol name=convert_element_type])" + "1133([Symbol name=convert_element_type])" + "1131([Symbol name=broadcast_in_dim])" -> "1133([Symbol name=convert_element_type])" + "1197([Symbol name=convert_element_type])" + "1195([Symbol name=broadcast_in_dim])" -> "1197([Symbol name=convert_element_type])" + "1237([Symbol name=reshape])" + "1236([Symbol name=linear])" -> "1237([Symbol name=reshape])" + "1281([Symbol name=convert_element_type])" + "1280([Symbol name=linear])" -> "1281([Symbol name=convert_element_type])" + "1308([Symbol name=convert_element_type])" + "1300([Symbol name=linear])" -> "1308([Symbol name=convert_element_type])" + "1302([Symbol name=convert_element_type])" + "1300([Symbol name=linear])" -> "1302([Symbol name=convert_element_type])" + "1313([Symbol name=convert_element_type])" + "1301([Symbol name=linear])" -> "1313([Symbol name=convert_element_type])" + "1317([Symbol name=convert_element_type])" + "1316([Symbol name=linear])" -> "1317([Symbol name=convert_element_type])" + "1233([Symbol name=convert_element_type])" + "1231([Symbol name=broadcast_in_dim])" -> "1233([Symbol name=convert_element_type])" + "1297([Symbol name=convert_element_type])" + "1295([Symbol name=broadcast_in_dim])" -> "1297([Symbol name=convert_element_type])" + "1337([Symbol name=reshape])" + "1336([Symbol name=linear])" -> "1337([Symbol name=reshape])" + "1381([Symbol name=convert_element_type])" + "1380([Symbol name=linear])" -> "1381([Symbol name=convert_element_type])" + "1408([Symbol name=convert_element_type])" + "1400([Symbol name=linear])" -> "1408([Symbol name=convert_element_type])" + "1402([Symbol name=convert_element_type])" + "1400([Symbol name=linear])" -> "1402([Symbol name=convert_element_type])" + "1413([Symbol name=convert_element_type])" + "1401([Symbol name=linear])" -> "1413([Symbol name=convert_element_type])" + "1417([Symbol name=convert_element_type])" + "1416([Symbol name=linear])" -> "1417([Symbol name=convert_element_type])" + "1333([Symbol name=convert_element_type])" + "1331([Symbol name=broadcast_in_dim])" -> "1333([Symbol name=convert_element_type])" + "1397([Symbol name=convert_element_type])" + "1395([Symbol name=broadcast_in_dim])" -> "1397([Symbol name=convert_element_type])" + "1437([Symbol name=reshape])" + "1436([Symbol name=linear])" -> "1437([Symbol name=reshape])" + "1481([Symbol name=convert_element_type])" + "1480([Symbol name=linear])" -> "1481([Symbol name=convert_element_type])" + "1508([Symbol name=convert_element_type])" + "1500([Symbol name=linear])" -> "1508([Symbol name=convert_element_type])" + "1502([Symbol name=convert_element_type])" + "1500([Symbol name=linear])" -> "1502([Symbol name=convert_element_type])" + "1513([Symbol name=convert_element_type])" + "1501([Symbol name=linear])" -> "1513([Symbol name=convert_element_type])" + "1517([Symbol name=convert_element_type])" + "1516([Symbol name=linear])" -> "1517([Symbol name=convert_element_type])" + "1433([Symbol name=convert_element_type])" + "1431([Symbol name=broadcast_in_dim])" -> "1433([Symbol name=convert_element_type])" + "1497([Symbol name=convert_element_type])" + "1495([Symbol name=broadcast_in_dim])" -> "1497([Symbol name=convert_element_type])" + "1537([Symbol name=reshape])" + "1536([Symbol name=linear])" -> "1537([Symbol name=reshape])" + "1581([Symbol name=convert_element_type])" + "1580([Symbol name=linear])" -> "1581([Symbol name=convert_element_type])" + "1608([Symbol name=convert_element_type])" + "1600([Symbol name=linear])" -> "1608([Symbol name=convert_element_type])" + "1602([Symbol name=convert_element_type])" + "1600([Symbol name=linear])" -> "1602([Symbol name=convert_element_type])" + "1613([Symbol name=convert_element_type])" + "1601([Symbol name=linear])" -> "1613([Symbol name=convert_element_type])" + "1617([Symbol name=convert_element_type])" + "1616([Symbol name=linear])" -> "1617([Symbol name=convert_element_type])" + "1533([Symbol name=convert_element_type])" + "1531([Symbol name=broadcast_in_dim])" -> "1533([Symbol name=convert_element_type])" + "1597([Symbol name=convert_element_type])" + "1595([Symbol name=broadcast_in_dim])" -> "1597([Symbol name=convert_element_type])" + "1637([Symbol name=reshape])" + "1636([Symbol name=linear])" -> "1637([Symbol name=reshape])" + "1681([Symbol name=convert_element_type])" + "1680([Symbol name=linear])" -> "1681([Symbol name=convert_element_type])" + "1708([Symbol name=convert_element_type])" + "1700([Symbol name=linear])" -> "1708([Symbol name=convert_element_type])" + "1702([Symbol name=convert_element_type])" + "1700([Symbol name=linear])" -> "1702([Symbol name=convert_element_type])" + "1713([Symbol name=convert_element_type])" + "1701([Symbol name=linear])" -> "1713([Symbol name=convert_element_type])" + "1717([Symbol name=convert_element_type])" + "1716([Symbol name=linear])" -> "1717([Symbol name=convert_element_type])" + "1633([Symbol name=convert_element_type])" + "1631([Symbol name=broadcast_in_dim])" -> "1633([Symbol name=convert_element_type])" + "1697([Symbol name=convert_element_type])" + "1695([Symbol name=broadcast_in_dim])" -> "1697([Symbol name=convert_element_type])" + "1733([Symbol name=convert_element_type])" + "1731([Symbol name=broadcast_in_dim])" -> "1733([Symbol name=convert_element_type])" + "129([Symbol name=mul])" + "128([Symbol name=broadcast_in_dim])" -> "129([Symbol name=mul])" + "121([Symbol name=convert_element_type])" -> "129([Symbol name=mul])" + "122([Symbol name=mul])" + "121([Symbol name=convert_element_type])" -> "122([Symbol name=mul])" + "183([Symbol name=add])" + "181([Symbol name=convert_element_type])" -> "183([Symbol name=add])" + "182([Symbol name=convert_element_type])" -> "183([Symbol name=add])" + "1667([Symbol name=mul])" + "1665([Symbol name=broadcast_in_dim])" -> "1667([Symbol name=mul])" + "1666([Symbol name=convert_element_type])" -> "1667([Symbol name=mul])" + "267([Symbol name=mul])" + "265([Symbol name=broadcast_in_dim])" -> "267([Symbol name=mul])" + "266([Symbol name=convert_element_type])" -> "267([Symbol name=mul])" + "652([Symbol name=mul])" + "650([Symbol name=broadcast_in_dim])" -> "652([Symbol name=mul])" + "651([Symbol name=convert_element_type])" -> "652([Symbol name=mul])" + "1167([Symbol name=mul])" + "1165([Symbol name=broadcast_in_dim])" -> "1167([Symbol name=mul])" + "1166([Symbol name=convert_element_type])" -> "1167([Symbol name=mul])" + "1552([Symbol name=mul])" + "1550([Symbol name=broadcast_in_dim])" -> "1552([Symbol name=mul])" + "1551([Symbol name=convert_element_type])" -> "1552([Symbol name=mul])" + "152([Symbol name=mul])" + "150([Symbol name=broadcast_in_dim])" -> "152([Symbol name=mul])" + "151([Symbol name=convert_element_type])" -> "152([Symbol name=mul])" + "667([Symbol name=mul])" + "665([Symbol name=broadcast_in_dim])" -> "667([Symbol name=mul])" + "666([Symbol name=convert_element_type])" -> "667([Symbol name=mul])" + "1052([Symbol name=mul])" + "1050([Symbol name=broadcast_in_dim])" -> "1052([Symbol name=mul])" + "1051([Symbol name=convert_element_type])" -> "1052([Symbol name=mul])" + "1567([Symbol name=mul])" + "1565([Symbol name=broadcast_in_dim])" -> "1567([Symbol name=mul])" + "1566([Symbol name=convert_element_type])" -> "1567([Symbol name=mul])" + "167([Symbol name=mul])" + "165([Symbol name=broadcast_in_dim])" -> "167([Symbol name=mul])" + "166([Symbol name=convert_element_type])" -> "167([Symbol name=mul])" + "552([Symbol name=mul])" + "550([Symbol name=broadcast_in_dim])" -> "552([Symbol name=mul])" + "551([Symbol name=convert_element_type])" -> "552([Symbol name=mul])" + "1067([Symbol name=mul])" + "1065([Symbol name=broadcast_in_dim])" -> "1067([Symbol name=mul])" + "1066([Symbol name=convert_element_type])" -> "1067([Symbol name=mul])" + "1452([Symbol name=mul])" + "1450([Symbol name=broadcast_in_dim])" -> "1452([Symbol name=mul])" + "1451([Symbol name=convert_element_type])" -> "1452([Symbol name=mul])" + "567([Symbol name=mul])" + "565([Symbol name=broadcast_in_dim])" -> "567([Symbol name=mul])" + "566([Symbol name=convert_element_type])" -> "567([Symbol name=mul])" + "952([Symbol name=mul])" + "950([Symbol name=broadcast_in_dim])" -> "952([Symbol name=mul])" + "951([Symbol name=convert_element_type])" -> "952([Symbol name=mul])" + "1467([Symbol name=mul])" + "1465([Symbol name=broadcast_in_dim])" -> "1467([Symbol name=mul])" + "1466([Symbol name=convert_element_type])" -> "1467([Symbol name=mul])" + "452([Symbol name=mul])" + "450([Symbol name=broadcast_in_dim])" -> "452([Symbol name=mul])" + "451([Symbol name=convert_element_type])" -> "452([Symbol name=mul])" + "967([Symbol name=mul])" + "965([Symbol name=broadcast_in_dim])" -> "967([Symbol name=mul])" + "966([Symbol name=convert_element_type])" -> "967([Symbol name=mul])" + "1352([Symbol name=mul])" + "1350([Symbol name=broadcast_in_dim])" -> "1352([Symbol name=mul])" + "1351([Symbol name=convert_element_type])" -> "1352([Symbol name=mul])" + "467([Symbol name=mul])" + "465([Symbol name=broadcast_in_dim])" -> "467([Symbol name=mul])" + "466([Symbol name=convert_element_type])" -> "467([Symbol name=mul])" + "852([Symbol name=mul])" + "850([Symbol name=broadcast_in_dim])" -> "852([Symbol name=mul])" + "851([Symbol name=convert_element_type])" -> "852([Symbol name=mul])" + "1367([Symbol name=mul])" + "1365([Symbol name=broadcast_in_dim])" -> "1367([Symbol name=mul])" + "1366([Symbol name=convert_element_type])" -> "1367([Symbol name=mul])" + "352([Symbol name=mul])" + "350([Symbol name=broadcast_in_dim])" -> "352([Symbol name=mul])" + "351([Symbol name=convert_element_type])" -> "352([Symbol name=mul])" + "867([Symbol name=mul])" + "865([Symbol name=broadcast_in_dim])" -> "867([Symbol name=mul])" + "866([Symbol name=convert_element_type])" -> "867([Symbol name=mul])" + "1252([Symbol name=mul])" + "1250([Symbol name=broadcast_in_dim])" -> "1252([Symbol name=mul])" + "1251([Symbol name=convert_element_type])" -> "1252([Symbol name=mul])" + "367([Symbol name=mul])" + "365([Symbol name=broadcast_in_dim])" -> "367([Symbol name=mul])" + "366([Symbol name=convert_element_type])" -> "367([Symbol name=mul])" + "752([Symbol name=mul])" + "750([Symbol name=broadcast_in_dim])" -> "752([Symbol name=mul])" + "751([Symbol name=convert_element_type])" -> "752([Symbol name=mul])" + "1267([Symbol name=mul])" + "1265([Symbol name=broadcast_in_dim])" -> "1267([Symbol name=mul])" + "1266([Symbol name=convert_element_type])" -> "1267([Symbol name=mul])" + "1652([Symbol name=mul])" + "1650([Symbol name=broadcast_in_dim])" -> "1652([Symbol name=mul])" + "1651([Symbol name=convert_element_type])" -> "1652([Symbol name=mul])" + "252([Symbol name=mul])" + "250([Symbol name=broadcast_in_dim])" -> "252([Symbol name=mul])" + "251([Symbol name=convert_element_type])" -> "252([Symbol name=mul])" + "767([Symbol name=mul])" + "765([Symbol name=broadcast_in_dim])" -> "767([Symbol name=mul])" + "766([Symbol name=convert_element_type])" -> "767([Symbol name=mul])" + "1152([Symbol name=mul])" + "1150([Symbol name=broadcast_in_dim])" -> "1152([Symbol name=mul])" + "1151([Symbol name=convert_element_type])" -> "1152([Symbol name=mul])" + "770([Symbol name=mul])" + "768([Symbol name=broadcast_in_dim])" -> "770([Symbol name=mul])" + "769([Symbol name=convert_element_type])" -> "770([Symbol name=mul])" + "1155([Symbol name=mul])" + "1153([Symbol name=broadcast_in_dim])" -> "1155([Symbol name=mul])" + "1154([Symbol name=convert_element_type])" -> "1155([Symbol name=mul])" + "1670([Symbol name=mul])" + "1668([Symbol name=broadcast_in_dim])" -> "1670([Symbol name=mul])" + "1669([Symbol name=convert_element_type])" -> "1670([Symbol name=mul])" + "270([Symbol name=mul])" + "268([Symbol name=broadcast_in_dim])" -> "270([Symbol name=mul])" + "269([Symbol name=convert_element_type])" -> "270([Symbol name=mul])" + "655([Symbol name=mul])" + "653([Symbol name=broadcast_in_dim])" -> "655([Symbol name=mul])" + "654([Symbol name=convert_element_type])" -> "655([Symbol name=mul])" + "1170([Symbol name=mul])" + "1168([Symbol name=broadcast_in_dim])" -> "1170([Symbol name=mul])" + "1169([Symbol name=convert_element_type])" -> "1170([Symbol name=mul])" + "1555([Symbol name=mul])" + "1553([Symbol name=broadcast_in_dim])" -> "1555([Symbol name=mul])" + "1554([Symbol name=convert_element_type])" -> "1555([Symbol name=mul])" + "155([Symbol name=mul])" + "153([Symbol name=broadcast_in_dim])" -> "155([Symbol name=mul])" + "154([Symbol name=convert_element_type])" -> "155([Symbol name=mul])" + "670([Symbol name=mul])" + "668([Symbol name=broadcast_in_dim])" -> "670([Symbol name=mul])" + "669([Symbol name=convert_element_type])" -> "670([Symbol name=mul])" + "1055([Symbol name=mul])" + "1053([Symbol name=broadcast_in_dim])" -> "1055([Symbol name=mul])" + "1054([Symbol name=convert_element_type])" -> "1055([Symbol name=mul])" + "1570([Symbol name=mul])" + "1568([Symbol name=broadcast_in_dim])" -> "1570([Symbol name=mul])" + "1569([Symbol name=convert_element_type])" -> "1570([Symbol name=mul])" + "170([Symbol name=mul])" + "168([Symbol name=broadcast_in_dim])" -> "170([Symbol name=mul])" + "169([Symbol name=convert_element_type])" -> "170([Symbol name=mul])" + "555([Symbol name=mul])" + "553([Symbol name=broadcast_in_dim])" -> "555([Symbol name=mul])" + "554([Symbol name=convert_element_type])" -> "555([Symbol name=mul])" + "1070([Symbol name=mul])" + "1068([Symbol name=broadcast_in_dim])" -> "1070([Symbol name=mul])" + "1069([Symbol name=convert_element_type])" -> "1070([Symbol name=mul])" + "1455([Symbol name=mul])" + "1453([Symbol name=broadcast_in_dim])" -> "1455([Symbol name=mul])" + "1454([Symbol name=convert_element_type])" -> "1455([Symbol name=mul])" + "570([Symbol name=mul])" + "568([Symbol name=broadcast_in_dim])" -> "570([Symbol name=mul])" + "569([Symbol name=convert_element_type])" -> "570([Symbol name=mul])" + "955([Symbol name=mul])" + "953([Symbol name=broadcast_in_dim])" -> "955([Symbol name=mul])" + "954([Symbol name=convert_element_type])" -> "955([Symbol name=mul])" + "1470([Symbol name=mul])" + "1468([Symbol name=broadcast_in_dim])" -> "1470([Symbol name=mul])" + "1469([Symbol name=convert_element_type])" -> "1470([Symbol name=mul])" + "455([Symbol name=mul])" + "453([Symbol name=broadcast_in_dim])" -> "455([Symbol name=mul])" + "454([Symbol name=convert_element_type])" -> "455([Symbol name=mul])" + "970([Symbol name=mul])" + "968([Symbol name=broadcast_in_dim])" -> "970([Symbol name=mul])" + "969([Symbol name=convert_element_type])" -> "970([Symbol name=mul])" + "1355([Symbol name=mul])" + "1353([Symbol name=broadcast_in_dim])" -> "1355([Symbol name=mul])" + "1354([Symbol name=convert_element_type])" -> "1355([Symbol name=mul])" + "470([Symbol name=mul])" + "468([Symbol name=broadcast_in_dim])" -> "470([Symbol name=mul])" + "469([Symbol name=convert_element_type])" -> "470([Symbol name=mul])" + "855([Symbol name=mul])" + "853([Symbol name=broadcast_in_dim])" -> "855([Symbol name=mul])" + "854([Symbol name=convert_element_type])" -> "855([Symbol name=mul])" + "1370([Symbol name=mul])" + "1368([Symbol name=broadcast_in_dim])" -> "1370([Symbol name=mul])" + "1369([Symbol name=convert_element_type])" -> "1370([Symbol name=mul])" + "355([Symbol name=mul])" + "353([Symbol name=broadcast_in_dim])" -> "355([Symbol name=mul])" + "354([Symbol name=convert_element_type])" -> "355([Symbol name=mul])" + "870([Symbol name=mul])" + "868([Symbol name=broadcast_in_dim])" -> "870([Symbol name=mul])" + "869([Symbol name=convert_element_type])" -> "870([Symbol name=mul])" + "1255([Symbol name=mul])" + "1253([Symbol name=broadcast_in_dim])" -> "1255([Symbol name=mul])" + "1254([Symbol name=convert_element_type])" -> "1255([Symbol name=mul])" + "370([Symbol name=mul])" + "368([Symbol name=broadcast_in_dim])" -> "370([Symbol name=mul])" + "369([Symbol name=convert_element_type])" -> "370([Symbol name=mul])" + "755([Symbol name=mul])" + "753([Symbol name=broadcast_in_dim])" -> "755([Symbol name=mul])" + "754([Symbol name=convert_element_type])" -> "755([Symbol name=mul])" + "1270([Symbol name=mul])" + "1268([Symbol name=broadcast_in_dim])" -> "1270([Symbol name=mul])" + "1269([Symbol name=convert_element_type])" -> "1270([Symbol name=mul])" + "1655([Symbol name=mul])" + "1653([Symbol name=broadcast_in_dim])" -> "1655([Symbol name=mul])" + "1654([Symbol name=convert_element_type])" -> "1655([Symbol name=mul])" + "255([Symbol name=mul])" + "253([Symbol name=broadcast_in_dim])" -> "255([Symbol name=mul])" + "254([Symbol name=convert_element_type])" -> "255([Symbol name=mul])" + "138([Symbol name=transpose])" + "137([Symbol name=reshape])" -> "138([Symbol name=transpose])" + "210([Symbol name=mul])" + "208([Symbol name=convert_element_type])" -> "210([Symbol name=mul])" + "209([Symbol name=convert_element_type])" -> "210([Symbol name=mul])" + "203([Symbol name=neg])" + "202([Symbol name=convert_element_type])" -> "203([Symbol name=neg])" + "214([Symbol name=mul])" + "212([Symbol name=convert_element_type])" -> "214([Symbol name=mul])" + "213([Symbol name=convert_element_type])" -> "214([Symbol name=mul])" + "219([Symbol name=add])" + "217([Symbol name=convert_element_type])" -> "219([Symbol name=add])" + "218([Symbol name=convert_element_type])" -> "219([Symbol name=add])" + "134([Symbol name=mul])" + "132([Symbol name=convert_element_type])" -> "134([Symbol name=mul])" + "133([Symbol name=convert_element_type])" -> "134([Symbol name=mul])" + "198([Symbol name=mul])" + "196([Symbol name=convert_element_type])" -> "198([Symbol name=mul])" + "197([Symbol name=convert_element_type])" -> "198([Symbol name=mul])" + "238([Symbol name=transpose])" + "237([Symbol name=reshape])" -> "238([Symbol name=transpose])" + "283([Symbol name=add])" + "281([Symbol name=convert_element_type])" -> "283([Symbol name=add])" + "282([Symbol name=convert_element_type])" -> "283([Symbol name=add])" + "310([Symbol name=mul])" + "308([Symbol name=convert_element_type])" -> "310([Symbol name=mul])" + "309([Symbol name=convert_element_type])" -> "310([Symbol name=mul])" + "303([Symbol name=neg])" + "302([Symbol name=convert_element_type])" -> "303([Symbol name=neg])" + "314([Symbol name=mul])" + "312([Symbol name=convert_element_type])" -> "314([Symbol name=mul])" + "313([Symbol name=convert_element_type])" -> "314([Symbol name=mul])" + "319([Symbol name=add])" + "317([Symbol name=convert_element_type])" -> "319([Symbol name=add])" + "318([Symbol name=convert_element_type])" -> "319([Symbol name=add])" + "234([Symbol name=mul])" + "232([Symbol name=convert_element_type])" -> "234([Symbol name=mul])" + "233([Symbol name=convert_element_type])" -> "234([Symbol name=mul])" + "298([Symbol name=mul])" + "296([Symbol name=convert_element_type])" -> "298([Symbol name=mul])" + "297([Symbol name=convert_element_type])" -> "298([Symbol name=mul])" + "338([Symbol name=transpose])" + "337([Symbol name=reshape])" -> "338([Symbol name=transpose])" + "383([Symbol name=add])" + "381([Symbol name=convert_element_type])" -> "383([Symbol name=add])" + "382([Symbol name=convert_element_type])" -> "383([Symbol name=add])" + "410([Symbol name=mul])" + "408([Symbol name=convert_element_type])" -> "410([Symbol name=mul])" + "409([Symbol name=convert_element_type])" -> "410([Symbol name=mul])" + "403([Symbol name=neg])" + "402([Symbol name=convert_element_type])" -> "403([Symbol name=neg])" + "414([Symbol name=mul])" + "412([Symbol name=convert_element_type])" -> "414([Symbol name=mul])" + "413([Symbol name=convert_element_type])" -> "414([Symbol name=mul])" + "419([Symbol name=add])" + "417([Symbol name=convert_element_type])" -> "419([Symbol name=add])" + "418([Symbol name=convert_element_type])" -> "419([Symbol name=add])" + "334([Symbol name=mul])" + "332([Symbol name=convert_element_type])" -> "334([Symbol name=mul])" + "333([Symbol name=convert_element_type])" -> "334([Symbol name=mul])" + "398([Symbol name=mul])" + "396([Symbol name=convert_element_type])" -> "398([Symbol name=mul])" + "397([Symbol name=convert_element_type])" -> "398([Symbol name=mul])" + "438([Symbol name=transpose])" + "437([Symbol name=reshape])" -> "438([Symbol name=transpose])" + "483([Symbol name=add])" + "481([Symbol name=convert_element_type])" -> "483([Symbol name=add])" + "482([Symbol name=convert_element_type])" -> "483([Symbol name=add])" + "510([Symbol name=mul])" + "508([Symbol name=convert_element_type])" -> "510([Symbol name=mul])" + "509([Symbol name=convert_element_type])" -> "510([Symbol name=mul])" + "503([Symbol name=neg])" + "502([Symbol name=convert_element_type])" -> "503([Symbol name=neg])" + "514([Symbol name=mul])" + "512([Symbol name=convert_element_type])" -> "514([Symbol name=mul])" + "513([Symbol name=convert_element_type])" -> "514([Symbol name=mul])" + "519([Symbol name=add])" + "517([Symbol name=convert_element_type])" -> "519([Symbol name=add])" + "518([Symbol name=convert_element_type])" -> "519([Symbol name=add])" + "434([Symbol name=mul])" + "432([Symbol name=convert_element_type])" -> "434([Symbol name=mul])" + "433([Symbol name=convert_element_type])" -> "434([Symbol name=mul])" + "498([Symbol name=mul])" + "496([Symbol name=convert_element_type])" -> "498([Symbol name=mul])" + "497([Symbol name=convert_element_type])" -> "498([Symbol name=mul])" + "538([Symbol name=transpose])" + "537([Symbol name=reshape])" -> "538([Symbol name=transpose])" + "583([Symbol name=add])" + "581([Symbol name=convert_element_type])" -> "583([Symbol name=add])" + "582([Symbol name=convert_element_type])" -> "583([Symbol name=add])" + "610([Symbol name=mul])" + "608([Symbol name=convert_element_type])" -> "610([Symbol name=mul])" + "609([Symbol name=convert_element_type])" -> "610([Symbol name=mul])" + "603([Symbol name=neg])" + "602([Symbol name=convert_element_type])" -> "603([Symbol name=neg])" + "614([Symbol name=mul])" + "612([Symbol name=convert_element_type])" -> "614([Symbol name=mul])" + "613([Symbol name=convert_element_type])" -> "614([Symbol name=mul])" + "619([Symbol name=add])" + "617([Symbol name=convert_element_type])" -> "619([Symbol name=add])" + "618([Symbol name=convert_element_type])" -> "619([Symbol name=add])" + "534([Symbol name=mul])" + "532([Symbol name=convert_element_type])" -> "534([Symbol name=mul])" + "533([Symbol name=convert_element_type])" -> "534([Symbol name=mul])" + "598([Symbol name=mul])" + "596([Symbol name=convert_element_type])" -> "598([Symbol name=mul])" + "597([Symbol name=convert_element_type])" -> "598([Symbol name=mul])" + "638([Symbol name=transpose])" + "637([Symbol name=reshape])" -> "638([Symbol name=transpose])" + "683([Symbol name=add])" + "681([Symbol name=convert_element_type])" -> "683([Symbol name=add])" + "682([Symbol name=convert_element_type])" -> "683([Symbol name=add])" + "710([Symbol name=mul])" + "708([Symbol name=convert_element_type])" -> "710([Symbol name=mul])" + "709([Symbol name=convert_element_type])" -> "710([Symbol name=mul])" + "703([Symbol name=neg])" + "702([Symbol name=convert_element_type])" -> "703([Symbol name=neg])" + "714([Symbol name=mul])" + "712([Symbol name=convert_element_type])" -> "714([Symbol name=mul])" + "713([Symbol name=convert_element_type])" -> "714([Symbol name=mul])" + "719([Symbol name=add])" + "717([Symbol name=convert_element_type])" -> "719([Symbol name=add])" + "718([Symbol name=convert_element_type])" -> "719([Symbol name=add])" + "634([Symbol name=mul])" + "632([Symbol name=convert_element_type])" -> "634([Symbol name=mul])" + "633([Symbol name=convert_element_type])" -> "634([Symbol name=mul])" + "698([Symbol name=mul])" + "696([Symbol name=convert_element_type])" -> "698([Symbol name=mul])" + "697([Symbol name=convert_element_type])" -> "698([Symbol name=mul])" + "738([Symbol name=transpose])" + "737([Symbol name=reshape])" -> "738([Symbol name=transpose])" + "783([Symbol name=add])" + "781([Symbol name=convert_element_type])" -> "783([Symbol name=add])" + "782([Symbol name=convert_element_type])" -> "783([Symbol name=add])" + "810([Symbol name=mul])" + "808([Symbol name=convert_element_type])" -> "810([Symbol name=mul])" + "809([Symbol name=convert_element_type])" -> "810([Symbol name=mul])" + "803([Symbol name=neg])" + "802([Symbol name=convert_element_type])" -> "803([Symbol name=neg])" + "814([Symbol name=mul])" + "812([Symbol name=convert_element_type])" -> "814([Symbol name=mul])" + "813([Symbol name=convert_element_type])" -> "814([Symbol name=mul])" + "819([Symbol name=add])" + "817([Symbol name=convert_element_type])" -> "819([Symbol name=add])" + "818([Symbol name=convert_element_type])" -> "819([Symbol name=add])" + "734([Symbol name=mul])" + "732([Symbol name=convert_element_type])" -> "734([Symbol name=mul])" + "733([Symbol name=convert_element_type])" -> "734([Symbol name=mul])" + "798([Symbol name=mul])" + "796([Symbol name=convert_element_type])" -> "798([Symbol name=mul])" + "797([Symbol name=convert_element_type])" -> "798([Symbol name=mul])" + "838([Symbol name=transpose])" + "837([Symbol name=reshape])" -> "838([Symbol name=transpose])" + "883([Symbol name=add])" + "881([Symbol name=convert_element_type])" -> "883([Symbol name=add])" + "882([Symbol name=convert_element_type])" -> "883([Symbol name=add])" + "910([Symbol name=mul])" + "908([Symbol name=convert_element_type])" -> "910([Symbol name=mul])" + "909([Symbol name=convert_element_type])" -> "910([Symbol name=mul])" + "903([Symbol name=neg])" + "902([Symbol name=convert_element_type])" -> "903([Symbol name=neg])" + "914([Symbol name=mul])" + "912([Symbol name=convert_element_type])" -> "914([Symbol name=mul])" + "913([Symbol name=convert_element_type])" -> "914([Symbol name=mul])" + "919([Symbol name=add])" + "917([Symbol name=convert_element_type])" -> "919([Symbol name=add])" + "918([Symbol name=convert_element_type])" -> "919([Symbol name=add])" + "834([Symbol name=mul])" + "832([Symbol name=convert_element_type])" -> "834([Symbol name=mul])" + "833([Symbol name=convert_element_type])" -> "834([Symbol name=mul])" + "898([Symbol name=mul])" + "896([Symbol name=convert_element_type])" -> "898([Symbol name=mul])" + "897([Symbol name=convert_element_type])" -> "898([Symbol name=mul])" + "938([Symbol name=transpose])" + "937([Symbol name=reshape])" -> "938([Symbol name=transpose])" + "983([Symbol name=add])" + "981([Symbol name=convert_element_type])" -> "983([Symbol name=add])" + "982([Symbol name=convert_element_type])" -> "983([Symbol name=add])" + "1010([Symbol name=mul])" + "1008([Symbol name=convert_element_type])" -> "1010([Symbol name=mul])" + "1009([Symbol name=convert_element_type])" -> "1010([Symbol name=mul])" + "1003([Symbol name=neg])" + "1002([Symbol name=convert_element_type])" -> "1003([Symbol name=neg])" + "1014([Symbol name=mul])" + "1012([Symbol name=convert_element_type])" -> "1014([Symbol name=mul])" + "1013([Symbol name=convert_element_type])" -> "1014([Symbol name=mul])" + "1019([Symbol name=add])" + "1017([Symbol name=convert_element_type])" -> "1019([Symbol name=add])" + "1018([Symbol name=convert_element_type])" -> "1019([Symbol name=add])" + "934([Symbol name=mul])" + "932([Symbol name=convert_element_type])" -> "934([Symbol name=mul])" + "933([Symbol name=convert_element_type])" -> "934([Symbol name=mul])" + "998([Symbol name=mul])" + "996([Symbol name=convert_element_type])" -> "998([Symbol name=mul])" + "997([Symbol name=convert_element_type])" -> "998([Symbol name=mul])" + "1038([Symbol name=transpose])" + "1037([Symbol name=reshape])" -> "1038([Symbol name=transpose])" + "1083([Symbol name=add])" + "1081([Symbol name=convert_element_type])" -> "1083([Symbol name=add])" + "1082([Symbol name=convert_element_type])" -> "1083([Symbol name=add])" + "1110([Symbol name=mul])" + "1108([Symbol name=convert_element_type])" -> "1110([Symbol name=mul])" + "1109([Symbol name=convert_element_type])" -> "1110([Symbol name=mul])" + "1103([Symbol name=neg])" + "1102([Symbol name=convert_element_type])" -> "1103([Symbol name=neg])" + "1114([Symbol name=mul])" + "1112([Symbol name=convert_element_type])" -> "1114([Symbol name=mul])" + "1113([Symbol name=convert_element_type])" -> "1114([Symbol name=mul])" + "1119([Symbol name=add])" + "1117([Symbol name=convert_element_type])" -> "1119([Symbol name=add])" + "1118([Symbol name=convert_element_type])" -> "1119([Symbol name=add])" + "1034([Symbol name=mul])" + "1032([Symbol name=convert_element_type])" -> "1034([Symbol name=mul])" + "1033([Symbol name=convert_element_type])" -> "1034([Symbol name=mul])" + "1098([Symbol name=mul])" + "1096([Symbol name=convert_element_type])" -> "1098([Symbol name=mul])" + "1097([Symbol name=convert_element_type])" -> "1098([Symbol name=mul])" + "1138([Symbol name=transpose])" + "1137([Symbol name=reshape])" -> "1138([Symbol name=transpose])" + "1183([Symbol name=add])" + "1181([Symbol name=convert_element_type])" -> "1183([Symbol name=add])" + "1182([Symbol name=convert_element_type])" -> "1183([Symbol name=add])" + "1210([Symbol name=mul])" + "1208([Symbol name=convert_element_type])" -> "1210([Symbol name=mul])" + "1209([Symbol name=convert_element_type])" -> "1210([Symbol name=mul])" + "1203([Symbol name=neg])" + "1202([Symbol name=convert_element_type])" -> "1203([Symbol name=neg])" + "1214([Symbol name=mul])" + "1212([Symbol name=convert_element_type])" -> "1214([Symbol name=mul])" + "1213([Symbol name=convert_element_type])" -> "1214([Symbol name=mul])" + "1219([Symbol name=add])" + "1217([Symbol name=convert_element_type])" -> "1219([Symbol name=add])" + "1218([Symbol name=convert_element_type])" -> "1219([Symbol name=add])" + "1134([Symbol name=mul])" + "1132([Symbol name=convert_element_type])" -> "1134([Symbol name=mul])" + "1133([Symbol name=convert_element_type])" -> "1134([Symbol name=mul])" + "1198([Symbol name=mul])" + "1196([Symbol name=convert_element_type])" -> "1198([Symbol name=mul])" + "1197([Symbol name=convert_element_type])" -> "1198([Symbol name=mul])" + "1238([Symbol name=transpose])" + "1237([Symbol name=reshape])" -> "1238([Symbol name=transpose])" + "1283([Symbol name=add])" + "1281([Symbol name=convert_element_type])" -> "1283([Symbol name=add])" + "1282([Symbol name=convert_element_type])" -> "1283([Symbol name=add])" + "1310([Symbol name=mul])" + "1308([Symbol name=convert_element_type])" -> "1310([Symbol name=mul])" + "1309([Symbol name=convert_element_type])" -> "1310([Symbol name=mul])" + "1303([Symbol name=neg])" + "1302([Symbol name=convert_element_type])" -> "1303([Symbol name=neg])" + "1314([Symbol name=mul])" + "1312([Symbol name=convert_element_type])" -> "1314([Symbol name=mul])" + "1313([Symbol name=convert_element_type])" -> "1314([Symbol name=mul])" + "1319([Symbol name=add])" + "1317([Symbol name=convert_element_type])" -> "1319([Symbol name=add])" + "1318([Symbol name=convert_element_type])" -> "1319([Symbol name=add])" + "1234([Symbol name=mul])" + "1232([Symbol name=convert_element_type])" -> "1234([Symbol name=mul])" + "1233([Symbol name=convert_element_type])" -> "1234([Symbol name=mul])" + "1298([Symbol name=mul])" + "1296([Symbol name=convert_element_type])" -> "1298([Symbol name=mul])" + "1297([Symbol name=convert_element_type])" -> "1298([Symbol name=mul])" + "1338([Symbol name=transpose])" + "1337([Symbol name=reshape])" -> "1338([Symbol name=transpose])" + "1383([Symbol name=add])" + "1381([Symbol name=convert_element_type])" -> "1383([Symbol name=add])" + "1382([Symbol name=convert_element_type])" -> "1383([Symbol name=add])" + "1410([Symbol name=mul])" + "1408([Symbol name=convert_element_type])" -> "1410([Symbol name=mul])" + "1409([Symbol name=convert_element_type])" -> "1410([Symbol name=mul])" + "1403([Symbol name=neg])" + "1402([Symbol name=convert_element_type])" -> "1403([Symbol name=neg])" + "1414([Symbol name=mul])" + "1412([Symbol name=convert_element_type])" -> "1414([Symbol name=mul])" + "1413([Symbol name=convert_element_type])" -> "1414([Symbol name=mul])" + "1419([Symbol name=add])" + "1417([Symbol name=convert_element_type])" -> "1419([Symbol name=add])" + "1418([Symbol name=convert_element_type])" -> "1419([Symbol name=add])" + "1334([Symbol name=mul])" + "1332([Symbol name=convert_element_type])" -> "1334([Symbol name=mul])" + "1333([Symbol name=convert_element_type])" -> "1334([Symbol name=mul])" + "1398([Symbol name=mul])" + "1396([Symbol name=convert_element_type])" -> "1398([Symbol name=mul])" + "1397([Symbol name=convert_element_type])" -> "1398([Symbol name=mul])" + "1438([Symbol name=transpose])" + "1437([Symbol name=reshape])" -> "1438([Symbol name=transpose])" + "1483([Symbol name=add])" + "1481([Symbol name=convert_element_type])" -> "1483([Symbol name=add])" + "1482([Symbol name=convert_element_type])" -> "1483([Symbol name=add])" + "1510([Symbol name=mul])" + "1508([Symbol name=convert_element_type])" -> "1510([Symbol name=mul])" + "1509([Symbol name=convert_element_type])" -> "1510([Symbol name=mul])" + "1503([Symbol name=neg])" + "1502([Symbol name=convert_element_type])" -> "1503([Symbol name=neg])" + "1514([Symbol name=mul])" + "1512([Symbol name=convert_element_type])" -> "1514([Symbol name=mul])" + "1513([Symbol name=convert_element_type])" -> "1514([Symbol name=mul])" + "1519([Symbol name=add])" + "1517([Symbol name=convert_element_type])" -> "1519([Symbol name=add])" + "1518([Symbol name=convert_element_type])" -> "1519([Symbol name=add])" + "1434([Symbol name=mul])" + "1432([Symbol name=convert_element_type])" -> "1434([Symbol name=mul])" + "1433([Symbol name=convert_element_type])" -> "1434([Symbol name=mul])" + "1498([Symbol name=mul])" + "1496([Symbol name=convert_element_type])" -> "1498([Symbol name=mul])" + "1497([Symbol name=convert_element_type])" -> "1498([Symbol name=mul])" + "1538([Symbol name=transpose])" + "1537([Symbol name=reshape])" -> "1538([Symbol name=transpose])" + "1583([Symbol name=add])" + "1581([Symbol name=convert_element_type])" -> "1583([Symbol name=add])" + "1582([Symbol name=convert_element_type])" -> "1583([Symbol name=add])" + "1610([Symbol name=mul])" + "1608([Symbol name=convert_element_type])" -> "1610([Symbol name=mul])" + "1609([Symbol name=convert_element_type])" -> "1610([Symbol name=mul])" + "1603([Symbol name=neg])" + "1602([Symbol name=convert_element_type])" -> "1603([Symbol name=neg])" + "1614([Symbol name=mul])" + "1612([Symbol name=convert_element_type])" -> "1614([Symbol name=mul])" + "1613([Symbol name=convert_element_type])" -> "1614([Symbol name=mul])" + "1619([Symbol name=add])" + "1617([Symbol name=convert_element_type])" -> "1619([Symbol name=add])" + "1618([Symbol name=convert_element_type])" -> "1619([Symbol name=add])" + "1534([Symbol name=mul])" + "1532([Symbol name=convert_element_type])" -> "1534([Symbol name=mul])" + "1533([Symbol name=convert_element_type])" -> "1534([Symbol name=mul])" + "1598([Symbol name=mul])" + "1596([Symbol name=convert_element_type])" -> "1598([Symbol name=mul])" + "1597([Symbol name=convert_element_type])" -> "1598([Symbol name=mul])" + "1638([Symbol name=transpose])" + "1637([Symbol name=reshape])" -> "1638([Symbol name=transpose])" + "1683([Symbol name=add])" + "1681([Symbol name=convert_element_type])" -> "1683([Symbol name=add])" + "1682([Symbol name=convert_element_type])" -> "1683([Symbol name=add])" + "1710([Symbol name=mul])" + "1708([Symbol name=convert_element_type])" -> "1710([Symbol name=mul])" + "1709([Symbol name=convert_element_type])" -> "1710([Symbol name=mul])" + "1703([Symbol name=neg])" + "1702([Symbol name=convert_element_type])" -> "1703([Symbol name=neg])" + "1714([Symbol name=mul])" + "1712([Symbol name=convert_element_type])" -> "1714([Symbol name=mul])" + "1713([Symbol name=convert_element_type])" -> "1714([Symbol name=mul])" + "1719([Symbol name=add])" + "1717([Symbol name=convert_element_type])" -> "1719([Symbol name=add])" + "1718([Symbol name=convert_element_type])" -> "1719([Symbol name=add])" + "1634([Symbol name=mul])" + "1632([Symbol name=convert_element_type])" -> "1634([Symbol name=mul])" + "1633([Symbol name=convert_element_type])" -> "1634([Symbol name=mul])" + "1698([Symbol name=mul])" + "1696([Symbol name=convert_element_type])" -> "1698([Symbol name=mul])" + "1697([Symbol name=convert_element_type])" -> "1698([Symbol name=mul])" + "1734([Symbol name=mul])" + "1732([Symbol name=convert_element_type])" -> "1734([Symbol name=mul])" + "1733([Symbol name=convert_element_type])" -> "1734([Symbol name=mul])" + "130([Symbol name=convert_element_type])" + "129([Symbol name=mul])" -> "130([Symbol name=convert_element_type])" + "123([Symbol name=sum])" + "122([Symbol name=mul])" -> "123([Symbol name=sum])" + "184([Symbol name=convert_element_type])" + "183([Symbol name=add])" -> "184([Symbol name=convert_element_type])" + "1671([Symbol name=add])" + "1667([Symbol name=mul])" -> "1671([Symbol name=add])" + "1670([Symbol name=mul])" -> "1671([Symbol name=add])" + "271([Symbol name=add])" + "267([Symbol name=mul])" -> "271([Symbol name=add])" + "270([Symbol name=mul])" -> "271([Symbol name=add])" + "656([Symbol name=add])" + "652([Symbol name=mul])" -> "656([Symbol name=add])" + "655([Symbol name=mul])" -> "656([Symbol name=add])" + "1171([Symbol name=add])" + "1170([Symbol name=mul])" -> "1171([Symbol name=add])" + "1167([Symbol name=mul])" -> "1171([Symbol name=add])" + "1556([Symbol name=add])" + "1552([Symbol name=mul])" -> "1556([Symbol name=add])" + "1555([Symbol name=mul])" -> "1556([Symbol name=add])" + "156([Symbol name=add])" + "152([Symbol name=mul])" -> "156([Symbol name=add])" + "155([Symbol name=mul])" -> "156([Symbol name=add])" + "671([Symbol name=add])" + "667([Symbol name=mul])" -> "671([Symbol name=add])" + "670([Symbol name=mul])" -> "671([Symbol name=add])" + "1056([Symbol name=add])" + "1052([Symbol name=mul])" -> "1056([Symbol name=add])" + "1055([Symbol name=mul])" -> "1056([Symbol name=add])" + "1571([Symbol name=add])" + "1570([Symbol name=mul])" -> "1571([Symbol name=add])" + "1567([Symbol name=mul])" -> "1571([Symbol name=add])" + "171([Symbol name=add])" + "170([Symbol name=mul])" -> "171([Symbol name=add])" + "167([Symbol name=mul])" -> "171([Symbol name=add])" + "556([Symbol name=add])" + "552([Symbol name=mul])" -> "556([Symbol name=add])" + "555([Symbol name=mul])" -> "556([Symbol name=add])" + "1071([Symbol name=add])" + "1067([Symbol name=mul])" -> "1071([Symbol name=add])" + "1070([Symbol name=mul])" -> "1071([Symbol name=add])" + "1456([Symbol name=add])" + "1452([Symbol name=mul])" -> "1456([Symbol name=add])" + "1455([Symbol name=mul])" -> "1456([Symbol name=add])" + "571([Symbol name=add])" + "570([Symbol name=mul])" -> "571([Symbol name=add])" + "567([Symbol name=mul])" -> "571([Symbol name=add])" + "956([Symbol name=add])" + "952([Symbol name=mul])" -> "956([Symbol name=add])" + "955([Symbol name=mul])" -> "956([Symbol name=add])" + "1471([Symbol name=add])" + "1467([Symbol name=mul])" -> "1471([Symbol name=add])" + "1470([Symbol name=mul])" -> "1471([Symbol name=add])" + "456([Symbol name=add])" + "452([Symbol name=mul])" -> "456([Symbol name=add])" + "455([Symbol name=mul])" -> "456([Symbol name=add])" + "971([Symbol name=add])" + "970([Symbol name=mul])" -> "971([Symbol name=add])" + "967([Symbol name=mul])" -> "971([Symbol name=add])" + "1356([Symbol name=add])" + "1352([Symbol name=mul])" -> "1356([Symbol name=add])" + "1355([Symbol name=mul])" -> "1356([Symbol name=add])" + "471([Symbol name=add])" + "467([Symbol name=mul])" -> "471([Symbol name=add])" + "470([Symbol name=mul])" -> "471([Symbol name=add])" + "856([Symbol name=add])" + "852([Symbol name=mul])" -> "856([Symbol name=add])" + "855([Symbol name=mul])" -> "856([Symbol name=add])" + "1371([Symbol name=add])" + "1370([Symbol name=mul])" -> "1371([Symbol name=add])" + "1367([Symbol name=mul])" -> "1371([Symbol name=add])" + "356([Symbol name=add])" + "352([Symbol name=mul])" -> "356([Symbol name=add])" + "355([Symbol name=mul])" -> "356([Symbol name=add])" + "871([Symbol name=add])" + "867([Symbol name=mul])" -> "871([Symbol name=add])" + "870([Symbol name=mul])" -> "871([Symbol name=add])" + "1256([Symbol name=add])" + "1252([Symbol name=mul])" -> "1256([Symbol name=add])" + "1255([Symbol name=mul])" -> "1256([Symbol name=add])" + "371([Symbol name=add])" + "370([Symbol name=mul])" -> "371([Symbol name=add])" + "367([Symbol name=mul])" -> "371([Symbol name=add])" + "756([Symbol name=add])" + "752([Symbol name=mul])" -> "756([Symbol name=add])" + "755([Symbol name=mul])" -> "756([Symbol name=add])" + "1271([Symbol name=add])" + "1267([Symbol name=mul])" -> "1271([Symbol name=add])" + "1270([Symbol name=mul])" -> "1271([Symbol name=add])" + "1656([Symbol name=add])" + "1652([Symbol name=mul])" -> "1656([Symbol name=add])" + "1655([Symbol name=mul])" -> "1656([Symbol name=add])" + "256([Symbol name=add])" + "252([Symbol name=mul])" -> "256([Symbol name=add])" + "255([Symbol name=mul])" -> "256([Symbol name=add])" + "771([Symbol name=add])" + "770([Symbol name=mul])" -> "771([Symbol name=add])" + "767([Symbol name=mul])" -> "771([Symbol name=add])" + "1156([Symbol name=add])" + "1152([Symbol name=mul])" -> "1156([Symbol name=add])" + "1155([Symbol name=mul])" -> "1156([Symbol name=add])" + "139([Symbol name=split])" + "138([Symbol name=transpose])" -> "139([Symbol name=split])" + "211([Symbol name=convert_element_type])" + "210([Symbol name=mul])" -> "211([Symbol name=convert_element_type])" + "204([Symbol name=exp])" + "203([Symbol name=neg])" -> "204([Symbol name=exp])" + "215([Symbol name=convert_element_type])" + "214([Symbol name=mul])" -> "215([Symbol name=convert_element_type])" + "220([Symbol name=convert_element_type])" + "219([Symbol name=add])" -> "220([Symbol name=convert_element_type])" + "135([Symbol name=convert_element_type])" + "134([Symbol name=mul])" -> "135([Symbol name=convert_element_type])" + "199([Symbol name=convert_element_type])" + "198([Symbol name=mul])" -> "199([Symbol name=convert_element_type])" + "239([Symbol name=split])" + "238([Symbol name=transpose])" -> "239([Symbol name=split])" + "284([Symbol name=convert_element_type])" + "283([Symbol name=add])" -> "284([Symbol name=convert_element_type])" + "311([Symbol name=convert_element_type])" + "310([Symbol name=mul])" -> "311([Symbol name=convert_element_type])" + "304([Symbol name=exp])" + "303([Symbol name=neg])" -> "304([Symbol name=exp])" + "315([Symbol name=convert_element_type])" + "314([Symbol name=mul])" -> "315([Symbol name=convert_element_type])" + "320([Symbol name=convert_element_type])" + "319([Symbol name=add])" -> "320([Symbol name=convert_element_type])" + "235([Symbol name=convert_element_type])" + "234([Symbol name=mul])" -> "235([Symbol name=convert_element_type])" + "299([Symbol name=convert_element_type])" + "298([Symbol name=mul])" -> "299([Symbol name=convert_element_type])" + "339([Symbol name=split])" + "338([Symbol name=transpose])" -> "339([Symbol name=split])" + "384([Symbol name=convert_element_type])" + "383([Symbol name=add])" -> "384([Symbol name=convert_element_type])" + "411([Symbol name=convert_element_type])" + "410([Symbol name=mul])" -> "411([Symbol name=convert_element_type])" + "404([Symbol name=exp])" + "403([Symbol name=neg])" -> "404([Symbol name=exp])" + "415([Symbol name=convert_element_type])" + "414([Symbol name=mul])" -> "415([Symbol name=convert_element_type])" + "420([Symbol name=convert_element_type])" + "419([Symbol name=add])" -> "420([Symbol name=convert_element_type])" + "335([Symbol name=convert_element_type])" + "334([Symbol name=mul])" -> "335([Symbol name=convert_element_type])" + "399([Symbol name=convert_element_type])" + "398([Symbol name=mul])" -> "399([Symbol name=convert_element_type])" + "439([Symbol name=split])" + "438([Symbol name=transpose])" -> "439([Symbol name=split])" + "484([Symbol name=convert_element_type])" + "483([Symbol name=add])" -> "484([Symbol name=convert_element_type])" + "511([Symbol name=convert_element_type])" + "510([Symbol name=mul])" -> "511([Symbol name=convert_element_type])" + "504([Symbol name=exp])" + "503([Symbol name=neg])" -> "504([Symbol name=exp])" + "515([Symbol name=convert_element_type])" + "514([Symbol name=mul])" -> "515([Symbol name=convert_element_type])" + "520([Symbol name=convert_element_type])" + "519([Symbol name=add])" -> "520([Symbol name=convert_element_type])" + "435([Symbol name=convert_element_type])" + "434([Symbol name=mul])" -> "435([Symbol name=convert_element_type])" + "499([Symbol name=convert_element_type])" + "498([Symbol name=mul])" -> "499([Symbol name=convert_element_type])" + "539([Symbol name=split])" + "538([Symbol name=transpose])" -> "539([Symbol name=split])" + "584([Symbol name=convert_element_type])" + "583([Symbol name=add])" -> "584([Symbol name=convert_element_type])" + "611([Symbol name=convert_element_type])" + "610([Symbol name=mul])" -> "611([Symbol name=convert_element_type])" + "604([Symbol name=exp])" + "603([Symbol name=neg])" -> "604([Symbol name=exp])" + "615([Symbol name=convert_element_type])" + "614([Symbol name=mul])" -> "615([Symbol name=convert_element_type])" + "620([Symbol name=convert_element_type])" + "619([Symbol name=add])" -> "620([Symbol name=convert_element_type])" + "535([Symbol name=convert_element_type])" + "534([Symbol name=mul])" -> "535([Symbol name=convert_element_type])" + "599([Symbol name=convert_element_type])" + "598([Symbol name=mul])" -> "599([Symbol name=convert_element_type])" + "639([Symbol name=split])" + "638([Symbol name=transpose])" -> "639([Symbol name=split])" + "684([Symbol name=convert_element_type])" + "683([Symbol name=add])" -> "684([Symbol name=convert_element_type])" + "711([Symbol name=convert_element_type])" + "710([Symbol name=mul])" -> "711([Symbol name=convert_element_type])" + "704([Symbol name=exp])" + "703([Symbol name=neg])" -> "704([Symbol name=exp])" + "715([Symbol name=convert_element_type])" + "714([Symbol name=mul])" -> "715([Symbol name=convert_element_type])" + "720([Symbol name=convert_element_type])" + "719([Symbol name=add])" -> "720([Symbol name=convert_element_type])" + "635([Symbol name=convert_element_type])" + "634([Symbol name=mul])" -> "635([Symbol name=convert_element_type])" + "699([Symbol name=convert_element_type])" + "698([Symbol name=mul])" -> "699([Symbol name=convert_element_type])" + "739([Symbol name=split])" + "738([Symbol name=transpose])" -> "739([Symbol name=split])" + "784([Symbol name=convert_element_type])" + "783([Symbol name=add])" -> "784([Symbol name=convert_element_type])" + "811([Symbol name=convert_element_type])" + "810([Symbol name=mul])" -> "811([Symbol name=convert_element_type])" + "804([Symbol name=exp])" + "803([Symbol name=neg])" -> "804([Symbol name=exp])" + "815([Symbol name=convert_element_type])" + "814([Symbol name=mul])" -> "815([Symbol name=convert_element_type])" + "820([Symbol name=convert_element_type])" + "819([Symbol name=add])" -> "820([Symbol name=convert_element_type])" + "735([Symbol name=convert_element_type])" + "734([Symbol name=mul])" -> "735([Symbol name=convert_element_type])" + "799([Symbol name=convert_element_type])" + "798([Symbol name=mul])" -> "799([Symbol name=convert_element_type])" + "839([Symbol name=split])" + "838([Symbol name=transpose])" -> "839([Symbol name=split])" + "884([Symbol name=convert_element_type])" + "883([Symbol name=add])" -> "884([Symbol name=convert_element_type])" + "911([Symbol name=convert_element_type])" + "910([Symbol name=mul])" -> "911([Symbol name=convert_element_type])" + "904([Symbol name=exp])" + "903([Symbol name=neg])" -> "904([Symbol name=exp])" + "915([Symbol name=convert_element_type])" + "914([Symbol name=mul])" -> "915([Symbol name=convert_element_type])" + "920([Symbol name=convert_element_type])" + "919([Symbol name=add])" -> "920([Symbol name=convert_element_type])" + "835([Symbol name=convert_element_type])" + "834([Symbol name=mul])" -> "835([Symbol name=convert_element_type])" + "899([Symbol name=convert_element_type])" + "898([Symbol name=mul])" -> "899([Symbol name=convert_element_type])" + "939([Symbol name=split])" + "938([Symbol name=transpose])" -> "939([Symbol name=split])" + "984([Symbol name=convert_element_type])" + "983([Symbol name=add])" -> "984([Symbol name=convert_element_type])" + "1011([Symbol name=convert_element_type])" + "1010([Symbol name=mul])" -> "1011([Symbol name=convert_element_type])" + "1004([Symbol name=exp])" + "1003([Symbol name=neg])" -> "1004([Symbol name=exp])" + "1015([Symbol name=convert_element_type])" + "1014([Symbol name=mul])" -> "1015([Symbol name=convert_element_type])" + "1020([Symbol name=convert_element_type])" + "1019([Symbol name=add])" -> "1020([Symbol name=convert_element_type])" + "935([Symbol name=convert_element_type])" + "934([Symbol name=mul])" -> "935([Symbol name=convert_element_type])" + "999([Symbol name=convert_element_type])" + "998([Symbol name=mul])" -> "999([Symbol name=convert_element_type])" + "1039([Symbol name=split])" + "1038([Symbol name=transpose])" -> "1039([Symbol name=split])" + "1084([Symbol name=convert_element_type])" + "1083([Symbol name=add])" -> "1084([Symbol name=convert_element_type])" + "1111([Symbol name=convert_element_type])" + "1110([Symbol name=mul])" -> "1111([Symbol name=convert_element_type])" + "1104([Symbol name=exp])" + "1103([Symbol name=neg])" -> "1104([Symbol name=exp])" + "1115([Symbol name=convert_element_type])" + "1114([Symbol name=mul])" -> "1115([Symbol name=convert_element_type])" + "1120([Symbol name=convert_element_type])" + "1119([Symbol name=add])" -> "1120([Symbol name=convert_element_type])" + "1035([Symbol name=convert_element_type])" + "1034([Symbol name=mul])" -> "1035([Symbol name=convert_element_type])" + "1099([Symbol name=convert_element_type])" + "1098([Symbol name=mul])" -> "1099([Symbol name=convert_element_type])" + "1139([Symbol name=split])" + "1138([Symbol name=transpose])" -> "1139([Symbol name=split])" + "1184([Symbol name=convert_element_type])" + "1183([Symbol name=add])" -> "1184([Symbol name=convert_element_type])" + "1211([Symbol name=convert_element_type])" + "1210([Symbol name=mul])" -> "1211([Symbol name=convert_element_type])" + "1204([Symbol name=exp])" + "1203([Symbol name=neg])" -> "1204([Symbol name=exp])" + "1215([Symbol name=convert_element_type])" + "1214([Symbol name=mul])" -> "1215([Symbol name=convert_element_type])" + "1220([Symbol name=convert_element_type])" + "1219([Symbol name=add])" -> "1220([Symbol name=convert_element_type])" + "1135([Symbol name=convert_element_type])" + "1134([Symbol name=mul])" -> "1135([Symbol name=convert_element_type])" + "1199([Symbol name=convert_element_type])" + "1198([Symbol name=mul])" -> "1199([Symbol name=convert_element_type])" + "1239([Symbol name=split])" + "1238([Symbol name=transpose])" -> "1239([Symbol name=split])" + "1284([Symbol name=convert_element_type])" + "1283([Symbol name=add])" -> "1284([Symbol name=convert_element_type])" + "1311([Symbol name=convert_element_type])" + "1310([Symbol name=mul])" -> "1311([Symbol name=convert_element_type])" + "1304([Symbol name=exp])" + "1303([Symbol name=neg])" -> "1304([Symbol name=exp])" + "1315([Symbol name=convert_element_type])" + "1314([Symbol name=mul])" -> "1315([Symbol name=convert_element_type])" + "1320([Symbol name=convert_element_type])" + "1319([Symbol name=add])" -> "1320([Symbol name=convert_element_type])" + "1235([Symbol name=convert_element_type])" + "1234([Symbol name=mul])" -> "1235([Symbol name=convert_element_type])" + "1299([Symbol name=convert_element_type])" + "1298([Symbol name=mul])" -> "1299([Symbol name=convert_element_type])" + "1339([Symbol name=split])" + "1338([Symbol name=transpose])" -> "1339([Symbol name=split])" + "1384([Symbol name=convert_element_type])" + "1383([Symbol name=add])" -> "1384([Symbol name=convert_element_type])" + "1411([Symbol name=convert_element_type])" + "1410([Symbol name=mul])" -> "1411([Symbol name=convert_element_type])" + "1404([Symbol name=exp])" + "1403([Symbol name=neg])" -> "1404([Symbol name=exp])" + "1415([Symbol name=convert_element_type])" + "1414([Symbol name=mul])" -> "1415([Symbol name=convert_element_type])" + "1420([Symbol name=convert_element_type])" + "1419([Symbol name=add])" -> "1420([Symbol name=convert_element_type])" + "1335([Symbol name=convert_element_type])" + "1334([Symbol name=mul])" -> "1335([Symbol name=convert_element_type])" + "1399([Symbol name=convert_element_type])" + "1398([Symbol name=mul])" -> "1399([Symbol name=convert_element_type])" + "1439([Symbol name=split])" + "1438([Symbol name=transpose])" -> "1439([Symbol name=split])" + "1484([Symbol name=convert_element_type])" + "1483([Symbol name=add])" -> "1484([Symbol name=convert_element_type])" + "1511([Symbol name=convert_element_type])" + "1510([Symbol name=mul])" -> "1511([Symbol name=convert_element_type])" + "1504([Symbol name=exp])" + "1503([Symbol name=neg])" -> "1504([Symbol name=exp])" + "1515([Symbol name=convert_element_type])" + "1514([Symbol name=mul])" -> "1515([Symbol name=convert_element_type])" + "1520([Symbol name=convert_element_type])" + "1519([Symbol name=add])" -> "1520([Symbol name=convert_element_type])" + "1435([Symbol name=convert_element_type])" + "1434([Symbol name=mul])" -> "1435([Symbol name=convert_element_type])" + "1499([Symbol name=convert_element_type])" + "1498([Symbol name=mul])" -> "1499([Symbol name=convert_element_type])" + "1539([Symbol name=split])" + "1538([Symbol name=transpose])" -> "1539([Symbol name=split])" + "1584([Symbol name=convert_element_type])" + "1583([Symbol name=add])" -> "1584([Symbol name=convert_element_type])" + "1611([Symbol name=convert_element_type])" + "1610([Symbol name=mul])" -> "1611([Symbol name=convert_element_type])" + "1604([Symbol name=exp])" + "1603([Symbol name=neg])" -> "1604([Symbol name=exp])" + "1615([Symbol name=convert_element_type])" + "1614([Symbol name=mul])" -> "1615([Symbol name=convert_element_type])" + "1620([Symbol name=convert_element_type])" + "1619([Symbol name=add])" -> "1620([Symbol name=convert_element_type])" + "1535([Symbol name=convert_element_type])" + "1534([Symbol name=mul])" -> "1535([Symbol name=convert_element_type])" + "1599([Symbol name=convert_element_type])" + "1598([Symbol name=mul])" -> "1599([Symbol name=convert_element_type])" + "1639([Symbol name=split])" + "1638([Symbol name=transpose])" -> "1639([Symbol name=split])" + "1684([Symbol name=convert_element_type])" + "1683([Symbol name=add])" -> "1684([Symbol name=convert_element_type])" + "1711([Symbol name=convert_element_type])" + "1710([Symbol name=mul])" -> "1711([Symbol name=convert_element_type])" + "1704([Symbol name=exp])" + "1703([Symbol name=neg])" -> "1704([Symbol name=exp])" + "1715([Symbol name=convert_element_type])" + "1714([Symbol name=mul])" -> "1715([Symbol name=convert_element_type])" + "1720([Symbol name=convert_element_type])" + "1719([Symbol name=add])" -> "1720([Symbol name=convert_element_type])" + "1635([Symbol name=convert_element_type])" + "1634([Symbol name=mul])" -> "1635([Symbol name=convert_element_type])" + "1699([Symbol name=convert_element_type])" + "1698([Symbol name=mul])" -> "1699([Symbol name=convert_element_type])" + "1735([Symbol name=convert_element_type])" + "1734([Symbol name=mul])" -> "1735([Symbol name=convert_element_type])" + "132([Symbol name=convert_element_type])" + "130([Symbol name=convert_element_type])" -> "132([Symbol name=convert_element_type])" + "124([Symbol name=broadcast_in_dim])" + "123([Symbol name=sum])" -> "124([Symbol name=broadcast_in_dim])" + "185([Symbol name=convert_element_type])" + "184([Symbol name=convert_element_type])" -> "185([Symbol name=convert_element_type])" + "218([Symbol name=convert_element_type])" + "184([Symbol name=convert_element_type])" -> "218([Symbol name=convert_element_type])" + "1672([Symbol name=convert_element_type])" + "1671([Symbol name=add])" -> "1672([Symbol name=convert_element_type])" + "272([Symbol name=convert_element_type])" + "271([Symbol name=add])" -> "272([Symbol name=convert_element_type])" + "657([Symbol name=convert_element_type])" + "656([Symbol name=add])" -> "657([Symbol name=convert_element_type])" + "1172([Symbol name=convert_element_type])" + "1171([Symbol name=add])" -> "1172([Symbol name=convert_element_type])" + "1557([Symbol name=convert_element_type])" + "1556([Symbol name=add])" -> "1557([Symbol name=convert_element_type])" + "157([Symbol name=convert_element_type])" + "156([Symbol name=add])" -> "157([Symbol name=convert_element_type])" + "672([Symbol name=convert_element_type])" + "671([Symbol name=add])" -> "672([Symbol name=convert_element_type])" + "1057([Symbol name=convert_element_type])" + "1056([Symbol name=add])" -> "1057([Symbol name=convert_element_type])" + "1572([Symbol name=convert_element_type])" + "1571([Symbol name=add])" -> "1572([Symbol name=convert_element_type])" + "172([Symbol name=convert_element_type])" + "171([Symbol name=add])" -> "172([Symbol name=convert_element_type])" + "557([Symbol name=convert_element_type])" + "556([Symbol name=add])" -> "557([Symbol name=convert_element_type])" + "1072([Symbol name=convert_element_type])" + "1071([Symbol name=add])" -> "1072([Symbol name=convert_element_type])" + "1457([Symbol name=convert_element_type])" + "1456([Symbol name=add])" -> "1457([Symbol name=convert_element_type])" + "572([Symbol name=convert_element_type])" + "571([Symbol name=add])" -> "572([Symbol name=convert_element_type])" + "957([Symbol name=convert_element_type])" + "956([Symbol name=add])" -> "957([Symbol name=convert_element_type])" + "1472([Symbol name=convert_element_type])" + "1471([Symbol name=add])" -> "1472([Symbol name=convert_element_type])" + "457([Symbol name=convert_element_type])" + "456([Symbol name=add])" -> "457([Symbol name=convert_element_type])" + "972([Symbol name=convert_element_type])" + "971([Symbol name=add])" -> "972([Symbol name=convert_element_type])" + "1357([Symbol name=convert_element_type])" + "1356([Symbol name=add])" -> "1357([Symbol name=convert_element_type])" + "472([Symbol name=convert_element_type])" + "471([Symbol name=add])" -> "472([Symbol name=convert_element_type])" + "857([Symbol name=convert_element_type])" + "856([Symbol name=add])" -> "857([Symbol name=convert_element_type])" + "1372([Symbol name=convert_element_type])" + "1371([Symbol name=add])" -> "1372([Symbol name=convert_element_type])" + "357([Symbol name=convert_element_type])" + "356([Symbol name=add])" -> "357([Symbol name=convert_element_type])" + "872([Symbol name=convert_element_type])" + "871([Symbol name=add])" -> "872([Symbol name=convert_element_type])" + "1257([Symbol name=convert_element_type])" + "1256([Symbol name=add])" -> "1257([Symbol name=convert_element_type])" + "372([Symbol name=convert_element_type])" + "371([Symbol name=add])" -> "372([Symbol name=convert_element_type])" + "757([Symbol name=convert_element_type])" + "756([Symbol name=add])" -> "757([Symbol name=convert_element_type])" + "1272([Symbol name=convert_element_type])" + "1271([Symbol name=add])" -> "1272([Symbol name=convert_element_type])" + "1657([Symbol name=convert_element_type])" + "1656([Symbol name=add])" -> "1657([Symbol name=convert_element_type])" + "257([Symbol name=convert_element_type])" + "256([Symbol name=add])" -> "257([Symbol name=convert_element_type])" + "772([Symbol name=convert_element_type])" + "771([Symbol name=add])" -> "772([Symbol name=convert_element_type])" + "1157([Symbol name=convert_element_type])" + "1156([Symbol name=add])" -> "1157([Symbol name=convert_element_type])" + "140([Symbol name=reshape])" + "139([Symbol name=split])" -> "140([Symbol name=reshape])" + "141([Symbol name=reshape])" + "139([Symbol name=split])" -> "141([Symbol name=reshape])" + "142([Symbol name=reshape])" + "139([Symbol name=split])" -> "142([Symbol name=reshape])" + "212([Symbol name=convert_element_type])" + "211([Symbol name=convert_element_type])" -> "212([Symbol name=convert_element_type])" + "205([Symbol name=add])" + "204([Symbol name=exp])" -> "205([Symbol name=add])" + "282([Symbol name=convert_element_type])" + "220([Symbol name=convert_element_type])" -> "282([Symbol name=convert_element_type])" + "221([Symbol name=convert_element_type])" + "220([Symbol name=convert_element_type])" -> "221([Symbol name=convert_element_type])" + "240([Symbol name=reshape])" + "239([Symbol name=split])" -> "240([Symbol name=reshape])" + "241([Symbol name=reshape])" + "239([Symbol name=split])" -> "241([Symbol name=reshape])" + "242([Symbol name=reshape])" + "239([Symbol name=split])" -> "242([Symbol name=reshape])" + "285([Symbol name=convert_element_type])" + "284([Symbol name=convert_element_type])" -> "285([Symbol name=convert_element_type])" + "318([Symbol name=convert_element_type])" + "284([Symbol name=convert_element_type])" -> "318([Symbol name=convert_element_type])" + "312([Symbol name=convert_element_type])" + "311([Symbol name=convert_element_type])" -> "312([Symbol name=convert_element_type])" + "305([Symbol name=add])" + "304([Symbol name=exp])" -> "305([Symbol name=add])" + "321([Symbol name=convert_element_type])" + "320([Symbol name=convert_element_type])" -> "321([Symbol name=convert_element_type])" + "382([Symbol name=convert_element_type])" + "320([Symbol name=convert_element_type])" -> "382([Symbol name=convert_element_type])" + "340([Symbol name=reshape])" + "339([Symbol name=split])" -> "340([Symbol name=reshape])" + "341([Symbol name=reshape])" + "339([Symbol name=split])" -> "341([Symbol name=reshape])" + "342([Symbol name=reshape])" + "339([Symbol name=split])" -> "342([Symbol name=reshape])" + "385([Symbol name=convert_element_type])" + "384([Symbol name=convert_element_type])" -> "385([Symbol name=convert_element_type])" + "418([Symbol name=convert_element_type])" + "384([Symbol name=convert_element_type])" -> "418([Symbol name=convert_element_type])" + "412([Symbol name=convert_element_type])" + "411([Symbol name=convert_element_type])" -> "412([Symbol name=convert_element_type])" + "405([Symbol name=add])" + "404([Symbol name=exp])" -> "405([Symbol name=add])" + "482([Symbol name=convert_element_type])" + "420([Symbol name=convert_element_type])" -> "482([Symbol name=convert_element_type])" + "421([Symbol name=convert_element_type])" + "420([Symbol name=convert_element_type])" -> "421([Symbol name=convert_element_type])" + "440([Symbol name=reshape])" + "439([Symbol name=split])" -> "440([Symbol name=reshape])" + "441([Symbol name=reshape])" + "439([Symbol name=split])" -> "441([Symbol name=reshape])" + "442([Symbol name=reshape])" + "439([Symbol name=split])" -> "442([Symbol name=reshape])" + "485([Symbol name=convert_element_type])" + "484([Symbol name=convert_element_type])" -> "485([Symbol name=convert_element_type])" + "518([Symbol name=convert_element_type])" + "484([Symbol name=convert_element_type])" -> "518([Symbol name=convert_element_type])" + "512([Symbol name=convert_element_type])" + "511([Symbol name=convert_element_type])" -> "512([Symbol name=convert_element_type])" + "505([Symbol name=add])" + "504([Symbol name=exp])" -> "505([Symbol name=add])" + "521([Symbol name=convert_element_type])" + "520([Symbol name=convert_element_type])" -> "521([Symbol name=convert_element_type])" + "582([Symbol name=convert_element_type])" + "520([Symbol name=convert_element_type])" -> "582([Symbol name=convert_element_type])" + "540([Symbol name=reshape])" + "539([Symbol name=split])" -> "540([Symbol name=reshape])" + "541([Symbol name=reshape])" + "539([Symbol name=split])" -> "541([Symbol name=reshape])" + "542([Symbol name=reshape])" + "539([Symbol name=split])" -> "542([Symbol name=reshape])" + "585([Symbol name=convert_element_type])" + "584([Symbol name=convert_element_type])" -> "585([Symbol name=convert_element_type])" + "618([Symbol name=convert_element_type])" + "584([Symbol name=convert_element_type])" -> "618([Symbol name=convert_element_type])" + "612([Symbol name=convert_element_type])" + "611([Symbol name=convert_element_type])" -> "612([Symbol name=convert_element_type])" + "605([Symbol name=add])" + "604([Symbol name=exp])" -> "605([Symbol name=add])" + "682([Symbol name=convert_element_type])" + "620([Symbol name=convert_element_type])" -> "682([Symbol name=convert_element_type])" + "621([Symbol name=convert_element_type])" + "620([Symbol name=convert_element_type])" -> "621([Symbol name=convert_element_type])" + "640([Symbol name=reshape])" + "639([Symbol name=split])" -> "640([Symbol name=reshape])" + "641([Symbol name=reshape])" + "639([Symbol name=split])" -> "641([Symbol name=reshape])" + "642([Symbol name=reshape])" + "639([Symbol name=split])" -> "642([Symbol name=reshape])" + "685([Symbol name=convert_element_type])" + "684([Symbol name=convert_element_type])" -> "685([Symbol name=convert_element_type])" + "718([Symbol name=convert_element_type])" + "684([Symbol name=convert_element_type])" -> "718([Symbol name=convert_element_type])" + "712([Symbol name=convert_element_type])" + "711([Symbol name=convert_element_type])" -> "712([Symbol name=convert_element_type])" + "705([Symbol name=add])" + "704([Symbol name=exp])" -> "705([Symbol name=add])" + "721([Symbol name=convert_element_type])" + "720([Symbol name=convert_element_type])" -> "721([Symbol name=convert_element_type])" + "782([Symbol name=convert_element_type])" + "720([Symbol name=convert_element_type])" -> "782([Symbol name=convert_element_type])" + "740([Symbol name=reshape])" + "739([Symbol name=split])" -> "740([Symbol name=reshape])" + "741([Symbol name=reshape])" + "739([Symbol name=split])" -> "741([Symbol name=reshape])" + "742([Symbol name=reshape])" + "739([Symbol name=split])" -> "742([Symbol name=reshape])" + "785([Symbol name=convert_element_type])" + "784([Symbol name=convert_element_type])" -> "785([Symbol name=convert_element_type])" + "818([Symbol name=convert_element_type])" + "784([Symbol name=convert_element_type])" -> "818([Symbol name=convert_element_type])" + "812([Symbol name=convert_element_type])" + "811([Symbol name=convert_element_type])" -> "812([Symbol name=convert_element_type])" + "805([Symbol name=add])" + "804([Symbol name=exp])" -> "805([Symbol name=add])" + "882([Symbol name=convert_element_type])" + "820([Symbol name=convert_element_type])" -> "882([Symbol name=convert_element_type])" + "821([Symbol name=convert_element_type])" + "820([Symbol name=convert_element_type])" -> "821([Symbol name=convert_element_type])" + "840([Symbol name=reshape])" + "839([Symbol name=split])" -> "840([Symbol name=reshape])" + "841([Symbol name=reshape])" + "839([Symbol name=split])" -> "841([Symbol name=reshape])" + "842([Symbol name=reshape])" + "839([Symbol name=split])" -> "842([Symbol name=reshape])" + "885([Symbol name=convert_element_type])" + "884([Symbol name=convert_element_type])" -> "885([Symbol name=convert_element_type])" + "918([Symbol name=convert_element_type])" + "884([Symbol name=convert_element_type])" -> "918([Symbol name=convert_element_type])" + "912([Symbol name=convert_element_type])" + "911([Symbol name=convert_element_type])" -> "912([Symbol name=convert_element_type])" + "905([Symbol name=add])" + "904([Symbol name=exp])" -> "905([Symbol name=add])" + "921([Symbol name=convert_element_type])" + "920([Symbol name=convert_element_type])" -> "921([Symbol name=convert_element_type])" + "982([Symbol name=convert_element_type])" + "920([Symbol name=convert_element_type])" -> "982([Symbol name=convert_element_type])" + "940([Symbol name=reshape])" + "939([Symbol name=split])" -> "940([Symbol name=reshape])" + "941([Symbol name=reshape])" + "939([Symbol name=split])" -> "941([Symbol name=reshape])" + "942([Symbol name=reshape])" + "939([Symbol name=split])" -> "942([Symbol name=reshape])" + "985([Symbol name=convert_element_type])" + "984([Symbol name=convert_element_type])" -> "985([Symbol name=convert_element_type])" + "1018([Symbol name=convert_element_type])" + "984([Symbol name=convert_element_type])" -> "1018([Symbol name=convert_element_type])" + "1012([Symbol name=convert_element_type])" + "1011([Symbol name=convert_element_type])" -> "1012([Symbol name=convert_element_type])" + "1005([Symbol name=add])" + "1004([Symbol name=exp])" -> "1005([Symbol name=add])" + "1082([Symbol name=convert_element_type])" + "1020([Symbol name=convert_element_type])" -> "1082([Symbol name=convert_element_type])" + "1021([Symbol name=convert_element_type])" + "1020([Symbol name=convert_element_type])" -> "1021([Symbol name=convert_element_type])" + "1040([Symbol name=reshape])" + "1039([Symbol name=split])" -> "1040([Symbol name=reshape])" + "1041([Symbol name=reshape])" + "1039([Symbol name=split])" -> "1041([Symbol name=reshape])" + "1042([Symbol name=reshape])" + "1039([Symbol name=split])" -> "1042([Symbol name=reshape])" + "1085([Symbol name=convert_element_type])" + "1084([Symbol name=convert_element_type])" -> "1085([Symbol name=convert_element_type])" + "1118([Symbol name=convert_element_type])" + "1084([Symbol name=convert_element_type])" -> "1118([Symbol name=convert_element_type])" + "1112([Symbol name=convert_element_type])" + "1111([Symbol name=convert_element_type])" -> "1112([Symbol name=convert_element_type])" + "1105([Symbol name=add])" + "1104([Symbol name=exp])" -> "1105([Symbol name=add])" + "1121([Symbol name=convert_element_type])" + "1120([Symbol name=convert_element_type])" -> "1121([Symbol name=convert_element_type])" + "1182([Symbol name=convert_element_type])" + "1120([Symbol name=convert_element_type])" -> "1182([Symbol name=convert_element_type])" + "1140([Symbol name=reshape])" + "1139([Symbol name=split])" -> "1140([Symbol name=reshape])" + "1141([Symbol name=reshape])" + "1139([Symbol name=split])" -> "1141([Symbol name=reshape])" + "1142([Symbol name=reshape])" + "1139([Symbol name=split])" -> "1142([Symbol name=reshape])" + "1185([Symbol name=convert_element_type])" + "1184([Symbol name=convert_element_type])" -> "1185([Symbol name=convert_element_type])" + "1218([Symbol name=convert_element_type])" + "1184([Symbol name=convert_element_type])" -> "1218([Symbol name=convert_element_type])" + "1212([Symbol name=convert_element_type])" + "1211([Symbol name=convert_element_type])" -> "1212([Symbol name=convert_element_type])" + "1205([Symbol name=add])" + "1204([Symbol name=exp])" -> "1205([Symbol name=add])" + "1282([Symbol name=convert_element_type])" + "1220([Symbol name=convert_element_type])" -> "1282([Symbol name=convert_element_type])" + "1221([Symbol name=convert_element_type])" + "1220([Symbol name=convert_element_type])" -> "1221([Symbol name=convert_element_type])" + "1240([Symbol name=reshape])" + "1239([Symbol name=split])" -> "1240([Symbol name=reshape])" + "1241([Symbol name=reshape])" + "1239([Symbol name=split])" -> "1241([Symbol name=reshape])" + "1242([Symbol name=reshape])" + "1239([Symbol name=split])" -> "1242([Symbol name=reshape])" + "1285([Symbol name=convert_element_type])" + "1284([Symbol name=convert_element_type])" -> "1285([Symbol name=convert_element_type])" + "1318([Symbol name=convert_element_type])" + "1284([Symbol name=convert_element_type])" -> "1318([Symbol name=convert_element_type])" + "1312([Symbol name=convert_element_type])" + "1311([Symbol name=convert_element_type])" -> "1312([Symbol name=convert_element_type])" + "1305([Symbol name=add])" + "1304([Symbol name=exp])" -> "1305([Symbol name=add])" + "1321([Symbol name=convert_element_type])" + "1320([Symbol name=convert_element_type])" -> "1321([Symbol name=convert_element_type])" + "1382([Symbol name=convert_element_type])" + "1320([Symbol name=convert_element_type])" -> "1382([Symbol name=convert_element_type])" + "1340([Symbol name=reshape])" + "1339([Symbol name=split])" -> "1340([Symbol name=reshape])" + "1341([Symbol name=reshape])" + "1339([Symbol name=split])" -> "1341([Symbol name=reshape])" + "1342([Symbol name=reshape])" + "1339([Symbol name=split])" -> "1342([Symbol name=reshape])" + "1385([Symbol name=convert_element_type])" + "1384([Symbol name=convert_element_type])" -> "1385([Symbol name=convert_element_type])" + "1418([Symbol name=convert_element_type])" + "1384([Symbol name=convert_element_type])" -> "1418([Symbol name=convert_element_type])" + "1412([Symbol name=convert_element_type])" + "1411([Symbol name=convert_element_type])" -> "1412([Symbol name=convert_element_type])" + "1405([Symbol name=add])" + "1404([Symbol name=exp])" -> "1405([Symbol name=add])" + "1482([Symbol name=convert_element_type])" + "1420([Symbol name=convert_element_type])" -> "1482([Symbol name=convert_element_type])" + "1421([Symbol name=convert_element_type])" + "1420([Symbol name=convert_element_type])" -> "1421([Symbol name=convert_element_type])" + "1440([Symbol name=reshape])" + "1439([Symbol name=split])" -> "1440([Symbol name=reshape])" + "1441([Symbol name=reshape])" + "1439([Symbol name=split])" -> "1441([Symbol name=reshape])" + "1442([Symbol name=reshape])" + "1439([Symbol name=split])" -> "1442([Symbol name=reshape])" + "1485([Symbol name=convert_element_type])" + "1484([Symbol name=convert_element_type])" -> "1485([Symbol name=convert_element_type])" + "1518([Symbol name=convert_element_type])" + "1484([Symbol name=convert_element_type])" -> "1518([Symbol name=convert_element_type])" + "1512([Symbol name=convert_element_type])" + "1511([Symbol name=convert_element_type])" -> "1512([Symbol name=convert_element_type])" + "1505([Symbol name=add])" + "1504([Symbol name=exp])" -> "1505([Symbol name=add])" + "1521([Symbol name=convert_element_type])" + "1520([Symbol name=convert_element_type])" -> "1521([Symbol name=convert_element_type])" + "1582([Symbol name=convert_element_type])" + "1520([Symbol name=convert_element_type])" -> "1582([Symbol name=convert_element_type])" + "1540([Symbol name=reshape])" + "1539([Symbol name=split])" -> "1540([Symbol name=reshape])" + "1541([Symbol name=reshape])" + "1539([Symbol name=split])" -> "1541([Symbol name=reshape])" + "1542([Symbol name=reshape])" + "1539([Symbol name=split])" -> "1542([Symbol name=reshape])" + "1585([Symbol name=convert_element_type])" + "1584([Symbol name=convert_element_type])" -> "1585([Symbol name=convert_element_type])" + "1618([Symbol name=convert_element_type])" + "1584([Symbol name=convert_element_type])" -> "1618([Symbol name=convert_element_type])" + "1612([Symbol name=convert_element_type])" + "1611([Symbol name=convert_element_type])" -> "1612([Symbol name=convert_element_type])" + "1605([Symbol name=add])" + "1604([Symbol name=exp])" -> "1605([Symbol name=add])" + "1682([Symbol name=convert_element_type])" + "1620([Symbol name=convert_element_type])" -> "1682([Symbol name=convert_element_type])" + "1621([Symbol name=convert_element_type])" + "1620([Symbol name=convert_element_type])" -> "1621([Symbol name=convert_element_type])" + "1640([Symbol name=reshape])" + "1639([Symbol name=split])" -> "1640([Symbol name=reshape])" + "1641([Symbol name=reshape])" + "1639([Symbol name=split])" -> "1641([Symbol name=reshape])" + "1642([Symbol name=reshape])" + "1639([Symbol name=split])" -> "1642([Symbol name=reshape])" + "1685([Symbol name=convert_element_type])" + "1684([Symbol name=convert_element_type])" -> "1685([Symbol name=convert_element_type])" + "1718([Symbol name=convert_element_type])" + "1684([Symbol name=convert_element_type])" -> "1718([Symbol name=convert_element_type])" + "1712([Symbol name=convert_element_type])" + "1711([Symbol name=convert_element_type])" -> "1712([Symbol name=convert_element_type])" + "1705([Symbol name=add])" + "1704([Symbol name=exp])" -> "1705([Symbol name=add])" + "1721([Symbol name=convert_element_type])" + "1720([Symbol name=convert_element_type])" -> "1721([Symbol name=convert_element_type])" + "125([Symbol name=true_divide])" + "124([Symbol name=broadcast_in_dim])" -> "125([Symbol name=true_divide])" + "193([Symbol name=mul])" + "192([Symbol name=broadcast_in_dim])" -> "193([Symbol name=mul])" + "185([Symbol name=convert_element_type])" -> "193([Symbol name=mul])" + "186([Symbol name=mul])" + "185([Symbol name=convert_element_type])" -> "186([Symbol name=mul])" + "1676([Symbol name=cat])" + "1672([Symbol name=convert_element_type])" -> "1676([Symbol name=cat])" + "1675([Symbol name=slice_prim])" -> "1676([Symbol name=cat])" + "276([Symbol name=cat])" + "272([Symbol name=convert_element_type])" -> "276([Symbol name=cat])" + "275([Symbol name=slice_prim])" -> "276([Symbol name=cat])" + "674([Symbol name=cat])" + "657([Symbol name=convert_element_type])" -> "674([Symbol name=cat])" + "673([Symbol name=slice_prim])" -> "674([Symbol name=cat])" + "1176([Symbol name=cat])" + "1172([Symbol name=convert_element_type])" -> "1176([Symbol name=cat])" + "1175([Symbol name=slice_prim])" -> "1176([Symbol name=cat])" + "1574([Symbol name=cat])" + "1573([Symbol name=slice_prim])" -> "1574([Symbol name=cat])" + "1557([Symbol name=convert_element_type])" -> "1574([Symbol name=cat])" + "174([Symbol name=cat])" + "157([Symbol name=convert_element_type])" -> "174([Symbol name=cat])" + "173([Symbol name=slice_prim])" -> "174([Symbol name=cat])" + "676([Symbol name=cat])" + "672([Symbol name=convert_element_type])" -> "676([Symbol name=cat])" + "675([Symbol name=slice_prim])" -> "676([Symbol name=cat])" + "1074([Symbol name=cat])" + "1057([Symbol name=convert_element_type])" -> "1074([Symbol name=cat])" + "1073([Symbol name=slice_prim])" -> "1074([Symbol name=cat])" + "1576([Symbol name=cat])" + "1572([Symbol name=convert_element_type])" -> "1576([Symbol name=cat])" + "1575([Symbol name=slice_prim])" -> "1576([Symbol name=cat])" + "176([Symbol name=cat])" + "172([Symbol name=convert_element_type])" -> "176([Symbol name=cat])" + "175([Symbol name=slice_prim])" -> "176([Symbol name=cat])" + "574([Symbol name=cat])" + "573([Symbol name=slice_prim])" -> "574([Symbol name=cat])" + "557([Symbol name=convert_element_type])" -> "574([Symbol name=cat])" + "1076([Symbol name=cat])" + "1072([Symbol name=convert_element_type])" -> "1076([Symbol name=cat])" + "1075([Symbol name=slice_prim])" -> "1076([Symbol name=cat])" + "1474([Symbol name=cat])" + "1457([Symbol name=convert_element_type])" -> "1474([Symbol name=cat])" + "1473([Symbol name=slice_prim])" -> "1474([Symbol name=cat])" + "576([Symbol name=cat])" + "572([Symbol name=convert_element_type])" -> "576([Symbol name=cat])" + "575([Symbol name=slice_prim])" -> "576([Symbol name=cat])" + "974([Symbol name=cat])" + "973([Symbol name=slice_prim])" -> "974([Symbol name=cat])" + "957([Symbol name=convert_element_type])" -> "974([Symbol name=cat])" + "1476([Symbol name=cat])" + "1472([Symbol name=convert_element_type])" -> "1476([Symbol name=cat])" + "1475([Symbol name=slice_prim])" -> "1476([Symbol name=cat])" + "474([Symbol name=cat])" + "457([Symbol name=convert_element_type])" -> "474([Symbol name=cat])" + "473([Symbol name=slice_prim])" -> "474([Symbol name=cat])" + "976([Symbol name=cat])" + "972([Symbol name=convert_element_type])" -> "976([Symbol name=cat])" + "975([Symbol name=slice_prim])" -> "976([Symbol name=cat])" + "1374([Symbol name=cat])" + "1373([Symbol name=slice_prim])" -> "1374([Symbol name=cat])" + "1357([Symbol name=convert_element_type])" -> "1374([Symbol name=cat])" + "476([Symbol name=cat])" + "472([Symbol name=convert_element_type])" -> "476([Symbol name=cat])" + "475([Symbol name=slice_prim])" -> "476([Symbol name=cat])" + "874([Symbol name=cat])" + "857([Symbol name=convert_element_type])" -> "874([Symbol name=cat])" + "873([Symbol name=slice_prim])" -> "874([Symbol name=cat])" + "1376([Symbol name=cat])" + "1372([Symbol name=convert_element_type])" -> "1376([Symbol name=cat])" + "1375([Symbol name=slice_prim])" -> "1376([Symbol name=cat])" + "374([Symbol name=cat])" + "373([Symbol name=slice_prim])" -> "374([Symbol name=cat])" + "357([Symbol name=convert_element_type])" -> "374([Symbol name=cat])" + "876([Symbol name=cat])" + "872([Symbol name=convert_element_type])" -> "876([Symbol name=cat])" + "875([Symbol name=slice_prim])" -> "876([Symbol name=cat])" + "1274([Symbol name=cat])" + "1257([Symbol name=convert_element_type])" -> "1274([Symbol name=cat])" + "1273([Symbol name=slice_prim])" -> "1274([Symbol name=cat])" + "376([Symbol name=cat])" + "372([Symbol name=convert_element_type])" -> "376([Symbol name=cat])" + "375([Symbol name=slice_prim])" -> "376([Symbol name=cat])" + "774([Symbol name=cat])" + "773([Symbol name=slice_prim])" -> "774([Symbol name=cat])" + "757([Symbol name=convert_element_type])" -> "774([Symbol name=cat])" + "1276([Symbol name=cat])" + "1272([Symbol name=convert_element_type])" -> "1276([Symbol name=cat])" + "1275([Symbol name=slice_prim])" -> "1276([Symbol name=cat])" + "1674([Symbol name=cat])" + "1657([Symbol name=convert_element_type])" -> "1674([Symbol name=cat])" + "1673([Symbol name=slice_prim])" -> "1674([Symbol name=cat])" + "274([Symbol name=cat])" + "257([Symbol name=convert_element_type])" -> "274([Symbol name=cat])" + "273([Symbol name=slice_prim])" -> "274([Symbol name=cat])" + "776([Symbol name=cat])" + "772([Symbol name=convert_element_type])" -> "776([Symbol name=cat])" + "775([Symbol name=slice_prim])" -> "776([Symbol name=cat])" + "1174([Symbol name=cat])" + "1157([Symbol name=convert_element_type])" -> "1174([Symbol name=cat])" + "1173([Symbol name=slice_prim])" -> "1174([Symbol name=cat])" + "173([Symbol name=slice_prim])" + "140([Symbol name=reshape])" -> "173([Symbol name=slice_prim])" + "143([Symbol name=slice_prim])" + "140([Symbol name=reshape])" -> "143([Symbol name=slice_prim])" + "158([Symbol name=slice_prim])" + "141([Symbol name=reshape])" -> "158([Symbol name=slice_prim])" + "175([Symbol name=slice_prim])" + "141([Symbol name=reshape])" -> "175([Symbol name=slice_prim])" + "177([Symbol name=cudnn_sdpa_fwd])" + "176([Symbol name=cat])" -> "177([Symbol name=cudnn_sdpa_fwd])" + "142([Symbol name=reshape])" -> "177([Symbol name=cudnn_sdpa_fwd])" + "174([Symbol name=cat])" -> "177([Symbol name=cudnn_sdpa_fwd])" + "206([Symbol name=reciprocal])" + "205([Symbol name=add])" -> "206([Symbol name=reciprocal])" + "229([Symbol name=mul])" + "228([Symbol name=broadcast_in_dim])" -> "229([Symbol name=mul])" + "221([Symbol name=convert_element_type])" -> "229([Symbol name=mul])" + "222([Symbol name=mul])" + "221([Symbol name=convert_element_type])" -> "222([Symbol name=mul])" + "273([Symbol name=slice_prim])" + "240([Symbol name=reshape])" -> "273([Symbol name=slice_prim])" + "243([Symbol name=slice_prim])" + "240([Symbol name=reshape])" -> "243([Symbol name=slice_prim])" + "258([Symbol name=slice_prim])" + "241([Symbol name=reshape])" -> "258([Symbol name=slice_prim])" + "275([Symbol name=slice_prim])" + "241([Symbol name=reshape])" -> "275([Symbol name=slice_prim])" + "277([Symbol name=cudnn_sdpa_fwd])" + "274([Symbol name=cat])" -> "277([Symbol name=cudnn_sdpa_fwd])" + "242([Symbol name=reshape])" -> "277([Symbol name=cudnn_sdpa_fwd])" + "276([Symbol name=cat])" -> "277([Symbol name=cudnn_sdpa_fwd])" + "293([Symbol name=mul])" + "292([Symbol name=broadcast_in_dim])" -> "293([Symbol name=mul])" + "285([Symbol name=convert_element_type])" -> "293([Symbol name=mul])" + "286([Symbol name=mul])" + "285([Symbol name=convert_element_type])" -> "286([Symbol name=mul])" + "306([Symbol name=reciprocal])" + "305([Symbol name=add])" -> "306([Symbol name=reciprocal])" + "329([Symbol name=mul])" + "328([Symbol name=broadcast_in_dim])" -> "329([Symbol name=mul])" + "321([Symbol name=convert_element_type])" -> "329([Symbol name=mul])" + "322([Symbol name=mul])" + "321([Symbol name=convert_element_type])" -> "322([Symbol name=mul])" + "373([Symbol name=slice_prim])" + "340([Symbol name=reshape])" -> "373([Symbol name=slice_prim])" + "343([Symbol name=slice_prim])" + "340([Symbol name=reshape])" -> "343([Symbol name=slice_prim])" + "358([Symbol name=slice_prim])" + "341([Symbol name=reshape])" -> "358([Symbol name=slice_prim])" + "375([Symbol name=slice_prim])" + "341([Symbol name=reshape])" -> "375([Symbol name=slice_prim])" + "377([Symbol name=cudnn_sdpa_fwd])" + "376([Symbol name=cat])" -> "377([Symbol name=cudnn_sdpa_fwd])" + "342([Symbol name=reshape])" -> "377([Symbol name=cudnn_sdpa_fwd])" + "374([Symbol name=cat])" -> "377([Symbol name=cudnn_sdpa_fwd])" + "393([Symbol name=mul])" + "392([Symbol name=broadcast_in_dim])" -> "393([Symbol name=mul])" + "385([Symbol name=convert_element_type])" -> "393([Symbol name=mul])" + "386([Symbol name=mul])" + "385([Symbol name=convert_element_type])" -> "386([Symbol name=mul])" + "406([Symbol name=reciprocal])" + "405([Symbol name=add])" -> "406([Symbol name=reciprocal])" + "429([Symbol name=mul])" + "428([Symbol name=broadcast_in_dim])" -> "429([Symbol name=mul])" + "421([Symbol name=convert_element_type])" -> "429([Symbol name=mul])" + "422([Symbol name=mul])" + "421([Symbol name=convert_element_type])" -> "422([Symbol name=mul])" + "473([Symbol name=slice_prim])" + "440([Symbol name=reshape])" -> "473([Symbol name=slice_prim])" + "443([Symbol name=slice_prim])" + "440([Symbol name=reshape])" -> "443([Symbol name=slice_prim])" + "458([Symbol name=slice_prim])" + "441([Symbol name=reshape])" -> "458([Symbol name=slice_prim])" + "475([Symbol name=slice_prim])" + "441([Symbol name=reshape])" -> "475([Symbol name=slice_prim])" + "477([Symbol name=cudnn_sdpa_fwd])" + "442([Symbol name=reshape])" -> "477([Symbol name=cudnn_sdpa_fwd])" + "474([Symbol name=cat])" -> "477([Symbol name=cudnn_sdpa_fwd])" + "476([Symbol name=cat])" -> "477([Symbol name=cudnn_sdpa_fwd])" + "493([Symbol name=mul])" + "492([Symbol name=broadcast_in_dim])" -> "493([Symbol name=mul])" + "485([Symbol name=convert_element_type])" -> "493([Symbol name=mul])" + "486([Symbol name=mul])" + "485([Symbol name=convert_element_type])" -> "486([Symbol name=mul])" + "506([Symbol name=reciprocal])" + "505([Symbol name=add])" -> "506([Symbol name=reciprocal])" + "529([Symbol name=mul])" + "528([Symbol name=broadcast_in_dim])" -> "529([Symbol name=mul])" + "521([Symbol name=convert_element_type])" -> "529([Symbol name=mul])" + "522([Symbol name=mul])" + "521([Symbol name=convert_element_type])" -> "522([Symbol name=mul])" + "573([Symbol name=slice_prim])" + "540([Symbol name=reshape])" -> "573([Symbol name=slice_prim])" + "543([Symbol name=slice_prim])" + "540([Symbol name=reshape])" -> "543([Symbol name=slice_prim])" + "558([Symbol name=slice_prim])" + "541([Symbol name=reshape])" -> "558([Symbol name=slice_prim])" + "575([Symbol name=slice_prim])" + "541([Symbol name=reshape])" -> "575([Symbol name=slice_prim])" + "577([Symbol name=cudnn_sdpa_fwd])" + "576([Symbol name=cat])" -> "577([Symbol name=cudnn_sdpa_fwd])" + "574([Symbol name=cat])" -> "577([Symbol name=cudnn_sdpa_fwd])" + "542([Symbol name=reshape])" -> "577([Symbol name=cudnn_sdpa_fwd])" + "593([Symbol name=mul])" + "592([Symbol name=broadcast_in_dim])" -> "593([Symbol name=mul])" + "585([Symbol name=convert_element_type])" -> "593([Symbol name=mul])" + "586([Symbol name=mul])" + "585([Symbol name=convert_element_type])" -> "586([Symbol name=mul])" + "606([Symbol name=reciprocal])" + "605([Symbol name=add])" -> "606([Symbol name=reciprocal])" + "629([Symbol name=mul])" + "628([Symbol name=broadcast_in_dim])" -> "629([Symbol name=mul])" + "621([Symbol name=convert_element_type])" -> "629([Symbol name=mul])" + "622([Symbol name=mul])" + "621([Symbol name=convert_element_type])" -> "622([Symbol name=mul])" + "673([Symbol name=slice_prim])" + "640([Symbol name=reshape])" -> "673([Symbol name=slice_prim])" + "643([Symbol name=slice_prim])" + "640([Symbol name=reshape])" -> "643([Symbol name=slice_prim])" + "658([Symbol name=slice_prim])" + "641([Symbol name=reshape])" -> "658([Symbol name=slice_prim])" + "675([Symbol name=slice_prim])" + "641([Symbol name=reshape])" -> "675([Symbol name=slice_prim])" + "677([Symbol name=cudnn_sdpa_fwd])" + "674([Symbol name=cat])" -> "677([Symbol name=cudnn_sdpa_fwd])" + "676([Symbol name=cat])" -> "677([Symbol name=cudnn_sdpa_fwd])" + "642([Symbol name=reshape])" -> "677([Symbol name=cudnn_sdpa_fwd])" + "693([Symbol name=mul])" + "692([Symbol name=broadcast_in_dim])" -> "693([Symbol name=mul])" + "685([Symbol name=convert_element_type])" -> "693([Symbol name=mul])" + "686([Symbol name=mul])" + "685([Symbol name=convert_element_type])" -> "686([Symbol name=mul])" + "706([Symbol name=reciprocal])" + "705([Symbol name=add])" -> "706([Symbol name=reciprocal])" + "729([Symbol name=mul])" + "728([Symbol name=broadcast_in_dim])" -> "729([Symbol name=mul])" + "721([Symbol name=convert_element_type])" -> "729([Symbol name=mul])" + "722([Symbol name=mul])" + "721([Symbol name=convert_element_type])" -> "722([Symbol name=mul])" + "773([Symbol name=slice_prim])" + "740([Symbol name=reshape])" -> "773([Symbol name=slice_prim])" + "743([Symbol name=slice_prim])" + "740([Symbol name=reshape])" -> "743([Symbol name=slice_prim])" + "758([Symbol name=slice_prim])" + "741([Symbol name=reshape])" -> "758([Symbol name=slice_prim])" + "775([Symbol name=slice_prim])" + "741([Symbol name=reshape])" -> "775([Symbol name=slice_prim])" + "777([Symbol name=cudnn_sdpa_fwd])" + "776([Symbol name=cat])" -> "777([Symbol name=cudnn_sdpa_fwd])" + "774([Symbol name=cat])" -> "777([Symbol name=cudnn_sdpa_fwd])" + "742([Symbol name=reshape])" -> "777([Symbol name=cudnn_sdpa_fwd])" + "793([Symbol name=mul])" + "792([Symbol name=broadcast_in_dim])" -> "793([Symbol name=mul])" + "785([Symbol name=convert_element_type])" -> "793([Symbol name=mul])" + "786([Symbol name=mul])" + "785([Symbol name=convert_element_type])" -> "786([Symbol name=mul])" + "806([Symbol name=reciprocal])" + "805([Symbol name=add])" -> "806([Symbol name=reciprocal])" + "829([Symbol name=mul])" + "828([Symbol name=broadcast_in_dim])" -> "829([Symbol name=mul])" + "821([Symbol name=convert_element_type])" -> "829([Symbol name=mul])" + "822([Symbol name=mul])" + "821([Symbol name=convert_element_type])" -> "822([Symbol name=mul])" + "873([Symbol name=slice_prim])" + "840([Symbol name=reshape])" -> "873([Symbol name=slice_prim])" + "843([Symbol name=slice_prim])" + "840([Symbol name=reshape])" -> "843([Symbol name=slice_prim])" + "858([Symbol name=slice_prim])" + "841([Symbol name=reshape])" -> "858([Symbol name=slice_prim])" + "875([Symbol name=slice_prim])" + "841([Symbol name=reshape])" -> "875([Symbol name=slice_prim])" + "877([Symbol name=cudnn_sdpa_fwd])" + "874([Symbol name=cat])" -> "877([Symbol name=cudnn_sdpa_fwd])" + "876([Symbol name=cat])" -> "877([Symbol name=cudnn_sdpa_fwd])" + "842([Symbol name=reshape])" -> "877([Symbol name=cudnn_sdpa_fwd])" + "893([Symbol name=mul])" + "892([Symbol name=broadcast_in_dim])" -> "893([Symbol name=mul])" + "885([Symbol name=convert_element_type])" -> "893([Symbol name=mul])" + "886([Symbol name=mul])" + "885([Symbol name=convert_element_type])" -> "886([Symbol name=mul])" + "906([Symbol name=reciprocal])" + "905([Symbol name=add])" -> "906([Symbol name=reciprocal])" + "929([Symbol name=mul])" + "928([Symbol name=broadcast_in_dim])" -> "929([Symbol name=mul])" + "921([Symbol name=convert_element_type])" -> "929([Symbol name=mul])" + "922([Symbol name=mul])" + "921([Symbol name=convert_element_type])" -> "922([Symbol name=mul])" + "973([Symbol name=slice_prim])" + "940([Symbol name=reshape])" -> "973([Symbol name=slice_prim])" + "943([Symbol name=slice_prim])" + "940([Symbol name=reshape])" -> "943([Symbol name=slice_prim])" + "958([Symbol name=slice_prim])" + "941([Symbol name=reshape])" -> "958([Symbol name=slice_prim])" + "975([Symbol name=slice_prim])" + "941([Symbol name=reshape])" -> "975([Symbol name=slice_prim])" + "977([Symbol name=cudnn_sdpa_fwd])" + "976([Symbol name=cat])" -> "977([Symbol name=cudnn_sdpa_fwd])" + "942([Symbol name=reshape])" -> "977([Symbol name=cudnn_sdpa_fwd])" + "974([Symbol name=cat])" -> "977([Symbol name=cudnn_sdpa_fwd])" + "993([Symbol name=mul])" + "992([Symbol name=broadcast_in_dim])" -> "993([Symbol name=mul])" + "985([Symbol name=convert_element_type])" -> "993([Symbol name=mul])" + "986([Symbol name=mul])" + "985([Symbol name=convert_element_type])" -> "986([Symbol name=mul])" + "1006([Symbol name=reciprocal])" + "1005([Symbol name=add])" -> "1006([Symbol name=reciprocal])" + "1029([Symbol name=mul])" + "1028([Symbol name=broadcast_in_dim])" -> "1029([Symbol name=mul])" + "1021([Symbol name=convert_element_type])" -> "1029([Symbol name=mul])" + "1022([Symbol name=mul])" + "1021([Symbol name=convert_element_type])" -> "1022([Symbol name=mul])" + "1073([Symbol name=slice_prim])" + "1040([Symbol name=reshape])" -> "1073([Symbol name=slice_prim])" + "1043([Symbol name=slice_prim])" + "1040([Symbol name=reshape])" -> "1043([Symbol name=slice_prim])" + "1058([Symbol name=slice_prim])" + "1041([Symbol name=reshape])" -> "1058([Symbol name=slice_prim])" + "1075([Symbol name=slice_prim])" + "1041([Symbol name=reshape])" -> "1075([Symbol name=slice_prim])" + "1077([Symbol name=cudnn_sdpa_fwd])" + "1074([Symbol name=cat])" -> "1077([Symbol name=cudnn_sdpa_fwd])" + "1042([Symbol name=reshape])" -> "1077([Symbol name=cudnn_sdpa_fwd])" + "1076([Symbol name=cat])" -> "1077([Symbol name=cudnn_sdpa_fwd])" + "1093([Symbol name=mul])" + "1092([Symbol name=broadcast_in_dim])" -> "1093([Symbol name=mul])" + "1085([Symbol name=convert_element_type])" -> "1093([Symbol name=mul])" + "1086([Symbol name=mul])" + "1085([Symbol name=convert_element_type])" -> "1086([Symbol name=mul])" + "1106([Symbol name=reciprocal])" + "1105([Symbol name=add])" -> "1106([Symbol name=reciprocal])" + "1129([Symbol name=mul])" + "1128([Symbol name=broadcast_in_dim])" -> "1129([Symbol name=mul])" + "1121([Symbol name=convert_element_type])" -> "1129([Symbol name=mul])" + "1122([Symbol name=mul])" + "1121([Symbol name=convert_element_type])" -> "1122([Symbol name=mul])" + "1173([Symbol name=slice_prim])" + "1140([Symbol name=reshape])" -> "1173([Symbol name=slice_prim])" + "1143([Symbol name=slice_prim])" + "1140([Symbol name=reshape])" -> "1143([Symbol name=slice_prim])" + "1158([Symbol name=slice_prim])" + "1141([Symbol name=reshape])" -> "1158([Symbol name=slice_prim])" + "1175([Symbol name=slice_prim])" + "1141([Symbol name=reshape])" -> "1175([Symbol name=slice_prim])" + "1177([Symbol name=cudnn_sdpa_fwd])" + "1176([Symbol name=cat])" -> "1177([Symbol name=cudnn_sdpa_fwd])" + "1142([Symbol name=reshape])" -> "1177([Symbol name=cudnn_sdpa_fwd])" + "1174([Symbol name=cat])" -> "1177([Symbol name=cudnn_sdpa_fwd])" + "1193([Symbol name=mul])" + "1192([Symbol name=broadcast_in_dim])" -> "1193([Symbol name=mul])" + "1185([Symbol name=convert_element_type])" -> "1193([Symbol name=mul])" + "1186([Symbol name=mul])" + "1185([Symbol name=convert_element_type])" -> "1186([Symbol name=mul])" + "1206([Symbol name=reciprocal])" + "1205([Symbol name=add])" -> "1206([Symbol name=reciprocal])" + "1229([Symbol name=mul])" + "1228([Symbol name=broadcast_in_dim])" -> "1229([Symbol name=mul])" + "1221([Symbol name=convert_element_type])" -> "1229([Symbol name=mul])" + "1222([Symbol name=mul])" + "1221([Symbol name=convert_element_type])" -> "1222([Symbol name=mul])" + "1273([Symbol name=slice_prim])" + "1240([Symbol name=reshape])" -> "1273([Symbol name=slice_prim])" + "1243([Symbol name=slice_prim])" + "1240([Symbol name=reshape])" -> "1243([Symbol name=slice_prim])" + "1258([Symbol name=slice_prim])" + "1241([Symbol name=reshape])" -> "1258([Symbol name=slice_prim])" + "1275([Symbol name=slice_prim])" + "1241([Symbol name=reshape])" -> "1275([Symbol name=slice_prim])" + "1277([Symbol name=cudnn_sdpa_fwd])" + "1242([Symbol name=reshape])" -> "1277([Symbol name=cudnn_sdpa_fwd])" + "1274([Symbol name=cat])" -> "1277([Symbol name=cudnn_sdpa_fwd])" + "1276([Symbol name=cat])" -> "1277([Symbol name=cudnn_sdpa_fwd])" + "1293([Symbol name=mul])" + "1292([Symbol name=broadcast_in_dim])" -> "1293([Symbol name=mul])" + "1285([Symbol name=convert_element_type])" -> "1293([Symbol name=mul])" + "1286([Symbol name=mul])" + "1285([Symbol name=convert_element_type])" -> "1286([Symbol name=mul])" + "1306([Symbol name=reciprocal])" + "1305([Symbol name=add])" -> "1306([Symbol name=reciprocal])" + "1329([Symbol name=mul])" + "1328([Symbol name=broadcast_in_dim])" -> "1329([Symbol name=mul])" + "1321([Symbol name=convert_element_type])" -> "1329([Symbol name=mul])" + "1322([Symbol name=mul])" + "1321([Symbol name=convert_element_type])" -> "1322([Symbol name=mul])" + "1373([Symbol name=slice_prim])" + "1340([Symbol name=reshape])" -> "1373([Symbol name=slice_prim])" + "1343([Symbol name=slice_prim])" + "1340([Symbol name=reshape])" -> "1343([Symbol name=slice_prim])" + "1358([Symbol name=slice_prim])" + "1341([Symbol name=reshape])" -> "1358([Symbol name=slice_prim])" + "1375([Symbol name=slice_prim])" + "1341([Symbol name=reshape])" -> "1375([Symbol name=slice_prim])" + "1377([Symbol name=cudnn_sdpa_fwd])" + "1376([Symbol name=cat])" -> "1377([Symbol name=cudnn_sdpa_fwd])" + "1342([Symbol name=reshape])" -> "1377([Symbol name=cudnn_sdpa_fwd])" + "1374([Symbol name=cat])" -> "1377([Symbol name=cudnn_sdpa_fwd])" + "1393([Symbol name=mul])" + "1392([Symbol name=broadcast_in_dim])" -> "1393([Symbol name=mul])" + "1385([Symbol name=convert_element_type])" -> "1393([Symbol name=mul])" + "1386([Symbol name=mul])" + "1385([Symbol name=convert_element_type])" -> "1386([Symbol name=mul])" + "1406([Symbol name=reciprocal])" + "1405([Symbol name=add])" -> "1406([Symbol name=reciprocal])" + "1429([Symbol name=mul])" + "1428([Symbol name=broadcast_in_dim])" -> "1429([Symbol name=mul])" + "1421([Symbol name=convert_element_type])" -> "1429([Symbol name=mul])" + "1422([Symbol name=mul])" + "1421([Symbol name=convert_element_type])" -> "1422([Symbol name=mul])" + "1473([Symbol name=slice_prim])" + "1440([Symbol name=reshape])" -> "1473([Symbol name=slice_prim])" + "1443([Symbol name=slice_prim])" + "1440([Symbol name=reshape])" -> "1443([Symbol name=slice_prim])" + "1458([Symbol name=slice_prim])" + "1441([Symbol name=reshape])" -> "1458([Symbol name=slice_prim])" + "1475([Symbol name=slice_prim])" + "1441([Symbol name=reshape])" -> "1475([Symbol name=slice_prim])" + "1477([Symbol name=cudnn_sdpa_fwd])" + "1442([Symbol name=reshape])" -> "1477([Symbol name=cudnn_sdpa_fwd])" + "1474([Symbol name=cat])" -> "1477([Symbol name=cudnn_sdpa_fwd])" + "1476([Symbol name=cat])" -> "1477([Symbol name=cudnn_sdpa_fwd])" + "1493([Symbol name=mul])" + "1492([Symbol name=broadcast_in_dim])" -> "1493([Symbol name=mul])" + "1485([Symbol name=convert_element_type])" -> "1493([Symbol name=mul])" + "1486([Symbol name=mul])" + "1485([Symbol name=convert_element_type])" -> "1486([Symbol name=mul])" + "1506([Symbol name=reciprocal])" + "1505([Symbol name=add])" -> "1506([Symbol name=reciprocal])" + "1529([Symbol name=mul])" + "1528([Symbol name=broadcast_in_dim])" -> "1529([Symbol name=mul])" + "1521([Symbol name=convert_element_type])" -> "1529([Symbol name=mul])" + "1522([Symbol name=mul])" + "1521([Symbol name=convert_element_type])" -> "1522([Symbol name=mul])" + "1573([Symbol name=slice_prim])" + "1540([Symbol name=reshape])" -> "1573([Symbol name=slice_prim])" + "1543([Symbol name=slice_prim])" + "1540([Symbol name=reshape])" -> "1543([Symbol name=slice_prim])" + "1558([Symbol name=slice_prim])" + "1541([Symbol name=reshape])" -> "1558([Symbol name=slice_prim])" + "1575([Symbol name=slice_prim])" + "1541([Symbol name=reshape])" -> "1575([Symbol name=slice_prim])" + "1577([Symbol name=cudnn_sdpa_fwd])" + "1576([Symbol name=cat])" -> "1577([Symbol name=cudnn_sdpa_fwd])" + "1574([Symbol name=cat])" -> "1577([Symbol name=cudnn_sdpa_fwd])" + "1542([Symbol name=reshape])" -> "1577([Symbol name=cudnn_sdpa_fwd])" + "1593([Symbol name=mul])" + "1592([Symbol name=broadcast_in_dim])" -> "1593([Symbol name=mul])" + "1585([Symbol name=convert_element_type])" -> "1593([Symbol name=mul])" + "1586([Symbol name=mul])" + "1585([Symbol name=convert_element_type])" -> "1586([Symbol name=mul])" + "1606([Symbol name=reciprocal])" + "1605([Symbol name=add])" -> "1606([Symbol name=reciprocal])" + "1629([Symbol name=mul])" + "1628([Symbol name=broadcast_in_dim])" -> "1629([Symbol name=mul])" + "1621([Symbol name=convert_element_type])" -> "1629([Symbol name=mul])" + "1622([Symbol name=mul])" + "1621([Symbol name=convert_element_type])" -> "1622([Symbol name=mul])" + "1673([Symbol name=slice_prim])" + "1640([Symbol name=reshape])" -> "1673([Symbol name=slice_prim])" + "1643([Symbol name=slice_prim])" + "1640([Symbol name=reshape])" -> "1643([Symbol name=slice_prim])" + "1658([Symbol name=slice_prim])" + "1641([Symbol name=reshape])" -> "1658([Symbol name=slice_prim])" + "1675([Symbol name=slice_prim])" + "1641([Symbol name=reshape])" -> "1675([Symbol name=slice_prim])" + "1677([Symbol name=cudnn_sdpa_fwd])" + "1674([Symbol name=cat])" -> "1677([Symbol name=cudnn_sdpa_fwd])" + "1676([Symbol name=cat])" -> "1677([Symbol name=cudnn_sdpa_fwd])" + "1642([Symbol name=reshape])" -> "1677([Symbol name=cudnn_sdpa_fwd])" + "1693([Symbol name=mul])" + "1692([Symbol name=broadcast_in_dim])" -> "1693([Symbol name=mul])" + "1685([Symbol name=convert_element_type])" -> "1693([Symbol name=mul])" + "1686([Symbol name=mul])" + "1685([Symbol name=convert_element_type])" -> "1686([Symbol name=mul])" + "1706([Symbol name=reciprocal])" + "1705([Symbol name=add])" -> "1706([Symbol name=reciprocal])" + "1729([Symbol name=mul])" + "1728([Symbol name=broadcast_in_dim])" -> "1729([Symbol name=mul])" + "1721([Symbol name=convert_element_type])" -> "1729([Symbol name=mul])" + "1722([Symbol name=mul])" + "1721([Symbol name=convert_element_type])" -> "1722([Symbol name=mul])" + "126([Symbol name=add])" + "125([Symbol name=true_divide])" -> "126([Symbol name=add])" + "194([Symbol name=convert_element_type])" + "193([Symbol name=mul])" -> "194([Symbol name=convert_element_type])" + "187([Symbol name=sum])" + "186([Symbol name=mul])" -> "187([Symbol name=sum])" + "144([Symbol name=slice_prim])" + "143([Symbol name=slice_prim])" -> "144([Symbol name=slice_prim])" + "145([Symbol name=slice_prim])" + "143([Symbol name=slice_prim])" -> "145([Symbol name=slice_prim])" + "151([Symbol name=convert_element_type])" + "143([Symbol name=slice_prim])" -> "151([Symbol name=convert_element_type])" + "160([Symbol name=slice_prim])" + "158([Symbol name=slice_prim])" -> "160([Symbol name=slice_prim])" + "166([Symbol name=convert_element_type])" + "158([Symbol name=slice_prim])" -> "166([Symbol name=convert_element_type])" + "159([Symbol name=slice_prim])" + "158([Symbol name=slice_prim])" -> "159([Symbol name=slice_prim])" + "178([Symbol name=transpose])" + "177([Symbol name=cudnn_sdpa_fwd])" -> "178([Symbol name=transpose])" + "207([Symbol name=convert_element_type])" + "206([Symbol name=reciprocal])" -> "207([Symbol name=convert_element_type])" + "230([Symbol name=convert_element_type])" + "229([Symbol name=mul])" -> "230([Symbol name=convert_element_type])" + "223([Symbol name=sum])" + "222([Symbol name=mul])" -> "223([Symbol name=sum])" + "251([Symbol name=convert_element_type])" + "243([Symbol name=slice_prim])" -> "251([Symbol name=convert_element_type])" + "244([Symbol name=slice_prim])" + "243([Symbol name=slice_prim])" -> "244([Symbol name=slice_prim])" + "245([Symbol name=slice_prim])" + "243([Symbol name=slice_prim])" -> "245([Symbol name=slice_prim])" + "266([Symbol name=convert_element_type])" + "258([Symbol name=slice_prim])" -> "266([Symbol name=convert_element_type])" + "259([Symbol name=slice_prim])" + "258([Symbol name=slice_prim])" -> "259([Symbol name=slice_prim])" + "260([Symbol name=slice_prim])" + "258([Symbol name=slice_prim])" -> "260([Symbol name=slice_prim])" + "278([Symbol name=transpose])" + "277([Symbol name=cudnn_sdpa_fwd])" -> "278([Symbol name=transpose])" + "294([Symbol name=convert_element_type])" + "293([Symbol name=mul])" -> "294([Symbol name=convert_element_type])" + "287([Symbol name=sum])" + "286([Symbol name=mul])" -> "287([Symbol name=sum])" + "307([Symbol name=convert_element_type])" + "306([Symbol name=reciprocal])" -> "307([Symbol name=convert_element_type])" + "330([Symbol name=convert_element_type])" + "329([Symbol name=mul])" -> "330([Symbol name=convert_element_type])" + "323([Symbol name=sum])" + "322([Symbol name=mul])" -> "323([Symbol name=sum])" + "344([Symbol name=slice_prim])" + "343([Symbol name=slice_prim])" -> "344([Symbol name=slice_prim])" + "345([Symbol name=slice_prim])" + "343([Symbol name=slice_prim])" -> "345([Symbol name=slice_prim])" + "351([Symbol name=convert_element_type])" + "343([Symbol name=slice_prim])" -> "351([Symbol name=convert_element_type])" + "360([Symbol name=slice_prim])" + "358([Symbol name=slice_prim])" -> "360([Symbol name=slice_prim])" + "366([Symbol name=convert_element_type])" + "358([Symbol name=slice_prim])" -> "366([Symbol name=convert_element_type])" + "359([Symbol name=slice_prim])" + "358([Symbol name=slice_prim])" -> "359([Symbol name=slice_prim])" + "378([Symbol name=transpose])" + "377([Symbol name=cudnn_sdpa_fwd])" -> "378([Symbol name=transpose])" + "394([Symbol name=convert_element_type])" + "393([Symbol name=mul])" -> "394([Symbol name=convert_element_type])" + "387([Symbol name=sum])" + "386([Symbol name=mul])" -> "387([Symbol name=sum])" + "407([Symbol name=convert_element_type])" + "406([Symbol name=reciprocal])" -> "407([Symbol name=convert_element_type])" + "430([Symbol name=convert_element_type])" + "429([Symbol name=mul])" -> "430([Symbol name=convert_element_type])" + "423([Symbol name=sum])" + "422([Symbol name=mul])" -> "423([Symbol name=sum])" + "451([Symbol name=convert_element_type])" + "443([Symbol name=slice_prim])" -> "451([Symbol name=convert_element_type])" + "444([Symbol name=slice_prim])" + "443([Symbol name=slice_prim])" -> "444([Symbol name=slice_prim])" + "445([Symbol name=slice_prim])" + "443([Symbol name=slice_prim])" -> "445([Symbol name=slice_prim])" + "466([Symbol name=convert_element_type])" + "458([Symbol name=slice_prim])" -> "466([Symbol name=convert_element_type])" + "459([Symbol name=slice_prim])" + "458([Symbol name=slice_prim])" -> "459([Symbol name=slice_prim])" + "460([Symbol name=slice_prim])" + "458([Symbol name=slice_prim])" -> "460([Symbol name=slice_prim])" + "478([Symbol name=transpose])" + "477([Symbol name=cudnn_sdpa_fwd])" -> "478([Symbol name=transpose])" + "494([Symbol name=convert_element_type])" + "493([Symbol name=mul])" -> "494([Symbol name=convert_element_type])" + "487([Symbol name=sum])" + "486([Symbol name=mul])" -> "487([Symbol name=sum])" + "507([Symbol name=convert_element_type])" + "506([Symbol name=reciprocal])" -> "507([Symbol name=convert_element_type])" + "530([Symbol name=convert_element_type])" + "529([Symbol name=mul])" -> "530([Symbol name=convert_element_type])" + "523([Symbol name=sum])" + "522([Symbol name=mul])" -> "523([Symbol name=sum])" + "544([Symbol name=slice_prim])" + "543([Symbol name=slice_prim])" -> "544([Symbol name=slice_prim])" + "545([Symbol name=slice_prim])" + "543([Symbol name=slice_prim])" -> "545([Symbol name=slice_prim])" + "551([Symbol name=convert_element_type])" + "543([Symbol name=slice_prim])" -> "551([Symbol name=convert_element_type])" + "560([Symbol name=slice_prim])" + "558([Symbol name=slice_prim])" -> "560([Symbol name=slice_prim])" + "566([Symbol name=convert_element_type])" + "558([Symbol name=slice_prim])" -> "566([Symbol name=convert_element_type])" + "559([Symbol name=slice_prim])" + "558([Symbol name=slice_prim])" -> "559([Symbol name=slice_prim])" + "578([Symbol name=transpose])" + "577([Symbol name=cudnn_sdpa_fwd])" -> "578([Symbol name=transpose])" + "594([Symbol name=convert_element_type])" + "593([Symbol name=mul])" -> "594([Symbol name=convert_element_type])" + "587([Symbol name=sum])" + "586([Symbol name=mul])" -> "587([Symbol name=sum])" + "607([Symbol name=convert_element_type])" + "606([Symbol name=reciprocal])" -> "607([Symbol name=convert_element_type])" + "630([Symbol name=convert_element_type])" + "629([Symbol name=mul])" -> "630([Symbol name=convert_element_type])" + "623([Symbol name=sum])" + "622([Symbol name=mul])" -> "623([Symbol name=sum])" + "651([Symbol name=convert_element_type])" + "643([Symbol name=slice_prim])" -> "651([Symbol name=convert_element_type])" + "644([Symbol name=slice_prim])" + "643([Symbol name=slice_prim])" -> "644([Symbol name=slice_prim])" + "645([Symbol name=slice_prim])" + "643([Symbol name=slice_prim])" -> "645([Symbol name=slice_prim])" + "666([Symbol name=convert_element_type])" + "658([Symbol name=slice_prim])" -> "666([Symbol name=convert_element_type])" + "659([Symbol name=slice_prim])" + "658([Symbol name=slice_prim])" -> "659([Symbol name=slice_prim])" + "660([Symbol name=slice_prim])" + "658([Symbol name=slice_prim])" -> "660([Symbol name=slice_prim])" + "678([Symbol name=transpose])" + "677([Symbol name=cudnn_sdpa_fwd])" -> "678([Symbol name=transpose])" + "694([Symbol name=convert_element_type])" + "693([Symbol name=mul])" -> "694([Symbol name=convert_element_type])" + "687([Symbol name=sum])" + "686([Symbol name=mul])" -> "687([Symbol name=sum])" + "707([Symbol name=convert_element_type])" + "706([Symbol name=reciprocal])" -> "707([Symbol name=convert_element_type])" + "730([Symbol name=convert_element_type])" + "729([Symbol name=mul])" -> "730([Symbol name=convert_element_type])" + "723([Symbol name=sum])" + "722([Symbol name=mul])" -> "723([Symbol name=sum])" + "744([Symbol name=slice_prim])" + "743([Symbol name=slice_prim])" -> "744([Symbol name=slice_prim])" + "745([Symbol name=slice_prim])" + "743([Symbol name=slice_prim])" -> "745([Symbol name=slice_prim])" + "751([Symbol name=convert_element_type])" + "743([Symbol name=slice_prim])" -> "751([Symbol name=convert_element_type])" + "760([Symbol name=slice_prim])" + "758([Symbol name=slice_prim])" -> "760([Symbol name=slice_prim])" + "766([Symbol name=convert_element_type])" + "758([Symbol name=slice_prim])" -> "766([Symbol name=convert_element_type])" + "759([Symbol name=slice_prim])" + "758([Symbol name=slice_prim])" -> "759([Symbol name=slice_prim])" + "778([Symbol name=transpose])" + "777([Symbol name=cudnn_sdpa_fwd])" -> "778([Symbol name=transpose])" + "794([Symbol name=convert_element_type])" + "793([Symbol name=mul])" -> "794([Symbol name=convert_element_type])" + "787([Symbol name=sum])" + "786([Symbol name=mul])" -> "787([Symbol name=sum])" + "807([Symbol name=convert_element_type])" + "806([Symbol name=reciprocal])" -> "807([Symbol name=convert_element_type])" + "830([Symbol name=convert_element_type])" + "829([Symbol name=mul])" -> "830([Symbol name=convert_element_type])" + "823([Symbol name=sum])" + "822([Symbol name=mul])" -> "823([Symbol name=sum])" + "851([Symbol name=convert_element_type])" + "843([Symbol name=slice_prim])" -> "851([Symbol name=convert_element_type])" + "844([Symbol name=slice_prim])" + "843([Symbol name=slice_prim])" -> "844([Symbol name=slice_prim])" + "845([Symbol name=slice_prim])" + "843([Symbol name=slice_prim])" -> "845([Symbol name=slice_prim])" + "866([Symbol name=convert_element_type])" + "858([Symbol name=slice_prim])" -> "866([Symbol name=convert_element_type])" + "859([Symbol name=slice_prim])" + "858([Symbol name=slice_prim])" -> "859([Symbol name=slice_prim])" + "860([Symbol name=slice_prim])" + "858([Symbol name=slice_prim])" -> "860([Symbol name=slice_prim])" + "878([Symbol name=transpose])" + "877([Symbol name=cudnn_sdpa_fwd])" -> "878([Symbol name=transpose])" + "894([Symbol name=convert_element_type])" + "893([Symbol name=mul])" -> "894([Symbol name=convert_element_type])" + "887([Symbol name=sum])" + "886([Symbol name=mul])" -> "887([Symbol name=sum])" + "907([Symbol name=convert_element_type])" + "906([Symbol name=reciprocal])" -> "907([Symbol name=convert_element_type])" + "930([Symbol name=convert_element_type])" + "929([Symbol name=mul])" -> "930([Symbol name=convert_element_type])" + "923([Symbol name=sum])" + "922([Symbol name=mul])" -> "923([Symbol name=sum])" + "944([Symbol name=slice_prim])" + "943([Symbol name=slice_prim])" -> "944([Symbol name=slice_prim])" + "945([Symbol name=slice_prim])" + "943([Symbol name=slice_prim])" -> "945([Symbol name=slice_prim])" + "951([Symbol name=convert_element_type])" + "943([Symbol name=slice_prim])" -> "951([Symbol name=convert_element_type])" + "960([Symbol name=slice_prim])" + "958([Symbol name=slice_prim])" -> "960([Symbol name=slice_prim])" + "966([Symbol name=convert_element_type])" + "958([Symbol name=slice_prim])" -> "966([Symbol name=convert_element_type])" + "959([Symbol name=slice_prim])" + "958([Symbol name=slice_prim])" -> "959([Symbol name=slice_prim])" + "978([Symbol name=transpose])" + "977([Symbol name=cudnn_sdpa_fwd])" -> "978([Symbol name=transpose])" + "994([Symbol name=convert_element_type])" + "993([Symbol name=mul])" -> "994([Symbol name=convert_element_type])" + "987([Symbol name=sum])" + "986([Symbol name=mul])" -> "987([Symbol name=sum])" + "1007([Symbol name=convert_element_type])" + "1006([Symbol name=reciprocal])" -> "1007([Symbol name=convert_element_type])" + "1030([Symbol name=convert_element_type])" + "1029([Symbol name=mul])" -> "1030([Symbol name=convert_element_type])" + "1023([Symbol name=sum])" + "1022([Symbol name=mul])" -> "1023([Symbol name=sum])" + "1051([Symbol name=convert_element_type])" + "1043([Symbol name=slice_prim])" -> "1051([Symbol name=convert_element_type])" + "1044([Symbol name=slice_prim])" + "1043([Symbol name=slice_prim])" -> "1044([Symbol name=slice_prim])" + "1045([Symbol name=slice_prim])" + "1043([Symbol name=slice_prim])" -> "1045([Symbol name=slice_prim])" + "1066([Symbol name=convert_element_type])" + "1058([Symbol name=slice_prim])" -> "1066([Symbol name=convert_element_type])" + "1059([Symbol name=slice_prim])" + "1058([Symbol name=slice_prim])" -> "1059([Symbol name=slice_prim])" + "1060([Symbol name=slice_prim])" + "1058([Symbol name=slice_prim])" -> "1060([Symbol name=slice_prim])" + "1078([Symbol name=transpose])" + "1077([Symbol name=cudnn_sdpa_fwd])" -> "1078([Symbol name=transpose])" + "1094([Symbol name=convert_element_type])" + "1093([Symbol name=mul])" -> "1094([Symbol name=convert_element_type])" + "1087([Symbol name=sum])" + "1086([Symbol name=mul])" -> "1087([Symbol name=sum])" + "1107([Symbol name=convert_element_type])" + "1106([Symbol name=reciprocal])" -> "1107([Symbol name=convert_element_type])" + "1130([Symbol name=convert_element_type])" + "1129([Symbol name=mul])" -> "1130([Symbol name=convert_element_type])" + "1123([Symbol name=sum])" + "1122([Symbol name=mul])" -> "1123([Symbol name=sum])" + "1144([Symbol name=slice_prim])" + "1143([Symbol name=slice_prim])" -> "1144([Symbol name=slice_prim])" + "1145([Symbol name=slice_prim])" + "1143([Symbol name=slice_prim])" -> "1145([Symbol name=slice_prim])" + "1151([Symbol name=convert_element_type])" + "1143([Symbol name=slice_prim])" -> "1151([Symbol name=convert_element_type])" + "1160([Symbol name=slice_prim])" + "1158([Symbol name=slice_prim])" -> "1160([Symbol name=slice_prim])" + "1166([Symbol name=convert_element_type])" + "1158([Symbol name=slice_prim])" -> "1166([Symbol name=convert_element_type])" + "1159([Symbol name=slice_prim])" + "1158([Symbol name=slice_prim])" -> "1159([Symbol name=slice_prim])" + "1178([Symbol name=transpose])" + "1177([Symbol name=cudnn_sdpa_fwd])" -> "1178([Symbol name=transpose])" + "1194([Symbol name=convert_element_type])" + "1193([Symbol name=mul])" -> "1194([Symbol name=convert_element_type])" + "1187([Symbol name=sum])" + "1186([Symbol name=mul])" -> "1187([Symbol name=sum])" + "1207([Symbol name=convert_element_type])" + "1206([Symbol name=reciprocal])" -> "1207([Symbol name=convert_element_type])" + "1230([Symbol name=convert_element_type])" + "1229([Symbol name=mul])" -> "1230([Symbol name=convert_element_type])" + "1223([Symbol name=sum])" + "1222([Symbol name=mul])" -> "1223([Symbol name=sum])" + "1251([Symbol name=convert_element_type])" + "1243([Symbol name=slice_prim])" -> "1251([Symbol name=convert_element_type])" + "1244([Symbol name=slice_prim])" + "1243([Symbol name=slice_prim])" -> "1244([Symbol name=slice_prim])" + "1245([Symbol name=slice_prim])" + "1243([Symbol name=slice_prim])" -> "1245([Symbol name=slice_prim])" + "1266([Symbol name=convert_element_type])" + "1258([Symbol name=slice_prim])" -> "1266([Symbol name=convert_element_type])" + "1259([Symbol name=slice_prim])" + "1258([Symbol name=slice_prim])" -> "1259([Symbol name=slice_prim])" + "1260([Symbol name=slice_prim])" + "1258([Symbol name=slice_prim])" -> "1260([Symbol name=slice_prim])" + "1278([Symbol name=transpose])" + "1277([Symbol name=cudnn_sdpa_fwd])" -> "1278([Symbol name=transpose])" + "1294([Symbol name=convert_element_type])" + "1293([Symbol name=mul])" -> "1294([Symbol name=convert_element_type])" + "1287([Symbol name=sum])" + "1286([Symbol name=mul])" -> "1287([Symbol name=sum])" + "1307([Symbol name=convert_element_type])" + "1306([Symbol name=reciprocal])" -> "1307([Symbol name=convert_element_type])" + "1330([Symbol name=convert_element_type])" + "1329([Symbol name=mul])" -> "1330([Symbol name=convert_element_type])" + "1323([Symbol name=sum])" + "1322([Symbol name=mul])" -> "1323([Symbol name=sum])" + "1344([Symbol name=slice_prim])" + "1343([Symbol name=slice_prim])" -> "1344([Symbol name=slice_prim])" + "1345([Symbol name=slice_prim])" + "1343([Symbol name=slice_prim])" -> "1345([Symbol name=slice_prim])" + "1351([Symbol name=convert_element_type])" + "1343([Symbol name=slice_prim])" -> "1351([Symbol name=convert_element_type])" + "1360([Symbol name=slice_prim])" + "1358([Symbol name=slice_prim])" -> "1360([Symbol name=slice_prim])" + "1366([Symbol name=convert_element_type])" + "1358([Symbol name=slice_prim])" -> "1366([Symbol name=convert_element_type])" + "1359([Symbol name=slice_prim])" + "1358([Symbol name=slice_prim])" -> "1359([Symbol name=slice_prim])" + "1378([Symbol name=transpose])" + "1377([Symbol name=cudnn_sdpa_fwd])" -> "1378([Symbol name=transpose])" + "1394([Symbol name=convert_element_type])" + "1393([Symbol name=mul])" -> "1394([Symbol name=convert_element_type])" + "1387([Symbol name=sum])" + "1386([Symbol name=mul])" -> "1387([Symbol name=sum])" + "1407([Symbol name=convert_element_type])" + "1406([Symbol name=reciprocal])" -> "1407([Symbol name=convert_element_type])" + "1430([Symbol name=convert_element_type])" + "1429([Symbol name=mul])" -> "1430([Symbol name=convert_element_type])" + "1423([Symbol name=sum])" + "1422([Symbol name=mul])" -> "1423([Symbol name=sum])" + "1451([Symbol name=convert_element_type])" + "1443([Symbol name=slice_prim])" -> "1451([Symbol name=convert_element_type])" + "1444([Symbol name=slice_prim])" + "1443([Symbol name=slice_prim])" -> "1444([Symbol name=slice_prim])" + "1445([Symbol name=slice_prim])" + "1443([Symbol name=slice_prim])" -> "1445([Symbol name=slice_prim])" + "1466([Symbol name=convert_element_type])" + "1458([Symbol name=slice_prim])" -> "1466([Symbol name=convert_element_type])" + "1459([Symbol name=slice_prim])" + "1458([Symbol name=slice_prim])" -> "1459([Symbol name=slice_prim])" + "1460([Symbol name=slice_prim])" + "1458([Symbol name=slice_prim])" -> "1460([Symbol name=slice_prim])" + "1478([Symbol name=transpose])" + "1477([Symbol name=cudnn_sdpa_fwd])" -> "1478([Symbol name=transpose])" + "1494([Symbol name=convert_element_type])" + "1493([Symbol name=mul])" -> "1494([Symbol name=convert_element_type])" + "1487([Symbol name=sum])" + "1486([Symbol name=mul])" -> "1487([Symbol name=sum])" + "1507([Symbol name=convert_element_type])" + "1506([Symbol name=reciprocal])" -> "1507([Symbol name=convert_element_type])" + "1530([Symbol name=convert_element_type])" + "1529([Symbol name=mul])" -> "1530([Symbol name=convert_element_type])" + "1523([Symbol name=sum])" + "1522([Symbol name=mul])" -> "1523([Symbol name=sum])" + "1544([Symbol name=slice_prim])" + "1543([Symbol name=slice_prim])" -> "1544([Symbol name=slice_prim])" + "1545([Symbol name=slice_prim])" + "1543([Symbol name=slice_prim])" -> "1545([Symbol name=slice_prim])" + "1551([Symbol name=convert_element_type])" + "1543([Symbol name=slice_prim])" -> "1551([Symbol name=convert_element_type])" + "1560([Symbol name=slice_prim])" + "1558([Symbol name=slice_prim])" -> "1560([Symbol name=slice_prim])" + "1566([Symbol name=convert_element_type])" + "1558([Symbol name=slice_prim])" -> "1566([Symbol name=convert_element_type])" + "1559([Symbol name=slice_prim])" + "1558([Symbol name=slice_prim])" -> "1559([Symbol name=slice_prim])" + "1578([Symbol name=transpose])" + "1577([Symbol name=cudnn_sdpa_fwd])" -> "1578([Symbol name=transpose])" + "1594([Symbol name=convert_element_type])" + "1593([Symbol name=mul])" -> "1594([Symbol name=convert_element_type])" + "1587([Symbol name=sum])" + "1586([Symbol name=mul])" -> "1587([Symbol name=sum])" + "1607([Symbol name=convert_element_type])" + "1606([Symbol name=reciprocal])" -> "1607([Symbol name=convert_element_type])" + "1630([Symbol name=convert_element_type])" + "1629([Symbol name=mul])" -> "1630([Symbol name=convert_element_type])" + "1623([Symbol name=sum])" + "1622([Symbol name=mul])" -> "1623([Symbol name=sum])" + "1651([Symbol name=convert_element_type])" + "1643([Symbol name=slice_prim])" -> "1651([Symbol name=convert_element_type])" + "1644([Symbol name=slice_prim])" + "1643([Symbol name=slice_prim])" -> "1644([Symbol name=slice_prim])" + "1645([Symbol name=slice_prim])" + "1643([Symbol name=slice_prim])" -> "1645([Symbol name=slice_prim])" + "1666([Symbol name=convert_element_type])" + "1658([Symbol name=slice_prim])" -> "1666([Symbol name=convert_element_type])" + "1659([Symbol name=slice_prim])" + "1658([Symbol name=slice_prim])" -> "1659([Symbol name=slice_prim])" + "1660([Symbol name=slice_prim])" + "1658([Symbol name=slice_prim])" -> "1660([Symbol name=slice_prim])" + "1678([Symbol name=transpose])" + "1677([Symbol name=cudnn_sdpa_fwd])" -> "1678([Symbol name=transpose])" + "1694([Symbol name=convert_element_type])" + "1693([Symbol name=mul])" -> "1694([Symbol name=convert_element_type])" + "1687([Symbol name=sum])" + "1686([Symbol name=mul])" -> "1687([Symbol name=sum])" + "1707([Symbol name=convert_element_type])" + "1706([Symbol name=reciprocal])" -> "1707([Symbol name=convert_element_type])" + "1730([Symbol name=convert_element_type])" + "1729([Symbol name=mul])" -> "1730([Symbol name=convert_element_type])" + "1723([Symbol name=sum])" + "1722([Symbol name=mul])" -> "1723([Symbol name=sum])" + "127([Symbol name=rsqrt])" + "126([Symbol name=add])" -> "127([Symbol name=rsqrt])" + "196([Symbol name=convert_element_type])" + "194([Symbol name=convert_element_type])" -> "196([Symbol name=convert_element_type])" + "188([Symbol name=broadcast_in_dim])" + "187([Symbol name=sum])" -> "188([Symbol name=broadcast_in_dim])" + "149([Symbol name=cat])" + "144([Symbol name=slice_prim])" -> "149([Symbol name=cat])" + "148([Symbol name=convert_element_type])" -> "149([Symbol name=cat])" + "146([Symbol name=convert_element_type])" + "145([Symbol name=slice_prim])" -> "146([Symbol name=convert_element_type])" + "161([Symbol name=convert_element_type])" + "160([Symbol name=slice_prim])" -> "161([Symbol name=convert_element_type])" + "164([Symbol name=cat])" + "163([Symbol name=convert_element_type])" -> "164([Symbol name=cat])" + "159([Symbol name=slice_prim])" -> "164([Symbol name=cat])" + "179([Symbol name=reshape])" + "178([Symbol name=transpose])" -> "179([Symbol name=reshape])" + "209([Symbol name=convert_element_type])" + "207([Symbol name=convert_element_type])" -> "209([Symbol name=convert_element_type])" + "232([Symbol name=convert_element_type])" + "230([Symbol name=convert_element_type])" -> "232([Symbol name=convert_element_type])" + "224([Symbol name=broadcast_in_dim])" + "223([Symbol name=sum])" -> "224([Symbol name=broadcast_in_dim])" + "249([Symbol name=cat])" + "248([Symbol name=convert_element_type])" -> "249([Symbol name=cat])" + "244([Symbol name=slice_prim])" -> "249([Symbol name=cat])" + "246([Symbol name=convert_element_type])" + "245([Symbol name=slice_prim])" -> "246([Symbol name=convert_element_type])" + "264([Symbol name=cat])" + "259([Symbol name=slice_prim])" -> "264([Symbol name=cat])" + "263([Symbol name=convert_element_type])" -> "264([Symbol name=cat])" + "261([Symbol name=convert_element_type])" + "260([Symbol name=slice_prim])" -> "261([Symbol name=convert_element_type])" + "279([Symbol name=reshape])" + "278([Symbol name=transpose])" -> "279([Symbol name=reshape])" + "296([Symbol name=convert_element_type])" + "294([Symbol name=convert_element_type])" -> "296([Symbol name=convert_element_type])" + "288([Symbol name=broadcast_in_dim])" + "287([Symbol name=sum])" -> "288([Symbol name=broadcast_in_dim])" + "309([Symbol name=convert_element_type])" + "307([Symbol name=convert_element_type])" -> "309([Symbol name=convert_element_type])" + "332([Symbol name=convert_element_type])" + "330([Symbol name=convert_element_type])" -> "332([Symbol name=convert_element_type])" + "324([Symbol name=broadcast_in_dim])" + "323([Symbol name=sum])" -> "324([Symbol name=broadcast_in_dim])" + "349([Symbol name=cat])" + "344([Symbol name=slice_prim])" -> "349([Symbol name=cat])" + "348([Symbol name=convert_element_type])" -> "349([Symbol name=cat])" + "346([Symbol name=convert_element_type])" + "345([Symbol name=slice_prim])" -> "346([Symbol name=convert_element_type])" + "361([Symbol name=convert_element_type])" + "360([Symbol name=slice_prim])" -> "361([Symbol name=convert_element_type])" + "364([Symbol name=cat])" + "363([Symbol name=convert_element_type])" -> "364([Symbol name=cat])" + "359([Symbol name=slice_prim])" -> "364([Symbol name=cat])" + "379([Symbol name=reshape])" + "378([Symbol name=transpose])" -> "379([Symbol name=reshape])" + "396([Symbol name=convert_element_type])" + "394([Symbol name=convert_element_type])" -> "396([Symbol name=convert_element_type])" + "388([Symbol name=broadcast_in_dim])" + "387([Symbol name=sum])" -> "388([Symbol name=broadcast_in_dim])" + "409([Symbol name=convert_element_type])" + "407([Symbol name=convert_element_type])" -> "409([Symbol name=convert_element_type])" + "432([Symbol name=convert_element_type])" + "430([Symbol name=convert_element_type])" -> "432([Symbol name=convert_element_type])" + "424([Symbol name=broadcast_in_dim])" + "423([Symbol name=sum])" -> "424([Symbol name=broadcast_in_dim])" + "449([Symbol name=cat])" + "448([Symbol name=convert_element_type])" -> "449([Symbol name=cat])" + "444([Symbol name=slice_prim])" -> "449([Symbol name=cat])" + "446([Symbol name=convert_element_type])" + "445([Symbol name=slice_prim])" -> "446([Symbol name=convert_element_type])" + "464([Symbol name=cat])" + "459([Symbol name=slice_prim])" -> "464([Symbol name=cat])" + "463([Symbol name=convert_element_type])" -> "464([Symbol name=cat])" + "461([Symbol name=convert_element_type])" + "460([Symbol name=slice_prim])" -> "461([Symbol name=convert_element_type])" + "479([Symbol name=reshape])" + "478([Symbol name=transpose])" -> "479([Symbol name=reshape])" + "496([Symbol name=convert_element_type])" + "494([Symbol name=convert_element_type])" -> "496([Symbol name=convert_element_type])" + "488([Symbol name=broadcast_in_dim])" + "487([Symbol name=sum])" -> "488([Symbol name=broadcast_in_dim])" + "509([Symbol name=convert_element_type])" + "507([Symbol name=convert_element_type])" -> "509([Symbol name=convert_element_type])" + "532([Symbol name=convert_element_type])" + "530([Symbol name=convert_element_type])" -> "532([Symbol name=convert_element_type])" + "524([Symbol name=broadcast_in_dim])" + "523([Symbol name=sum])" -> "524([Symbol name=broadcast_in_dim])" + "549([Symbol name=cat])" + "544([Symbol name=slice_prim])" -> "549([Symbol name=cat])" + "548([Symbol name=convert_element_type])" -> "549([Symbol name=cat])" + "546([Symbol name=convert_element_type])" + "545([Symbol name=slice_prim])" -> "546([Symbol name=convert_element_type])" + "561([Symbol name=convert_element_type])" + "560([Symbol name=slice_prim])" -> "561([Symbol name=convert_element_type])" + "564([Symbol name=cat])" + "563([Symbol name=convert_element_type])" -> "564([Symbol name=cat])" + "559([Symbol name=slice_prim])" -> "564([Symbol name=cat])" + "579([Symbol name=reshape])" + "578([Symbol name=transpose])" -> "579([Symbol name=reshape])" + "596([Symbol name=convert_element_type])" + "594([Symbol name=convert_element_type])" -> "596([Symbol name=convert_element_type])" + "588([Symbol name=broadcast_in_dim])" + "587([Symbol name=sum])" -> "588([Symbol name=broadcast_in_dim])" + "609([Symbol name=convert_element_type])" + "607([Symbol name=convert_element_type])" -> "609([Symbol name=convert_element_type])" + "632([Symbol name=convert_element_type])" + "630([Symbol name=convert_element_type])" -> "632([Symbol name=convert_element_type])" + "624([Symbol name=broadcast_in_dim])" + "623([Symbol name=sum])" -> "624([Symbol name=broadcast_in_dim])" + "649([Symbol name=cat])" + "648([Symbol name=convert_element_type])" -> "649([Symbol name=cat])" + "644([Symbol name=slice_prim])" -> "649([Symbol name=cat])" + "646([Symbol name=convert_element_type])" + "645([Symbol name=slice_prim])" -> "646([Symbol name=convert_element_type])" + "664([Symbol name=cat])" + "659([Symbol name=slice_prim])" -> "664([Symbol name=cat])" + "663([Symbol name=convert_element_type])" -> "664([Symbol name=cat])" + "661([Symbol name=convert_element_type])" + "660([Symbol name=slice_prim])" -> "661([Symbol name=convert_element_type])" + "679([Symbol name=reshape])" + "678([Symbol name=transpose])" -> "679([Symbol name=reshape])" + "696([Symbol name=convert_element_type])" + "694([Symbol name=convert_element_type])" -> "696([Symbol name=convert_element_type])" + "688([Symbol name=broadcast_in_dim])" + "687([Symbol name=sum])" -> "688([Symbol name=broadcast_in_dim])" + "709([Symbol name=convert_element_type])" + "707([Symbol name=convert_element_type])" -> "709([Symbol name=convert_element_type])" + "732([Symbol name=convert_element_type])" + "730([Symbol name=convert_element_type])" -> "732([Symbol name=convert_element_type])" + "724([Symbol name=broadcast_in_dim])" + "723([Symbol name=sum])" -> "724([Symbol name=broadcast_in_dim])" + "749([Symbol name=cat])" + "744([Symbol name=slice_prim])" -> "749([Symbol name=cat])" + "748([Symbol name=convert_element_type])" -> "749([Symbol name=cat])" + "746([Symbol name=convert_element_type])" + "745([Symbol name=slice_prim])" -> "746([Symbol name=convert_element_type])" + "761([Symbol name=convert_element_type])" + "760([Symbol name=slice_prim])" -> "761([Symbol name=convert_element_type])" + "764([Symbol name=cat])" + "763([Symbol name=convert_element_type])" -> "764([Symbol name=cat])" + "759([Symbol name=slice_prim])" -> "764([Symbol name=cat])" + "779([Symbol name=reshape])" + "778([Symbol name=transpose])" -> "779([Symbol name=reshape])" + "796([Symbol name=convert_element_type])" + "794([Symbol name=convert_element_type])" -> "796([Symbol name=convert_element_type])" + "788([Symbol name=broadcast_in_dim])" + "787([Symbol name=sum])" -> "788([Symbol name=broadcast_in_dim])" + "809([Symbol name=convert_element_type])" + "807([Symbol name=convert_element_type])" -> "809([Symbol name=convert_element_type])" + "832([Symbol name=convert_element_type])" + "830([Symbol name=convert_element_type])" -> "832([Symbol name=convert_element_type])" + "824([Symbol name=broadcast_in_dim])" + "823([Symbol name=sum])" -> "824([Symbol name=broadcast_in_dim])" + "849([Symbol name=cat])" + "848([Symbol name=convert_element_type])" -> "849([Symbol name=cat])" + "844([Symbol name=slice_prim])" -> "849([Symbol name=cat])" + "846([Symbol name=convert_element_type])" + "845([Symbol name=slice_prim])" -> "846([Symbol name=convert_element_type])" + "864([Symbol name=cat])" + "859([Symbol name=slice_prim])" -> "864([Symbol name=cat])" + "863([Symbol name=convert_element_type])" -> "864([Symbol name=cat])" + "861([Symbol name=convert_element_type])" + "860([Symbol name=slice_prim])" -> "861([Symbol name=convert_element_type])" + "879([Symbol name=reshape])" + "878([Symbol name=transpose])" -> "879([Symbol name=reshape])" + "896([Symbol name=convert_element_type])" + "894([Symbol name=convert_element_type])" -> "896([Symbol name=convert_element_type])" + "888([Symbol name=broadcast_in_dim])" + "887([Symbol name=sum])" -> "888([Symbol name=broadcast_in_dim])" + "909([Symbol name=convert_element_type])" + "907([Symbol name=convert_element_type])" -> "909([Symbol name=convert_element_type])" + "932([Symbol name=convert_element_type])" + "930([Symbol name=convert_element_type])" -> "932([Symbol name=convert_element_type])" + "924([Symbol name=broadcast_in_dim])" + "923([Symbol name=sum])" -> "924([Symbol name=broadcast_in_dim])" + "949([Symbol name=cat])" + "944([Symbol name=slice_prim])" -> "949([Symbol name=cat])" + "948([Symbol name=convert_element_type])" -> "949([Symbol name=cat])" + "946([Symbol name=convert_element_type])" + "945([Symbol name=slice_prim])" -> "946([Symbol name=convert_element_type])" + "961([Symbol name=convert_element_type])" + "960([Symbol name=slice_prim])" -> "961([Symbol name=convert_element_type])" + "964([Symbol name=cat])" + "963([Symbol name=convert_element_type])" -> "964([Symbol name=cat])" + "959([Symbol name=slice_prim])" -> "964([Symbol name=cat])" + "979([Symbol name=reshape])" + "978([Symbol name=transpose])" -> "979([Symbol name=reshape])" + "996([Symbol name=convert_element_type])" + "994([Symbol name=convert_element_type])" -> "996([Symbol name=convert_element_type])" + "988([Symbol name=broadcast_in_dim])" + "987([Symbol name=sum])" -> "988([Symbol name=broadcast_in_dim])" + "1009([Symbol name=convert_element_type])" + "1007([Symbol name=convert_element_type])" -> "1009([Symbol name=convert_element_type])" + "1032([Symbol name=convert_element_type])" + "1030([Symbol name=convert_element_type])" -> "1032([Symbol name=convert_element_type])" + "1024([Symbol name=broadcast_in_dim])" + "1023([Symbol name=sum])" -> "1024([Symbol name=broadcast_in_dim])" + "1049([Symbol name=cat])" + "1048([Symbol name=convert_element_type])" -> "1049([Symbol name=cat])" + "1044([Symbol name=slice_prim])" -> "1049([Symbol name=cat])" + "1046([Symbol name=convert_element_type])" + "1045([Symbol name=slice_prim])" -> "1046([Symbol name=convert_element_type])" + "1064([Symbol name=cat])" + "1059([Symbol name=slice_prim])" -> "1064([Symbol name=cat])" + "1063([Symbol name=convert_element_type])" -> "1064([Symbol name=cat])" + "1061([Symbol name=convert_element_type])" + "1060([Symbol name=slice_prim])" -> "1061([Symbol name=convert_element_type])" + "1079([Symbol name=reshape])" + "1078([Symbol name=transpose])" -> "1079([Symbol name=reshape])" + "1096([Symbol name=convert_element_type])" + "1094([Symbol name=convert_element_type])" -> "1096([Symbol name=convert_element_type])" + "1088([Symbol name=broadcast_in_dim])" + "1087([Symbol name=sum])" -> "1088([Symbol name=broadcast_in_dim])" + "1109([Symbol name=convert_element_type])" + "1107([Symbol name=convert_element_type])" -> "1109([Symbol name=convert_element_type])" + "1132([Symbol name=convert_element_type])" + "1130([Symbol name=convert_element_type])" -> "1132([Symbol name=convert_element_type])" + "1124([Symbol name=broadcast_in_dim])" + "1123([Symbol name=sum])" -> "1124([Symbol name=broadcast_in_dim])" + "1149([Symbol name=cat])" + "1144([Symbol name=slice_prim])" -> "1149([Symbol name=cat])" + "1148([Symbol name=convert_element_type])" -> "1149([Symbol name=cat])" + "1146([Symbol name=convert_element_type])" + "1145([Symbol name=slice_prim])" -> "1146([Symbol name=convert_element_type])" + "1161([Symbol name=convert_element_type])" + "1160([Symbol name=slice_prim])" -> "1161([Symbol name=convert_element_type])" + "1164([Symbol name=cat])" + "1163([Symbol name=convert_element_type])" -> "1164([Symbol name=cat])" + "1159([Symbol name=slice_prim])" -> "1164([Symbol name=cat])" + "1179([Symbol name=reshape])" + "1178([Symbol name=transpose])" -> "1179([Symbol name=reshape])" + "1196([Symbol name=convert_element_type])" + "1194([Symbol name=convert_element_type])" -> "1196([Symbol name=convert_element_type])" + "1188([Symbol name=broadcast_in_dim])" + "1187([Symbol name=sum])" -> "1188([Symbol name=broadcast_in_dim])" + "1209([Symbol name=convert_element_type])" + "1207([Symbol name=convert_element_type])" -> "1209([Symbol name=convert_element_type])" + "1232([Symbol name=convert_element_type])" + "1230([Symbol name=convert_element_type])" -> "1232([Symbol name=convert_element_type])" + "1224([Symbol name=broadcast_in_dim])" + "1223([Symbol name=sum])" -> "1224([Symbol name=broadcast_in_dim])" + "1249([Symbol name=cat])" + "1248([Symbol name=convert_element_type])" -> "1249([Symbol name=cat])" + "1244([Symbol name=slice_prim])" -> "1249([Symbol name=cat])" + "1246([Symbol name=convert_element_type])" + "1245([Symbol name=slice_prim])" -> "1246([Symbol name=convert_element_type])" + "1264([Symbol name=cat])" + "1259([Symbol name=slice_prim])" -> "1264([Symbol name=cat])" + "1263([Symbol name=convert_element_type])" -> "1264([Symbol name=cat])" + "1261([Symbol name=convert_element_type])" + "1260([Symbol name=slice_prim])" -> "1261([Symbol name=convert_element_type])" + "1279([Symbol name=reshape])" + "1278([Symbol name=transpose])" -> "1279([Symbol name=reshape])" + "1296([Symbol name=convert_element_type])" + "1294([Symbol name=convert_element_type])" -> "1296([Symbol name=convert_element_type])" + "1288([Symbol name=broadcast_in_dim])" + "1287([Symbol name=sum])" -> "1288([Symbol name=broadcast_in_dim])" + "1309([Symbol name=convert_element_type])" + "1307([Symbol name=convert_element_type])" -> "1309([Symbol name=convert_element_type])" + "1332([Symbol name=convert_element_type])" + "1330([Symbol name=convert_element_type])" -> "1332([Symbol name=convert_element_type])" + "1324([Symbol name=broadcast_in_dim])" + "1323([Symbol name=sum])" -> "1324([Symbol name=broadcast_in_dim])" + "1349([Symbol name=cat])" + "1344([Symbol name=slice_prim])" -> "1349([Symbol name=cat])" + "1348([Symbol name=convert_element_type])" -> "1349([Symbol name=cat])" + "1346([Symbol name=convert_element_type])" + "1345([Symbol name=slice_prim])" -> "1346([Symbol name=convert_element_type])" + "1361([Symbol name=convert_element_type])" + "1360([Symbol name=slice_prim])" -> "1361([Symbol name=convert_element_type])" + "1364([Symbol name=cat])" + "1363([Symbol name=convert_element_type])" -> "1364([Symbol name=cat])" + "1359([Symbol name=slice_prim])" -> "1364([Symbol name=cat])" + "1379([Symbol name=reshape])" + "1378([Symbol name=transpose])" -> "1379([Symbol name=reshape])" + "1396([Symbol name=convert_element_type])" + "1394([Symbol name=convert_element_type])" -> "1396([Symbol name=convert_element_type])" + "1388([Symbol name=broadcast_in_dim])" + "1387([Symbol name=sum])" -> "1388([Symbol name=broadcast_in_dim])" + "1409([Symbol name=convert_element_type])" + "1407([Symbol name=convert_element_type])" -> "1409([Symbol name=convert_element_type])" + "1432([Symbol name=convert_element_type])" + "1430([Symbol name=convert_element_type])" -> "1432([Symbol name=convert_element_type])" + "1424([Symbol name=broadcast_in_dim])" + "1423([Symbol name=sum])" -> "1424([Symbol name=broadcast_in_dim])" + "1449([Symbol name=cat])" + "1448([Symbol name=convert_element_type])" -> "1449([Symbol name=cat])" + "1444([Symbol name=slice_prim])" -> "1449([Symbol name=cat])" + "1446([Symbol name=convert_element_type])" + "1445([Symbol name=slice_prim])" -> "1446([Symbol name=convert_element_type])" + "1464([Symbol name=cat])" + "1459([Symbol name=slice_prim])" -> "1464([Symbol name=cat])" + "1463([Symbol name=convert_element_type])" -> "1464([Symbol name=cat])" + "1461([Symbol name=convert_element_type])" + "1460([Symbol name=slice_prim])" -> "1461([Symbol name=convert_element_type])" + "1479([Symbol name=reshape])" + "1478([Symbol name=transpose])" -> "1479([Symbol name=reshape])" + "1496([Symbol name=convert_element_type])" + "1494([Symbol name=convert_element_type])" -> "1496([Symbol name=convert_element_type])" + "1488([Symbol name=broadcast_in_dim])" + "1487([Symbol name=sum])" -> "1488([Symbol name=broadcast_in_dim])" + "1509([Symbol name=convert_element_type])" + "1507([Symbol name=convert_element_type])" -> "1509([Symbol name=convert_element_type])" + "1532([Symbol name=convert_element_type])" + "1530([Symbol name=convert_element_type])" -> "1532([Symbol name=convert_element_type])" + "1524([Symbol name=broadcast_in_dim])" + "1523([Symbol name=sum])" -> "1524([Symbol name=broadcast_in_dim])" + "1549([Symbol name=cat])" + "1544([Symbol name=slice_prim])" -> "1549([Symbol name=cat])" + "1548([Symbol name=convert_element_type])" -> "1549([Symbol name=cat])" + "1546([Symbol name=convert_element_type])" + "1545([Symbol name=slice_prim])" -> "1546([Symbol name=convert_element_type])" + "1561([Symbol name=convert_element_type])" + "1560([Symbol name=slice_prim])" -> "1561([Symbol name=convert_element_type])" + "1564([Symbol name=cat])" + "1563([Symbol name=convert_element_type])" -> "1564([Symbol name=cat])" + "1559([Symbol name=slice_prim])" -> "1564([Symbol name=cat])" + "1579([Symbol name=reshape])" + "1578([Symbol name=transpose])" -> "1579([Symbol name=reshape])" + "1596([Symbol name=convert_element_type])" + "1594([Symbol name=convert_element_type])" -> "1596([Symbol name=convert_element_type])" + "1588([Symbol name=broadcast_in_dim])" + "1587([Symbol name=sum])" -> "1588([Symbol name=broadcast_in_dim])" + "1609([Symbol name=convert_element_type])" + "1607([Symbol name=convert_element_type])" -> "1609([Symbol name=convert_element_type])" + "1632([Symbol name=convert_element_type])" + "1630([Symbol name=convert_element_type])" -> "1632([Symbol name=convert_element_type])" + "1624([Symbol name=broadcast_in_dim])" + "1623([Symbol name=sum])" -> "1624([Symbol name=broadcast_in_dim])" + "1649([Symbol name=cat])" + "1648([Symbol name=convert_element_type])" -> "1649([Symbol name=cat])" + "1644([Symbol name=slice_prim])" -> "1649([Symbol name=cat])" + "1646([Symbol name=convert_element_type])" + "1645([Symbol name=slice_prim])" -> "1646([Symbol name=convert_element_type])" + "1664([Symbol name=cat])" + "1659([Symbol name=slice_prim])" -> "1664([Symbol name=cat])" + "1663([Symbol name=convert_element_type])" -> "1664([Symbol name=cat])" + "1661([Symbol name=convert_element_type])" + "1660([Symbol name=slice_prim])" -> "1661([Symbol name=convert_element_type])" + "1679([Symbol name=reshape])" + "1678([Symbol name=transpose])" -> "1679([Symbol name=reshape])" + "1696([Symbol name=convert_element_type])" + "1694([Symbol name=convert_element_type])" -> "1696([Symbol name=convert_element_type])" + "1688([Symbol name=broadcast_in_dim])" + "1687([Symbol name=sum])" -> "1688([Symbol name=broadcast_in_dim])" + "1709([Symbol name=convert_element_type])" + "1707([Symbol name=convert_element_type])" -> "1709([Symbol name=convert_element_type])" + "1732([Symbol name=convert_element_type])" + "1730([Symbol name=convert_element_type])" -> "1732([Symbol name=convert_element_type])" + "1724([Symbol name=broadcast_in_dim])" + "1723([Symbol name=sum])" -> "1724([Symbol name=broadcast_in_dim])" + "128([Symbol name=broadcast_in_dim])" + "127([Symbol name=rsqrt])" -> "128([Symbol name=broadcast_in_dim])" + "189([Symbol name=true_divide])" + "188([Symbol name=broadcast_in_dim])" -> "189([Symbol name=true_divide])" + "154([Symbol name=convert_element_type])" + "149([Symbol name=cat])" -> "154([Symbol name=convert_element_type])" + "147([Symbol name=neg])" + "146([Symbol name=convert_element_type])" -> "147([Symbol name=neg])" + "162([Symbol name=neg])" + "161([Symbol name=convert_element_type])" -> "162([Symbol name=neg])" + "169([Symbol name=convert_element_type])" + "164([Symbol name=cat])" -> "169([Symbol name=convert_element_type])" + "225([Symbol name=true_divide])" + "224([Symbol name=broadcast_in_dim])" -> "225([Symbol name=true_divide])" + "254([Symbol name=convert_element_type])" + "249([Symbol name=cat])" -> "254([Symbol name=convert_element_type])" + "247([Symbol name=neg])" + "246([Symbol name=convert_element_type])" -> "247([Symbol name=neg])" + "269([Symbol name=convert_element_type])" + "264([Symbol name=cat])" -> "269([Symbol name=convert_element_type])" + "262([Symbol name=neg])" + "261([Symbol name=convert_element_type])" -> "262([Symbol name=neg])" + "289([Symbol name=true_divide])" + "288([Symbol name=broadcast_in_dim])" -> "289([Symbol name=true_divide])" + "325([Symbol name=true_divide])" + "324([Symbol name=broadcast_in_dim])" -> "325([Symbol name=true_divide])" + "354([Symbol name=convert_element_type])" + "349([Symbol name=cat])" -> "354([Symbol name=convert_element_type])" + "347([Symbol name=neg])" + "346([Symbol name=convert_element_type])" -> "347([Symbol name=neg])" + "362([Symbol name=neg])" + "361([Symbol name=convert_element_type])" -> "362([Symbol name=neg])" + "369([Symbol name=convert_element_type])" + "364([Symbol name=cat])" -> "369([Symbol name=convert_element_type])" + "389([Symbol name=true_divide])" + "388([Symbol name=broadcast_in_dim])" -> "389([Symbol name=true_divide])" + "425([Symbol name=true_divide])" + "424([Symbol name=broadcast_in_dim])" -> "425([Symbol name=true_divide])" + "454([Symbol name=convert_element_type])" + "449([Symbol name=cat])" -> "454([Symbol name=convert_element_type])" + "447([Symbol name=neg])" + "446([Symbol name=convert_element_type])" -> "447([Symbol name=neg])" + "469([Symbol name=convert_element_type])" + "464([Symbol name=cat])" -> "469([Symbol name=convert_element_type])" + "462([Symbol name=neg])" + "461([Symbol name=convert_element_type])" -> "462([Symbol name=neg])" + "489([Symbol name=true_divide])" + "488([Symbol name=broadcast_in_dim])" -> "489([Symbol name=true_divide])" + "525([Symbol name=true_divide])" + "524([Symbol name=broadcast_in_dim])" -> "525([Symbol name=true_divide])" + "554([Symbol name=convert_element_type])" + "549([Symbol name=cat])" -> "554([Symbol name=convert_element_type])" + "547([Symbol name=neg])" + "546([Symbol name=convert_element_type])" -> "547([Symbol name=neg])" + "562([Symbol name=neg])" + "561([Symbol name=convert_element_type])" -> "562([Symbol name=neg])" + "569([Symbol name=convert_element_type])" + "564([Symbol name=cat])" -> "569([Symbol name=convert_element_type])" + "589([Symbol name=true_divide])" + "588([Symbol name=broadcast_in_dim])" -> "589([Symbol name=true_divide])" + "625([Symbol name=true_divide])" + "624([Symbol name=broadcast_in_dim])" -> "625([Symbol name=true_divide])" + "654([Symbol name=convert_element_type])" + "649([Symbol name=cat])" -> "654([Symbol name=convert_element_type])" + "647([Symbol name=neg])" + "646([Symbol name=convert_element_type])" -> "647([Symbol name=neg])" + "669([Symbol name=convert_element_type])" + "664([Symbol name=cat])" -> "669([Symbol name=convert_element_type])" + "662([Symbol name=neg])" + "661([Symbol name=convert_element_type])" -> "662([Symbol name=neg])" + "689([Symbol name=true_divide])" + "688([Symbol name=broadcast_in_dim])" -> "689([Symbol name=true_divide])" + "725([Symbol name=true_divide])" + "724([Symbol name=broadcast_in_dim])" -> "725([Symbol name=true_divide])" + "754([Symbol name=convert_element_type])" + "749([Symbol name=cat])" -> "754([Symbol name=convert_element_type])" + "747([Symbol name=neg])" + "746([Symbol name=convert_element_type])" -> "747([Symbol name=neg])" + "762([Symbol name=neg])" + "761([Symbol name=convert_element_type])" -> "762([Symbol name=neg])" + "769([Symbol name=convert_element_type])" + "764([Symbol name=cat])" -> "769([Symbol name=convert_element_type])" + "789([Symbol name=true_divide])" + "788([Symbol name=broadcast_in_dim])" -> "789([Symbol name=true_divide])" + "825([Symbol name=true_divide])" + "824([Symbol name=broadcast_in_dim])" -> "825([Symbol name=true_divide])" + "854([Symbol name=convert_element_type])" + "849([Symbol name=cat])" -> "854([Symbol name=convert_element_type])" + "847([Symbol name=neg])" + "846([Symbol name=convert_element_type])" -> "847([Symbol name=neg])" + "869([Symbol name=convert_element_type])" + "864([Symbol name=cat])" -> "869([Symbol name=convert_element_type])" + "862([Symbol name=neg])" + "861([Symbol name=convert_element_type])" -> "862([Symbol name=neg])" + "889([Symbol name=true_divide])" + "888([Symbol name=broadcast_in_dim])" -> "889([Symbol name=true_divide])" + "925([Symbol name=true_divide])" + "924([Symbol name=broadcast_in_dim])" -> "925([Symbol name=true_divide])" + "954([Symbol name=convert_element_type])" + "949([Symbol name=cat])" -> "954([Symbol name=convert_element_type])" + "947([Symbol name=neg])" + "946([Symbol name=convert_element_type])" -> "947([Symbol name=neg])" + "962([Symbol name=neg])" + "961([Symbol name=convert_element_type])" -> "962([Symbol name=neg])" + "969([Symbol name=convert_element_type])" + "964([Symbol name=cat])" -> "969([Symbol name=convert_element_type])" + "989([Symbol name=true_divide])" + "988([Symbol name=broadcast_in_dim])" -> "989([Symbol name=true_divide])" + "1025([Symbol name=true_divide])" + "1024([Symbol name=broadcast_in_dim])" -> "1025([Symbol name=true_divide])" + "1054([Symbol name=convert_element_type])" + "1049([Symbol name=cat])" -> "1054([Symbol name=convert_element_type])" + "1047([Symbol name=neg])" + "1046([Symbol name=convert_element_type])" -> "1047([Symbol name=neg])" + "1069([Symbol name=convert_element_type])" + "1064([Symbol name=cat])" -> "1069([Symbol name=convert_element_type])" + "1062([Symbol name=neg])" + "1061([Symbol name=convert_element_type])" -> "1062([Symbol name=neg])" + "1089([Symbol name=true_divide])" + "1088([Symbol name=broadcast_in_dim])" -> "1089([Symbol name=true_divide])" + "1125([Symbol name=true_divide])" + "1124([Symbol name=broadcast_in_dim])" -> "1125([Symbol name=true_divide])" + "1154([Symbol name=convert_element_type])" + "1149([Symbol name=cat])" -> "1154([Symbol name=convert_element_type])" + "1147([Symbol name=neg])" + "1146([Symbol name=convert_element_type])" -> "1147([Symbol name=neg])" + "1162([Symbol name=neg])" + "1161([Symbol name=convert_element_type])" -> "1162([Symbol name=neg])" + "1169([Symbol name=convert_element_type])" + "1164([Symbol name=cat])" -> "1169([Symbol name=convert_element_type])" + "1189([Symbol name=true_divide])" + "1188([Symbol name=broadcast_in_dim])" -> "1189([Symbol name=true_divide])" + "1225([Symbol name=true_divide])" + "1224([Symbol name=broadcast_in_dim])" -> "1225([Symbol name=true_divide])" + "1254([Symbol name=convert_element_type])" + "1249([Symbol name=cat])" -> "1254([Symbol name=convert_element_type])" + "1247([Symbol name=neg])" + "1246([Symbol name=convert_element_type])" -> "1247([Symbol name=neg])" + "1269([Symbol name=convert_element_type])" + "1264([Symbol name=cat])" -> "1269([Symbol name=convert_element_type])" + "1262([Symbol name=neg])" + "1261([Symbol name=convert_element_type])" -> "1262([Symbol name=neg])" + "1289([Symbol name=true_divide])" + "1288([Symbol name=broadcast_in_dim])" -> "1289([Symbol name=true_divide])" + "1325([Symbol name=true_divide])" + "1324([Symbol name=broadcast_in_dim])" -> "1325([Symbol name=true_divide])" + "1354([Symbol name=convert_element_type])" + "1349([Symbol name=cat])" -> "1354([Symbol name=convert_element_type])" + "1347([Symbol name=neg])" + "1346([Symbol name=convert_element_type])" -> "1347([Symbol name=neg])" + "1362([Symbol name=neg])" + "1361([Symbol name=convert_element_type])" -> "1362([Symbol name=neg])" + "1369([Symbol name=convert_element_type])" + "1364([Symbol name=cat])" -> "1369([Symbol name=convert_element_type])" + "1389([Symbol name=true_divide])" + "1388([Symbol name=broadcast_in_dim])" -> "1389([Symbol name=true_divide])" + "1425([Symbol name=true_divide])" + "1424([Symbol name=broadcast_in_dim])" -> "1425([Symbol name=true_divide])" + "1454([Symbol name=convert_element_type])" + "1449([Symbol name=cat])" -> "1454([Symbol name=convert_element_type])" + "1447([Symbol name=neg])" + "1446([Symbol name=convert_element_type])" -> "1447([Symbol name=neg])" + "1469([Symbol name=convert_element_type])" + "1464([Symbol name=cat])" -> "1469([Symbol name=convert_element_type])" + "1462([Symbol name=neg])" + "1461([Symbol name=convert_element_type])" -> "1462([Symbol name=neg])" + "1489([Symbol name=true_divide])" + "1488([Symbol name=broadcast_in_dim])" -> "1489([Symbol name=true_divide])" + "1525([Symbol name=true_divide])" + "1524([Symbol name=broadcast_in_dim])" -> "1525([Symbol name=true_divide])" + "1554([Symbol name=convert_element_type])" + "1549([Symbol name=cat])" -> "1554([Symbol name=convert_element_type])" + "1547([Symbol name=neg])" + "1546([Symbol name=convert_element_type])" -> "1547([Symbol name=neg])" + "1562([Symbol name=neg])" + "1561([Symbol name=convert_element_type])" -> "1562([Symbol name=neg])" + "1569([Symbol name=convert_element_type])" + "1564([Symbol name=cat])" -> "1569([Symbol name=convert_element_type])" + "1589([Symbol name=true_divide])" + "1588([Symbol name=broadcast_in_dim])" -> "1589([Symbol name=true_divide])" + "1625([Symbol name=true_divide])" + "1624([Symbol name=broadcast_in_dim])" -> "1625([Symbol name=true_divide])" + "1654([Symbol name=convert_element_type])" + "1649([Symbol name=cat])" -> "1654([Symbol name=convert_element_type])" + "1647([Symbol name=neg])" + "1646([Symbol name=convert_element_type])" -> "1647([Symbol name=neg])" + "1669([Symbol name=convert_element_type])" + "1664([Symbol name=cat])" -> "1669([Symbol name=convert_element_type])" + "1662([Symbol name=neg])" + "1661([Symbol name=convert_element_type])" -> "1662([Symbol name=neg])" + "1689([Symbol name=true_divide])" + "1688([Symbol name=broadcast_in_dim])" -> "1689([Symbol name=true_divide])" + "1725([Symbol name=true_divide])" + "1724([Symbol name=broadcast_in_dim])" -> "1725([Symbol name=true_divide])" + "190([Symbol name=add])" + "189([Symbol name=true_divide])" -> "190([Symbol name=add])" + "148([Symbol name=convert_element_type])" + "147([Symbol name=neg])" -> "148([Symbol name=convert_element_type])" + "163([Symbol name=convert_element_type])" + "162([Symbol name=neg])" -> "163([Symbol name=convert_element_type])" + "226([Symbol name=add])" + "225([Symbol name=true_divide])" -> "226([Symbol name=add])" + "248([Symbol name=convert_element_type])" + "247([Symbol name=neg])" -> "248([Symbol name=convert_element_type])" + "263([Symbol name=convert_element_type])" + "262([Symbol name=neg])" -> "263([Symbol name=convert_element_type])" + "290([Symbol name=add])" + "289([Symbol name=true_divide])" -> "290([Symbol name=add])" + "326([Symbol name=add])" + "325([Symbol name=true_divide])" -> "326([Symbol name=add])" + "348([Symbol name=convert_element_type])" + "347([Symbol name=neg])" -> "348([Symbol name=convert_element_type])" + "363([Symbol name=convert_element_type])" + "362([Symbol name=neg])" -> "363([Symbol name=convert_element_type])" + "390([Symbol name=add])" + "389([Symbol name=true_divide])" -> "390([Symbol name=add])" + "426([Symbol name=add])" + "425([Symbol name=true_divide])" -> "426([Symbol name=add])" + "448([Symbol name=convert_element_type])" + "447([Symbol name=neg])" -> "448([Symbol name=convert_element_type])" + "463([Symbol name=convert_element_type])" + "462([Symbol name=neg])" -> "463([Symbol name=convert_element_type])" + "490([Symbol name=add])" + "489([Symbol name=true_divide])" -> "490([Symbol name=add])" + "526([Symbol name=add])" + "525([Symbol name=true_divide])" -> "526([Symbol name=add])" + "548([Symbol name=convert_element_type])" + "547([Symbol name=neg])" -> "548([Symbol name=convert_element_type])" + "563([Symbol name=convert_element_type])" + "562([Symbol name=neg])" -> "563([Symbol name=convert_element_type])" + "590([Symbol name=add])" + "589([Symbol name=true_divide])" -> "590([Symbol name=add])" + "626([Symbol name=add])" + "625([Symbol name=true_divide])" -> "626([Symbol name=add])" + "648([Symbol name=convert_element_type])" + "647([Symbol name=neg])" -> "648([Symbol name=convert_element_type])" + "663([Symbol name=convert_element_type])" + "662([Symbol name=neg])" -> "663([Symbol name=convert_element_type])" + "690([Symbol name=add])" + "689([Symbol name=true_divide])" -> "690([Symbol name=add])" + "726([Symbol name=add])" + "725([Symbol name=true_divide])" -> "726([Symbol name=add])" + "748([Symbol name=convert_element_type])" + "747([Symbol name=neg])" -> "748([Symbol name=convert_element_type])" + "763([Symbol name=convert_element_type])" + "762([Symbol name=neg])" -> "763([Symbol name=convert_element_type])" + "790([Symbol name=add])" + "789([Symbol name=true_divide])" -> "790([Symbol name=add])" + "826([Symbol name=add])" + "825([Symbol name=true_divide])" -> "826([Symbol name=add])" + "848([Symbol name=convert_element_type])" + "847([Symbol name=neg])" -> "848([Symbol name=convert_element_type])" + "863([Symbol name=convert_element_type])" + "862([Symbol name=neg])" -> "863([Symbol name=convert_element_type])" + "890([Symbol name=add])" + "889([Symbol name=true_divide])" -> "890([Symbol name=add])" + "926([Symbol name=add])" + "925([Symbol name=true_divide])" -> "926([Symbol name=add])" + "948([Symbol name=convert_element_type])" + "947([Symbol name=neg])" -> "948([Symbol name=convert_element_type])" + "963([Symbol name=convert_element_type])" + "962([Symbol name=neg])" -> "963([Symbol name=convert_element_type])" + "990([Symbol name=add])" + "989([Symbol name=true_divide])" -> "990([Symbol name=add])" + "1026([Symbol name=add])" + "1025([Symbol name=true_divide])" -> "1026([Symbol name=add])" + "1048([Symbol name=convert_element_type])" + "1047([Symbol name=neg])" -> "1048([Symbol name=convert_element_type])" + "1063([Symbol name=convert_element_type])" + "1062([Symbol name=neg])" -> "1063([Symbol name=convert_element_type])" + "1090([Symbol name=add])" + "1089([Symbol name=true_divide])" -> "1090([Symbol name=add])" + "1126([Symbol name=add])" + "1125([Symbol name=true_divide])" -> "1126([Symbol name=add])" + "1148([Symbol name=convert_element_type])" + "1147([Symbol name=neg])" -> "1148([Symbol name=convert_element_type])" + "1163([Symbol name=convert_element_type])" + "1162([Symbol name=neg])" -> "1163([Symbol name=convert_element_type])" + "1190([Symbol name=add])" + "1189([Symbol name=true_divide])" -> "1190([Symbol name=add])" + "1226([Symbol name=add])" + "1225([Symbol name=true_divide])" -> "1226([Symbol name=add])" + "1248([Symbol name=convert_element_type])" + "1247([Symbol name=neg])" -> "1248([Symbol name=convert_element_type])" + "1263([Symbol name=convert_element_type])" + "1262([Symbol name=neg])" -> "1263([Symbol name=convert_element_type])" + "1290([Symbol name=add])" + "1289([Symbol name=true_divide])" -> "1290([Symbol name=add])" + "1326([Symbol name=add])" + "1325([Symbol name=true_divide])" -> "1326([Symbol name=add])" + "1348([Symbol name=convert_element_type])" + "1347([Symbol name=neg])" -> "1348([Symbol name=convert_element_type])" + "1363([Symbol name=convert_element_type])" + "1362([Symbol name=neg])" -> "1363([Symbol name=convert_element_type])" + "1390([Symbol name=add])" + "1389([Symbol name=true_divide])" -> "1390([Symbol name=add])" + "1426([Symbol name=add])" + "1425([Symbol name=true_divide])" -> "1426([Symbol name=add])" + "1448([Symbol name=convert_element_type])" + "1447([Symbol name=neg])" -> "1448([Symbol name=convert_element_type])" + "1463([Symbol name=convert_element_type])" + "1462([Symbol name=neg])" -> "1463([Symbol name=convert_element_type])" + "1490([Symbol name=add])" + "1489([Symbol name=true_divide])" -> "1490([Symbol name=add])" + "1526([Symbol name=add])" + "1525([Symbol name=true_divide])" -> "1526([Symbol name=add])" + "1548([Symbol name=convert_element_type])" + "1547([Symbol name=neg])" -> "1548([Symbol name=convert_element_type])" + "1563([Symbol name=convert_element_type])" + "1562([Symbol name=neg])" -> "1563([Symbol name=convert_element_type])" + "1590([Symbol name=add])" + "1589([Symbol name=true_divide])" -> "1590([Symbol name=add])" + "1626([Symbol name=add])" + "1625([Symbol name=true_divide])" -> "1626([Symbol name=add])" + "1648([Symbol name=convert_element_type])" + "1647([Symbol name=neg])" -> "1648([Symbol name=convert_element_type])" + "1663([Symbol name=convert_element_type])" + "1662([Symbol name=neg])" -> "1663([Symbol name=convert_element_type])" + "1690([Symbol name=add])" + "1689([Symbol name=true_divide])" -> "1690([Symbol name=add])" + "1726([Symbol name=add])" + "1725([Symbol name=true_divide])" -> "1726([Symbol name=add])" + "191([Symbol name=rsqrt])" + "190([Symbol name=add])" -> "191([Symbol name=rsqrt])" + "227([Symbol name=rsqrt])" + "226([Symbol name=add])" -> "227([Symbol name=rsqrt])" + "291([Symbol name=rsqrt])" + "290([Symbol name=add])" -> "291([Symbol name=rsqrt])" + "327([Symbol name=rsqrt])" + "326([Symbol name=add])" -> "327([Symbol name=rsqrt])" + "391([Symbol name=rsqrt])" + "390([Symbol name=add])" -> "391([Symbol name=rsqrt])" + "427([Symbol name=rsqrt])" + "426([Symbol name=add])" -> "427([Symbol name=rsqrt])" + "491([Symbol name=rsqrt])" + "490([Symbol name=add])" -> "491([Symbol name=rsqrt])" + "527([Symbol name=rsqrt])" + "526([Symbol name=add])" -> "527([Symbol name=rsqrt])" + "591([Symbol name=rsqrt])" + "590([Symbol name=add])" -> "591([Symbol name=rsqrt])" + "627([Symbol name=rsqrt])" + "626([Symbol name=add])" -> "627([Symbol name=rsqrt])" + "691([Symbol name=rsqrt])" + "690([Symbol name=add])" -> "691([Symbol name=rsqrt])" + "727([Symbol name=rsqrt])" + "726([Symbol name=add])" -> "727([Symbol name=rsqrt])" + "791([Symbol name=rsqrt])" + "790([Symbol name=add])" -> "791([Symbol name=rsqrt])" + "827([Symbol name=rsqrt])" + "826([Symbol name=add])" -> "827([Symbol name=rsqrt])" + "891([Symbol name=rsqrt])" + "890([Symbol name=add])" -> "891([Symbol name=rsqrt])" + "927([Symbol name=rsqrt])" + "926([Symbol name=add])" -> "927([Symbol name=rsqrt])" + "991([Symbol name=rsqrt])" + "990([Symbol name=add])" -> "991([Symbol name=rsqrt])" + "1027([Symbol name=rsqrt])" + "1026([Symbol name=add])" -> "1027([Symbol name=rsqrt])" + "1091([Symbol name=rsqrt])" + "1090([Symbol name=add])" -> "1091([Symbol name=rsqrt])" + "1127([Symbol name=rsqrt])" + "1126([Symbol name=add])" -> "1127([Symbol name=rsqrt])" + "1191([Symbol name=rsqrt])" + "1190([Symbol name=add])" -> "1191([Symbol name=rsqrt])" + "1227([Symbol name=rsqrt])" + "1226([Symbol name=add])" -> "1227([Symbol name=rsqrt])" + "1291([Symbol name=rsqrt])" + "1290([Symbol name=add])" -> "1291([Symbol name=rsqrt])" + "1327([Symbol name=rsqrt])" + "1326([Symbol name=add])" -> "1327([Symbol name=rsqrt])" + "1391([Symbol name=rsqrt])" + "1390([Symbol name=add])" -> "1391([Symbol name=rsqrt])" + "1427([Symbol name=rsqrt])" + "1426([Symbol name=add])" -> "1427([Symbol name=rsqrt])" + "1491([Symbol name=rsqrt])" + "1490([Symbol name=add])" -> "1491([Symbol name=rsqrt])" + "1527([Symbol name=rsqrt])" + "1526([Symbol name=add])" -> "1527([Symbol name=rsqrt])" + "1591([Symbol name=rsqrt])" + "1590([Symbol name=add])" -> "1591([Symbol name=rsqrt])" + "1627([Symbol name=rsqrt])" + "1626([Symbol name=add])" -> "1627([Symbol name=rsqrt])" + "1691([Symbol name=rsqrt])" + "1690([Symbol name=add])" -> "1691([Symbol name=rsqrt])" + "1727([Symbol name=rsqrt])" + "1726([Symbol name=add])" -> "1727([Symbol name=rsqrt])" + "192([Symbol name=broadcast_in_dim])" + "191([Symbol name=rsqrt])" -> "192([Symbol name=broadcast_in_dim])" + "228([Symbol name=broadcast_in_dim])" + "227([Symbol name=rsqrt])" -> "228([Symbol name=broadcast_in_dim])" + "292([Symbol name=broadcast_in_dim])" + "291([Symbol name=rsqrt])" -> "292([Symbol name=broadcast_in_dim])" + "328([Symbol name=broadcast_in_dim])" + "327([Symbol name=rsqrt])" -> "328([Symbol name=broadcast_in_dim])" + "392([Symbol name=broadcast_in_dim])" + "391([Symbol name=rsqrt])" -> "392([Symbol name=broadcast_in_dim])" + "428([Symbol name=broadcast_in_dim])" + "427([Symbol name=rsqrt])" -> "428([Symbol name=broadcast_in_dim])" + "492([Symbol name=broadcast_in_dim])" + "491([Symbol name=rsqrt])" -> "492([Symbol name=broadcast_in_dim])" + "528([Symbol name=broadcast_in_dim])" + "527([Symbol name=rsqrt])" -> "528([Symbol name=broadcast_in_dim])" + "592([Symbol name=broadcast_in_dim])" + "591([Symbol name=rsqrt])" -> "592([Symbol name=broadcast_in_dim])" + "628([Symbol name=broadcast_in_dim])" + "627([Symbol name=rsqrt])" -> "628([Symbol name=broadcast_in_dim])" + "692([Symbol name=broadcast_in_dim])" + "691([Symbol name=rsqrt])" -> "692([Symbol name=broadcast_in_dim])" + "728([Symbol name=broadcast_in_dim])" + "727([Symbol name=rsqrt])" -> "728([Symbol name=broadcast_in_dim])" + "792([Symbol name=broadcast_in_dim])" + "791([Symbol name=rsqrt])" -> "792([Symbol name=broadcast_in_dim])" + "828([Symbol name=broadcast_in_dim])" + "827([Symbol name=rsqrt])" -> "828([Symbol name=broadcast_in_dim])" + "892([Symbol name=broadcast_in_dim])" + "891([Symbol name=rsqrt])" -> "892([Symbol name=broadcast_in_dim])" + "928([Symbol name=broadcast_in_dim])" + "927([Symbol name=rsqrt])" -> "928([Symbol name=broadcast_in_dim])" + "992([Symbol name=broadcast_in_dim])" + "991([Symbol name=rsqrt])" -> "992([Symbol name=broadcast_in_dim])" + "1028([Symbol name=broadcast_in_dim])" + "1027([Symbol name=rsqrt])" -> "1028([Symbol name=broadcast_in_dim])" + "1092([Symbol name=broadcast_in_dim])" + "1091([Symbol name=rsqrt])" -> "1092([Symbol name=broadcast_in_dim])" + "1128([Symbol name=broadcast_in_dim])" + "1127([Symbol name=rsqrt])" -> "1128([Symbol name=broadcast_in_dim])" + "1192([Symbol name=broadcast_in_dim])" + "1191([Symbol name=rsqrt])" -> "1192([Symbol name=broadcast_in_dim])" + "1228([Symbol name=broadcast_in_dim])" + "1227([Symbol name=rsqrt])" -> "1228([Symbol name=broadcast_in_dim])" + "1292([Symbol name=broadcast_in_dim])" + "1291([Symbol name=rsqrt])" -> "1292([Symbol name=broadcast_in_dim])" + "1328([Symbol name=broadcast_in_dim])" + "1327([Symbol name=rsqrt])" -> "1328([Symbol name=broadcast_in_dim])" + "1392([Symbol name=broadcast_in_dim])" + "1391([Symbol name=rsqrt])" -> "1392([Symbol name=broadcast_in_dim])" + "1428([Symbol name=broadcast_in_dim])" + "1427([Symbol name=rsqrt])" -> "1428([Symbol name=broadcast_in_dim])" + "1492([Symbol name=broadcast_in_dim])" + "1491([Symbol name=rsqrt])" -> "1492([Symbol name=broadcast_in_dim])" + "1528([Symbol name=broadcast_in_dim])" + "1527([Symbol name=rsqrt])" -> "1528([Symbol name=broadcast_in_dim])" + "1592([Symbol name=broadcast_in_dim])" + "1591([Symbol name=rsqrt])" -> "1592([Symbol name=broadcast_in_dim])" + "1628([Symbol name=broadcast_in_dim])" + "1627([Symbol name=rsqrt])" -> "1628([Symbol name=broadcast_in_dim])" + "1692([Symbol name=broadcast_in_dim])" + "1691([Symbol name=rsqrt])" -> "1692([Symbol name=broadcast_in_dim])" + "1728([Symbol name=broadcast_in_dim])" + "1727([Symbol name=rsqrt])" -> "1728([Symbol name=broadcast_in_dim])" +} diff --git a/examples/dev/forward_trc.pdf b/examples/dev/forward_trc.pdf new file mode 100644 index 0000000000000000000000000000000000000000..a87d5b59743c5075b46d97d46f8b90a14e3a3ad3 GIT binary patch literal 13368 zcma*O1ymf%7VnJ)cY@2{1f9VN7TjHeyF0_+5FmJPcLD@=cMEO_?gV!Y5=hVo$vNkp zd++zH^`>dr|E@0CRo%Td^=ld>F$rc63nwy7{c*)DG6w(%a4@z;=H~~n%9+_)Kr8{A zPm&5U003Z>u(E}ifuHuaMi4VGGZP0>Gh{(QWM>H2%*YPeJ>!eEd^AZTR{M43-sYvz z64N`}AyF0-IB3+Z7)@>+#Dv%~`c!q?zTtt+F)#5Ni?-T!LRb`0;FkSWj4jvUOzGRc z#D(=vzR!;}{l9(!?@hp0N70?>ogSMz*N+*d58PQk;l93^9lipsGhZC&oGwl+oUk=M z?yBvCl;?PqEAJ2uYkPmLa}G;Q-po+drt#vcYvK%9@gS0ygOFTTm$A&)$w9aUPWQpm1HW1ym_M$LU02eEl{$YNnD6wkF|W|pD4&tdG!T$my>s6^0bbu74D8z$2sO(q6!`OlEC56+HXD+U&Bn@SuM#e34cQyM7l{pH=4wY6vD zfWg+UxF{+IjyWymmx%gfMqaMiBAhn8Z?Qqlq38w|nnyw!f>Ssrt6ngzXI1rLqjN}P zE-G~14)uHglKbTCEpNhKJp=EB#YoZW*|lZFLQ@VO4nICyWK}UwFO9yWdA!0$A#(ez zNuf=e=tlA2UiG>fvSK(b!ssOd@Hb0bJAo;A>e4C~@8!{+3bh=iR(KKYKg$ucAR!0*YnlP!R ztNu>v3jan$uf3$cYNOs$$WHVIHvPN1y@$p3h2RekILJa^uSmmOJ2aocZDa=t-7ahF z;6T06K77ZAsSq>D{uSnlw$-)kv|0h^CRAE%gf3_|5&fA14Ff;4CSu#9@dRyuq3Cy( zvRa8O5tZ1ie40`!Z%1k4<`Q^q!0=u}bXH%)4c!$2uRk=Y3y+I+&lKGjLjjj}G?4FO-L$3>Rca$MzSZm~SlqmC=% zh4xkwC?Yo9shVl_`+?J-KlH{Zs|HV0xA$dwT|ix#V=DlnQWcRF8AB zmxhd&6Db>3|LoV^xZt+EYdt-RO^F=;I^1YZv}m{_z;h!!_Qv?HKnurhk{`ceoY zANp53bVeY2C2P84NGQJ!0)qerCiTuy=J=tnV??@K0ej>0xVivYS$wcQEM+!NePwT13X56pr{gUoA4>i}C%f{+_uqd29V=yDNTHy7S|=W47v*z8 zF_=|qRV+*)?ha5cGFo*T5?B=JBtV<%>WKWbxM>^JQQZz(e&#vna^1IR?}`zK&t$PG zboX|b92wsDh|o}kYA3}Fi|=*TUJ-pLSum07Ej?h zsbUsC*bErmc=?21P%@}c5RR+!=3>F10D22gBKIuXeCn|oxIWmdF49QSYFvu1{6Ayz zkV~Q&UbGz=UE_?~p=?;Z<$)4I!ox6#j0M5S`ZHoSFTXI+aLl~opX0}9|8z5B zN}JE@K`C;`CA8L!%}-^`mAnp98LXV?xoQ5xJa%@C`l$c-2QC?UR-vmnmzYQ`^9wuS zZQ!*nY(aVXR?$RPUJ`Po6K-_I7u1IKfb5e>2iSsN#&(T#BDk_?uB24nE>T#1m>a03S4&l=_uzCAvb2i^3jljR3%r-Uo!ejBb9!EtKY-Y^iA-LSm~d*Nk5=OG<>zOqs|f%A3Dh%APH^9V7x><4=FQAYCU0@l zV(l(`5lWqKgZ)&k=-`T#JDVg_CV0Ew}b zw4V`+GFU76gCV=ltQRNh4G#ZKEIHeWMUpnIQaMf9o*|4L6tMwO^@X6w5=mpLQ>Ob= zobSpjkry+SAJ{WsbNrp4u5u3Pw}V{X-5l23r;&=s`;fxC-bS#sNq4CcG&KUyB%HZK zhkR4aiAx15ZhtM$%0@7QRPO}=Cq5Cdjd2>1HPC-#uNDlxL=d=G)D~bR7ESGT>P{Mt zoYQ1zEYHDc)?Gi&O*cMWQ2RQ5q7QX^{5jd;0HT4i&zn(?xtQ8>MwVSo^7VurP4I-J za&W!i0Vg7CFgeX9QF|%Z4tF2Ssyex0h8MYwAL>pI3@rx0AV<7+Y#xz1m4L2pOw@Why z^=J_2=ULz)J;RVAI$aOORL_-|bfK6bI#cd8e;smdZh9a)BhWsoPg}|u^dV6;Hab#_ zjd+w<$B2gH743G1^5G1RskKrhbwzEI{Bl2J zuIzF|=F&A$$r{8Mk@)X?+l-;ul<7nf>Y)hXrFAwU{GV=kIAu{9h#4sJXlW`3{2yuv zoNji+KdHB{1;UeryJ)5>PE@RM3~%F>(Fmufko@&#u~H3QQsT*zx8w{v=0qA?ogEK( z>Im1Hk+Q^_HdeX|6wEcg0=mcQ>CGt{SML&Cc($BPue?CH9P%nI2k>f-n1oMtwUpm% zE+QTx=WMNeFydJ%=CGmc5o!-rw5B@=q8;Ued!PN+Nlhf7|vo@^FLiBsIz8D zI7D#jNeNQLNlXdKB7i2`;m{@gMn(*W)(if*bjK|km?IL@WD-!hB?H(Ohuhx+6leWy@qX1Iby~WeBg{stVvZ~&E zR|=+Y4^FrCpvmXI!}3^o#={_Gq!#E8Z*AXAYU7t$!gwumrdh*o)0lE&c097GRHVW#ZEy987VH| zJuq%5UYV(r$i-o=mYRAm_|7^fT3lFE2qVhUIq7j9nrS{(PCe@a#?C_j=%vm=O(822 z3c7)1h!zv{PFANubPo}9s_{U~;u!m2GLxz>Qf)57JVTZT&01#Pt0B>M$Y7eBF1a#q z1

K@EYSTu{42J<&;O!YX+Nyq#%y_ks9orui`w6H7a?>$u<;JS|;kfo~!eMU5g}| zncGKnwJ)joH(vym`&`fPd>8QRSJ46=b0mtx%!$c821IX_W*#ba<0A9X+WowAov%`@ zG&2(`qVte5Qa7~f{g|@(B3?pH8il)h2D;qPi}U223Z4?1qOaa#E3&=8+B*Cr7X`(; zAY)fkt(ebZgj2RXs_baj_$U46u#AzxZ_SGdB+=wz3tH_lk57s37DFc+!8nQfcJR0# zV5aVmyW-#z!l5;NZGXucn?}#bh~{TvXxz$1 zC=4U^4My^}4JcI7Jk;?~t8p~w4BkKL2*-|G*g*+@KLaGUJxb8ss z(jRv=1XP%yR`<3~LF3ly+~iZWaM(w({Fz*nLD;W-`N<@r*f}AD z3&pAuriE$5v6;x1kK(RZypmVizGMrsJ|qi`57$%V1&`3)4F}A~X7;9kMPE-x&vDyx zu=c!v3fI^+;P<#PFLzn zOG-<~)n~rVWS%HYe)HgzCc@?m#LR*h%lbefOl?gqf-OSVwv~uLFuk6MNkssIBQ5qK zde6`?0S7)c^o(4iPDPH=m`d0ul)Y9e4q~onZNMGC(r`L2vy6(r@w~b0V7dGu)xJ=S z)ykwiB%Q#t@J(7ONl8KMQPysCYO+SMQpoZ#YMcC~Do-yh!eV{F9X%`G#d6H1bNV&j z23{L3YY($_^K70#MPh&{!SS4aA(19sO|JWlbq%F!P^9>KIr{YJxm8mC>f6<^v9W%u zn`{qcn+SQ%JI7z=7+&0Vi>2)ya3A|noR9=tO@-A;r?O_zEnSpNFTds72YC|QaOL36 zC{&sjCtAV}u;J47M$w}O#<#4E4*+Kwin?T8eMY~_nGHYI#ygC7_|fOzkrh&ZLo~}k z8ZsSL1@}sL$QDF*Qyl1LTAkv+tpg*x7<8zFRLikdqAJrSdmJv>M+jX@lhV~J9)yjR zT#=%JJsB%yCgt-@z=EYKSEjv2fIMoR^16q#{H^P*#aN2K81+I&$cVqH(PX#^ZNGR4 zZ%k??FQc=^POw}7;@Do*Sn9-jit}j%`fi5TZhJ6W4D!^MH-vL762l%;C^PD!fe_~w znr~?*iWqqaQ}Gh#zq;u!=tt;nI+%r=5`5JR)qp})rso_+X&xK}WN*;;Esouy7khY@JU>kUMzft{C za`o;K-u95OQ@2w@@cg{Ppv>z!@%mlyz1D5V)uP&)$(%xBuD2D#!~*MEwRD0+X~mh# zbKf29#J{ApRae$$3gCk-A;%w*6z3)=aHGStLzn287QN3B@aw~Jzc49zr_0wl+0mk> zE%Rjc;F@a=b42BSKh}kz@x!22{Rh5F=HGqT>8XuAW!F3q-u>(og$$Q@MC zT_>%U2cgaUg2^^Yp(7<1zq7tDtvAZF!{Um^x5V)_#fD{oMh&Donv^n|UElPZ969?!7c}TL7GvX;6fi&v?zg%G%`X{6-{rHTl9~> zgesG}7)J>}0=B|PhxkG3I8O7pTMqemU*Y#KDDEi3MC=9JlZR23h7?$u>_DU$enB>& z8m4p2JgRO((wkuts5|Dl7DI)`H_9e{Zl*|DJ4|jG%U0`kzv^xwx^DwnyEosMCR6Ux ze_kf>a8HZ+{8*e*8Gpu*nRXkVKSsh+qa7Bi~Tesvl0 zcJ+fq@pNHh`%UgnP^sh6&iu>N+Mo4%=h|xv6J4a4U7`xOVye^A04hc}nv;zU*9Av3 zEq`mfGLJ!|9rGALtFurhsWJy{5P0sFl%AMcIz*r4bXKNgpbO;_-| z(klh}96|a}74VC85%u{uQWsA~l3Q&)66PHDhr7*kb8vD_pGL4*UmnE8{aSRW(4APx zGynrXymf28Ep+aDJmBt>XmRlWB_Z6|VlW8TIc{l0eEg%u;5!3++uPBmb^J>YvrgRa zOczQ`^gnLq74UzCEU<DrJN+tk!#7SLn6sxV|g&2@t?%HIBYI-V!osay;7zLh;oP zb>={mtXs8Ca~)n=oB)RL8MV9TmAEzp>j-}$LJm9@&@kxnF#DTm5a=ksdT6 zVP6<`&5Q0G-ej1(IocnN3P9j!DR1)ZDY&9%1{yW{`9S3DQ?|Z z^oA5h7$Z+s%XXo6R6ic(8FyiGpg$JBQt7XWD520JgnWr%Sr&JRwQot^H*Za}f7{r~ zSTQoSU61GA)(rXqi%J@-_F{7jojevNLbxX(!)8#fren3mYNR9)Efb**)?|q>%6r@k zPnD0G#YZg7;O=~R)i>PY!Abs0jil?^=3Ai;6&ZrR+;3T<6iQb#Yu0KnFInve1@O0A zDaq4{X;m^*0-6NmKG9TQ-RmOQ@((cTZ}mVexPQwAPoSXkq_}Dc~*kC35 zZ4*1FX{@72=TQV9yM-KMPLS{Xz=lpg?8`M1i%AFIHX?;txLz59@=7?(#E~R&88b3s zKTH>634l5=q7~edyPtSu)45z>!i8b_07K9npiwGqpIWb(IV~ahRhueTqa#Jah#QB} zh<_L_NhFDhRjOR|kbv})<3ZS-ZOH)ff}dP%lw2hit9z^)K;OOi>)iUM3fMUtsJYTk zbN>F%>C@JZN)gUmrUy9ur&SHy7QTJ47U&zlE^MIcAw|9_m+p{-){n}+UcUXv^K1Uh zVEv+BJ7RWcGBcYGMMlD{*eGHP&N$!ALpUa{H_fkVMC+6pou9Iy;AhX^mdox7yLC9E zpt6NmbH&Py{KbU>fSf^>cizP;^;9qC5gECwjV#CKUo5bhgMW4b=v*HHnO#b<3s}ei zeaVdID*)LqZq+4oIuJEg6zrd-&b5_P9o>8Pta+*a#wy^@7u&^N30T= z`_8Pju$X^?w)*bki`(HW2AKAq*|7j&uJ~fvj(^u$RhOTxL7H;Fxq&NdmqcN>A|jqz znZyD|k7e66c@S%OSTseory|r@#j)GRM23r1g0>H9i z7xoL<&wa8`A5=753>}m8D6AH?>OF*q3C~fgo6D44H`*wC5ByzMC;f>-na#;Q$ea{t zHSsiy;p_GZ9==m}uXqLcTH8LnREvMOZ=y{am`aFZVrDnk{~@F2h@<(aP1G@k6fb1P zKHm1jo^$^z`#H>ap}^h@b6NZl8A5Qdz!agzq@Gp>Rua}MR)hDex!B5xYxx}%Ciq#H zrSd3|cUizO??#@tSvDUJ>32t%%x1+KRGeZj{qFpH{69@=?J!)apQ(!^HcOPml^)rU za1Of;xbAA0m6lG-F>%ROey>EXOykpY2JFZTfb5eIL@Nm<<-q8i_tuZXU4$o#0gvu# zhyATry2myX3S6Yq0xF$^LR!i>nXoP48_C=@u*lYLElIid{s3ZO8DcRvPmer5&?E$-#Dp!cDfVW3sjbEs5#pk=1Yt{6M!fn2~#odOig^Cb^)KYu~^1c zS{%u_u*MO^{u#(2TD=LBXiSd@GPzU=t)onTISEMW55@nIK%ePBE=OW+V=8nmBRLx5%9Z zyd3ZnktbShQuzidRb78TEWB^p+!~|y4@I_A>o-epP9_bqH*G&&;ajIxhWF+Kg{_@SP|^mC{$4)ctTb&i zji8y{qvO>Yb())!)fou-!SIBSVejP%k2!Ebna!?pu1xfqgH>H@`iD|x1^W}1fYdOZ zUYiqHjGVCNjqg`W13b*LmuBPybEZ!<%~)NMZ~89Pi#W;?nDD~cVZ*#dz%71pTC*eC z889`fsnzs2sEuzQhFk}9+C76281cdbGs~gI}CX@IbCR!HGzmpUNuWs_iyxuD&T3pFmsRrzc zQDxWS(@_G=JJ2`J*~$=A3qM{2e??+Pger?OsH;!Vf`(}GQnu9=5Y0<8qS6OR4p*zA z{>7ynRO)gk^P@hhGY+}hgm5Cth-?XD)uQ6}NL^88WQRLwNcs`h&PCT8A4u9?pH9xG zPe(O6zIT1V?ht9aU$wNwEo7naxfcn#3aY5=$*#2C!P%1RR8vJtSK4Ma%U3rpu1%<4 ztGD#e=ZOKRico@Uy>ppD)u2R@PS^#$wdB5VeLy9(!aoZ3THD>j*2AYZM1)x*+8Q|| zR=E*{4T8SUPf4(%#IlOT(wd+K=Sp6F5CrMVVd5}TVpo&@a>5(#E@P?19+0raC%;@q9Xr;U6!*`h;ZJwOZJm=f_d#sS?sqf3O2)u6p11P~7^;oc#Qp>gMgl_q&Xp zy<)BIXC&0Bhb11<$x@guwe$lUzy2d=GJW0fqs%2~2HH0If8?RYe(b6v=Si?W}3cYf_OE zf}?lxfUM3RPKnbUTiJ+-u6RPKx?IqkeY z>@3Ll(bB(j_!dcX;k1&-=-zu57w}U1zS8$%=5*|S|69epAe-T-a+M^yKoKowZU}|g zbuB8bjXVl42PjaQ@<9zAGK385iA%$Y);jeySayxP@T02y>*ZyG?gXrd z7sjmHG-R~{8}$si1Dj_I9TfDWA;myx&O+^AP0WWS!y2gioVygJ#!VreD`Y zl{Zw>mt((`F5Oc9%3v00Q*Tl4R_}nbXkEfRr_-R_xX2Zro|&E?pC!K{-{Ivl3|Y=_ zOff0+M+{55Fr!+T%FY~Uw3|!#vih!3zSOp>c+n<$gY}@OoCkp|x?Y>ilx(vu{<|B# zAhZ>WDe3EL4%9n|grNRn=~qs9YU{E%eUlM=VSNj)obdQ?YJj3xhpL)fxEvOTUB|?x zqo2v61&?QH!_&(?EJH{yeho>#YVMKT+826>D%P+y(=b40Gi7tr;8MTTI9IWx;jtHZ zwYmK``Eb$e-1@P_`|4=B&{6g_mg-oMlviZd*2-$3rapgL@4C1l`s>HXqT__6>+Ymv zg^%y7MoY3Ji2D$Eoz1+@83SHD&LXej-S(_xQU-rc=6to+brflYd z)z8;c={4QMmKW>g8mzrERP;-+q3pzYquavGhs*583i<4QVHW}$qxVG|@L*Iu1voDx~NP?C_lBORz_sd9@3xtc7ozQjW;N5y`68uLWx%7bZWv0f+zF=eNA!PbJ|aX zXicgMrWZ+>9wkLF0*v&{BBE?Hob4l2eElc-K@Sp7AZ|1vZ@W7p4SnbzRx_}DOcSMH+jJM#S6>5i_E@* zk@a}oQU(7O5>JGc zoVm|(04_5MWYaPKK=|rm|16Ibm`CjCkvCC#h^gFA;?O8sov_2sB0(KuKQ)M8ZQ{`C zkq;2wXgRtTWnWXg(flsA56$33PT@wU^o>8TZ5#1du;}r`<5)!X2P@oYHs$c@MeIdb z{!mNy2k1ozUA^%Afn8+VHo(EP5a^H;!>!)J7u04%LuY^iqLUpwq?l9}d)J;t!X0m7 zety8Hurq_`(I@?yYHO7PtVYFZv-Ot9LzznQbpcSaA|~q97JMX$zT&&$*j>A=W>LHy zUOHY=iZO}7&ZHO%>>+iR*PkBRjm`j{oZb1|sONx=6O3Jp7Gl}Ln z1N<=UAk;r?V|v>_X?_c;@qul(vtg($54Ba@F9%kZ7n_clr!K@vH9+(aRVpXa&1$fM zS2RW5CCQVy#7d?O^_bV>=ZJa5?cBBn$wH8H$mh!(@Epk$fK zm^kbjB?=Udb3cg>yBH=PIN5v&nhmJLiJBc=G`k5jOHu{Fn`TLz%nmM_Y_UX+vQW6a z9_kbuEaMY1pY;TXNSP}r(Zwx|ntCFoCiQ}*(uSrHQ$?uFKLp4%#Ou8BZ=MlFrsYj} zkIZJ3+bai-3Bws36$_J7@-Tm&TEwsX!zDvL+JSa+FK4uAX=om#V19(kW?t77;S>jA z5YPsEIzJvQqKp$#F+V>x&67JXGYx#Po*h3lGx(hZfqhvTObHXe35eBkr=^UMk`rqD zPR+*{HN*nY>cORDj52kZ=)s0JPZ0^W%MLlG867vr{t97RC#1#+9E8NlC4TiQwn&J% zq*#!6EiDH04 zNPyGpFx7lgbmNfgqk&M=5k?!a_wgLDp+Mj2)cLhkc`A*)E+WY?2D^z+LLze)Q+TX3 zQX_$NSwxFC!mrYDY(}YwJr#7Q-UDKYXqgkDNDvt$xG`EOswSLdTidSG$1$?K$1Avp zedv{=PG}rPGo$~VDu151{cEb6o0I*YsdC=`o#F*OPfP#5A<%kpWA*_+)c#YC4Aks-9rR73-3ifJ&q}Ci`zGP>TXDKtY}$3Q$GR%i{RvXTW@%W*ZHRZY2s{) zdwA9+X4K7zO{1i>Z&m@hppzz~xIyQ*SrN?o?Q$_w6TjJG<_03jl&o%jTl%w2uB`N~ zEtrPeGpq2c(@}S^@qwrPX_>7jACK+vMxufZ$%&ol!34at@rF5Wi$T6u!ZzDW2 zUt2I8`>qs@&? zZ;0u=U|2_WiVY10tx@=?!}cvsN+8;${A|ca+%0frBmErHt@{$WTNJ2VL4UE0y@pwI zZAo8olUncBH41J@-`xN9_U}IYn~dQ7r?)}R7{cE(`hPfrXVO8;%-IBNxEG(bEil=OTOR5kvI}HHOpX;9uslScqCO**$W^90GpyA(u z!!yzGzwY>xor^K#H_f5;git)Y>bELl(iB{{P%<;`UG97+Bd`09fBz*$dk{Tm3EkF8ZyNDa6wGi5g)CJvZau3p+0m@Vx$e z;bP|k{J*S!uHj+-?>zsV*MA>#@jQb!|6cx-{%6bi^e+)0fRp#Be2%9**AqDNyytyx z==0H2T|v)}9j>QA@O;F}!SkeZJlSl2o#)_wep+z=LBDm*CphLgANHs9FHGhMcX3d+ zw|f301Mr;Hzx(umhxjv(^OT%Pfqe^D=>e-k}QU z=uM2QzzzTq3y6ya2%xuwKpdT4vp%6iPpSpj$kEcu#F@nbZ1L*PO(6R8kqE@f!Cvec zJE9kR%?4!S0CEFCpeIfV#HIb}-{t*ZT4yqH}{h>o__%Qzce5_7uQoy|Ch$e`Q-2a)PO+HQy>4Q z#{Qi9e`+8$cHmR>|0^92&(n?d9~ub633}?q|JI&I-G6FqK;VDZ1;hqq|L=By*m$1C z_kUmeJd*y=AYh}XdmjAz+a6UbPqSxVJwM%492}nd`w!~$Y)jjlKXvQxA@PS+G6F+> Sdl$t1) diff --git a/examples/dev/my_graph.png b/examples/dev/my_graph.png new file mode 100644 index 0000000000000000000000000000000000000000..08e87a916a60d35ebd037f88de169c4df81e2578 GIT binary patch literal 144 zcmeAS@N?(olHy`uVBq!ia0vp^+#t-s1|(OmDOUqhY)RhkE)4%caKYZ?lYt_xo-U3d z5>wA!*vNapfQR{DdO}9?o5iW2l@lri-Y@JCFk?IPOKRV)-5<^Lg_#QY`d29_*)3cW o-B4;^{QX1nhd)c%419OTIjHb$ogtHI2Q-ku)78&qol`;+0Q9ObbpQYW literal 0 HcmV?d00001 diff --git a/examples/dev/simple.py b/examples/dev/simple.py new file mode 100644 index 0000000000..a2ceffcf7e --- /dev/null +++ b/examples/dev/simple.py @@ -0,0 +1,26 @@ +import torch +import thunder + +class Module(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + + def forward(self, x: torch.Tensor): + a = x + x + return a + +with torch.device('cuda'): + model = Module() + x = torch.randn(2, 2) + + jmodel = thunder.jit(model) + + ans = jmodel(x) + print('---------------------------------------------- all traces') + for t in thunder.last_traces(jmodel): + print(t) + print('##############################################') + print('---------------------------------------------- ans') + print(ans) + diff --git a/examples/dev/simple_log.out b/examples/dev/simple_log.out new file mode 100644 index 0000000000..aa715847ae --- /dev/null +++ b/examples/dev/simple_log.out @@ -0,0 +1,132 @@ +Interpretation used: INTERPRETATION_OPTIONS.TRANSLATE_PYTHON +comp trce before +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +# No signature available +comp trace after +import thunder +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def computation(x): + # x: "cuda:0 f32[2, 2]" + + # /workspace/pj/lightning-thunder/examples/dev/simple.py:10: a = x + x + result = ltorch.add(x, x, alpha=None) # result: "cuda:0 f32[2, 2]" + # result = prims.add(x, x) # result: "cuda:0 f32[2, 2]" + return result +import thunder +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def computation(x): + # x: "cuda:0 f32[2, 2]" + + # /workspace/pj/lightning-thunder/examples/dev/simple.py:10: a = x + x + result = ltorch.add(x, x, alpha=None) # result: "cuda:0 f32[2, 2]" + # result = prims.add(x, x) # result: "cuda:0 f32[2, 2]" + return result +============================================ START: LABEL default +============================================ START: LABEL computation_trc -> backward_trc = None +============================================ START: post_optimization_transforms +[] +============================================ END: post_optimization_transforms +============================================ START: before computation_trc python Callable +# Constructed by Delete Last Used (took 0 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def computation(x): + # x: "cuda:0 f32[2, 2]" + [result] = nvFusion0(x) + # result = prims.add(x, x) # result: "cuda:0 f32[2, 2]" + del x + return result +============================================ END: before computation_trc python Callable +---------------------------------------------- all traces +import thunder +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def computation(x): + # x: "cuda:0 f32[2, 2]" + + # /workspace/pj/lightning-thunder/examples/dev/simple.py:10: a = x + x + result = ltorch.add(x, x, alpha=None) # result: "cuda:0 f32[2, 2]" + # result = prims.add(x, x) # result: "cuda:0 f32[2, 2]" + return result +############################################## +# Constructed by Dead Code Elimination (took 0 milliseconds) +import thunder +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def computation(x): + # x: "cuda:0 f32[2, 2]" + + # /workspace/pj/lightning-thunder/examples/dev/simple.py:10: a = x + x + result = ltorch.add(x, x, alpha=None) # result: "cuda:0 f32[2, 2]" + # result = prims.add(x, x) # result: "cuda:0 f32[2, 2]" + return result +############################################## +# Constructed by Dead Code Elimination (took 0 milliseconds) +import thunder +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def computation(x): + # x: "cuda:0 f32[2, 2]" + + # /workspace/pj/lightning-thunder/examples/dev/simple.py:10: a = x + x + result = ltorch.add(x, x, alpha=None) # result: "cuda:0 f32[2, 2]" + # result = prims.add(x, x) # result: "cuda:0 f32[2, 2]" + return result +############################################## +# Constructed by Transform for execution (took 1 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def computation(x): + # x: "cuda:0 f32[2, 2]" + [result] = nvFusion0(x) + # result = prims.add(x, x) # result: "cuda:0 f32[2, 2]" + return result +############################################## +# Constructed by Delete Last Used (took 0 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def computation(x): + # x: "cuda:0 f32[2, 2]" + [result] = nvFusion0(x) + # result = prims.add(x, x) # result: "cuda:0 f32[2, 2]" + del x + return result +############################################## +---------------------------------------------- ans +tensor([[-1.9710, -4.7323], + [ 0.1026, 0.5416]], device='cuda:0') diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py new file mode 100644 index 0000000000..8f8c3f87a7 --- /dev/null +++ b/thunder/backend_optimizer/optimizer.py @@ -0,0 +1,80 @@ +from typing import Hashable +from thunder.executors.data_dependent_partition import Graph, Node +from thunder.core.trace import TraceCtx +from thunder.extend import Executor, FusionExecutor, OperatorExecutor + +class OptimizerNode(): + def __init__(self, node: Node): + self.node: Node = node + self.candidate_executors: dict[Hashable, float] = {} + + def add_candidate(self, ex: Executor, benchmarck: float): + self.candidate_executors[ex] = benchmarck + +class BackendOptimizer(): + def __init__(self, trace: TraceCtx, executors: list[Executor]) -> None: + self.trace = trace + self.computation_graph = Graph(trace) + self.executors = executors + self.memo = {} + self.default_cost = {} + self.hash_separator = '#' + self.dummy_cost = 1 + self.optimizer_nodes = [] + + def __repr__(self) -> str: + ret = self.computation_graph.__repr__() + ret += "\n" + n: OptimizerNode + for n in self.optimizer_nodes: + ret += str(n.node.ID) + ' ####################################' + ret += "\n" + ret += n.candidate_executors.__repr__() + ret += "\n" + return ret + + def write(self, file_name): + with open(file_name, 'w') as file: + s = self.__repr__() + file.write(s) + file.close() + + def subgraph_hash(self, nodes: list[Node]): + ids = [str(n.ID) for n in nodes] + return self.hash_separator.join(ids) + + # TODO: to implement + def compute_default_costs_subgraphs(self, nodes: list[Node]): + hash = self.subgraph_hash(nodes) + self.default_cost[hash] = self.dummy_cost + + def build_search_space(self): + visited = set() + def dfs(node: Node): + visited.add(node.ID) + childs = node.children + node_bsym = node.group_bsyms[0] + + optimizer_node: OptimizerNode = OptimizerNode(node) + + print(f'Node id: {node.ID}, symbol: {node_bsym.sym.name}') + + ex: Executor + for ex in self.executors: + print(f'analyzing executor {ex._name}') + if (isinstance(ex, OperatorExecutor) and ex.can_execute(node_bsym)): + print(f'{ex._name} can execute symbol {node_bsym.sym.name}') + optimizer_node.add_candidate(ex, 1.0) + if (isinstance(ex, FusionExecutor) and ex.can_fuse(node_bsym)): + print(f'{ex._name} can fuse symbol {node_bsym.sym.name}') + optimizer_node.add_candidate(ex, 1.0) + + self.optimizer_nodes.append(optimizer_node) + + if childs: + for c in childs: + if c.ID not in visited: + dfs(c) + + for root in self.computation_graph.roots: + dfs(root) diff --git a/thunder/common.py b/thunder/common.py index 23770a0975..94328dfa4e 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -613,6 +613,7 @@ def wait_for_future(f: FutureTensorProxy) -> TensorProxy: def transform_for_execution( trace: TraceCtx, executors_list: Sequence[Executor], + label = "default", *, only_execute_prims=False, use_rematerialization=True, @@ -629,7 +630,7 @@ def transform_for_execution( # cse_trace = cse(dce_trace) # traces.append(cse_trace) - extrace = executors.passes.transform_for_execution(dce_trace, executors_list) + extrace = executors.passes.transform_for_execution(dce_trace, executors_list, label) traces.append(extrace) diff --git a/thunder/core/jit_ext.py b/thunder/core/jit_ext.py index d69e8796a9..c5b4262b24 100644 --- a/thunder/core/jit_ext.py +++ b/thunder/core/jit_ext.py @@ -1705,6 +1705,7 @@ def thunder_general_jit( process_group_for_ddp=process_group_for_ddp, executor_lookasides=executor_lookasides, ) + jfn = interpret( fn, fn_lookaside=general_jit_lookaside, diff --git a/thunder/core/prims.py b/thunder/core/prims.py index d84536ff7f..b7e3e9dbdc 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -3771,8 +3771,8 @@ def check_sequence(seq, seq_str_name, rank, *, min_val): utils.check(len(seq) == 1 or len(seq) == rank, lambda: f"len({seq_str_name}) should be either 1 or {rank}") - # Check all elements are >= min_val for i, e in enumerate(seq): + # Check all elements are >= min_val utils.check( isinstance(e, (int, IntegerProxy)) and e >= min_val, lambda: f"all elements in {seq_str_name} should be integers at least {min_val}, " diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 84804080c1..5b3491a5af 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -3609,8 +3609,16 @@ def forward_and_backward_from_trace(trace: Trace, torch_autograd=False) -> Forwa output_spec = None + import pprint + def augmented_forward_fn(*args, **kwargs): result, env = augmented_forward_pass(*args, trace=trace, **kwargs) + print('============================================ START: augmented_forward_pass') + print('result') + pprint.pprint(result) + print('env') + pprint.pprint(env) + print('============================================ END: augmented_forward_pass') saved_for_backward = deconstruct_forward_env_for_backward(trace, env) if torch_autograd: nonlocal output_spec @@ -3640,6 +3648,7 @@ def ones_like(x): else: return None + forward_trace = construct_trace()(augmented_forward_fn, *trace.args, **trace.kwargs) # We set forward trace to construct proxies because we need these proxies to # have different names than the ones in the forward trace. diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index 47dde612d1..4ce1c5b31b 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -20,7 +20,9 @@ from thunder.core.trace import get_tracectx from thunder.executors.pythonex import clear_mutable_collection -from thunder.extend import Executor, get_always_executors, OperatorExecutor, FusionExecutor +from thunder.extend import Executor, get_all_executors, get_always_executors, OperatorExecutor, FusionExecutor +from thunder.backend_optimizer.optimizer import BackendOptimizer +from thunder.visualizer.graphviz import create_graphviz_pdf comment_symbols = {prims.PrimIDs.COMMENT, prims.PrimIDs.UNPACK_TRIVIAL} @@ -57,13 +59,30 @@ def preserve_bsym(bsym: BoundSymbol) -> Any: # If the executor has an execution transform, it's called and True is returned # If no executor can execute the BoundSymbol, False is returned def visit_helper_(bsym: BoundSymbol) -> None | bool: + import pprint if bsym.sym.python_impl is not None: return None + subsymbols = [str(s.sym.id) for s in bsym.subsymbols] + subsymbols = " ".join(subsymbols) + print(f'\n -> analyzing bsym: {bsym.sym.id} - subsymbols: {subsymbols}') + executors_names = [str(n._name) for n in executors_list] + executors_names = ' '.join(executors_names) + print(f'available executors: {executors_names}') + print('what?') + ex: Executor for ex in executors_list: + # TODO Consider allowing operator executors to claim portions of operations # TODO Should FusionExecutors be allowed to claim bsym with bsym.sym.executor? + print(f'testing executor: {ex._name}') + if (isinstance(ex, OperatorExecutor) and ex.can_execute(bsym)): + print(ex.name, 'CAN execute', bsym.sym.id) + else: + can_fuse = isinstance(ex, FusionExecutor) and ex.can_fuse(bsym) + print(ex.name, 'can NOT execute', bsym.sym.id, 'but can fuse? ', can_fuse) + if (isinstance(ex, OperatorExecutor) and ex.can_execute(bsym)) or ( isinstance(ex, FusionExecutor) and ex.can_fuse(bsym) ): @@ -86,11 +105,19 @@ def visit_helper_(bsym: BoundSymbol) -> None | bool: raise AssertionError("Unknown executor") safe_map_flat(update_swapmap, bsym.output, out) + + # Here is the point were we return the first available one + # print('swap_map -> sym swapped, symbol:', bsym.sym.id) + # pprint.pprint((swapmap)) return True if bsym.sym.executor is not None: + # print('swap_map -> sym.executor is None', bsym.sym.id) + # pprint.pprint((swapmap)) return None + # print('swap_map -> nothing happened', bsym.sym.id) + # pprint.pprint((swapmap)) return False def visit_(bsym: BoundSymbol) -> transforms.VISIT_TYPE: @@ -128,8 +155,14 @@ def visit_(bsym: BoundSymbol) -> transforms.VISIT_TYPE: return extrace -def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor]) -> TraceCtx: +def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor], label='default') -> TraceCtx: import torch + import pprint + from thunder.executors.data_dependent_partition import Graph + + print(f'============================================ START: LABEL {label}') + print(f'============================================ START: executor_list: {executors_list}') + print(f'============================================ START: always_executor_list: {get_all_executors()}') start_time_ns = time.time_ns() @@ -141,12 +174,44 @@ def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor]) trace = dce(trace) + if (label == 'forward_trc'): + print('============================================ START: before _transform_for_operator_executor_execution') + pprint.pprint(trace) + print('============================================ END: before _transform_for_operator_executor_execution') + # # Step 1 Performs execution transforms # + # print('#### A') + # custom_exs = [ex for ex in executors_list if ex._name == 'python' or ex._name == 'torch'] + # extrace = _transform_for_operator_executor_execution(trace, custom_exs) + # if (label == 'forward_trc'): + # print('============================================ START: after _transform_for_operator_executor_execution') + # pprint.pprint(extrace) + # print('============================================ GRAPH: _transform_for_operator_executor_execution') + # g = Graph(trace) + # create_graphviz_pdf(g, label) + # print(g) + # print('============================================ END: after _transform_for_operator_executor_execution') + # extrace = dce(extrace) + + print('#### B') extrace = _transform_for_operator_executor_execution(trace, executors_list) + if (label == 'forward_trc'): + print('============================================ START: after _transform_for_operator_executor_execution') + pprint.pprint(extrace) + print('============================================ GRAPH: _transform_for_operator_executor_execution') + g = Graph(trace) + create_graphviz_pdf(g, label) + print(g) + print('============================================ BACKEND_OPTIMIZER') + backend_optimizer = BackendOptimizer(extrace, executors_list) + backend_optimizer.build_search_space() + backend_optimizer.write("backend_optimizer.log") + print('============================================ END: after _transform_for_operator_executor_execution') extrace = dce(extrace) + # # Step 2 Fusion executors can transform the trace # @@ -154,6 +219,15 @@ def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor]) if isinstance(ex, FusionExecutor): extrace = ex.fusion_pass(extrace) + if (label == 'forward_trc'): + print('============================================ START: after fusion_pass') + pprint.pprint(extrace) + print('============================================ GRAPH: fusion_pass') + g = Graph(extrace) + create_graphviz_pdf(g, f'{label}_fusion') + print(g) + print('============================================ END: after fusion_pass') + # # Step 3 "Always" executors are given the opportunity to execute unclaimed symbols # @@ -162,6 +236,15 @@ def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor]) # NOTE This occurs if a fusion executor declines to execute a symbol after running its fusion pass extrace = _transform_for_operator_executor_execution(extrace, get_always_executors()) + if (label == 'forward_trc'): + print('============================================ START: after _transform_for_operator_executor_execution (always)') + pprint.pprint(extrace) + print('============================================ GRAPH: fusion_pass') + g = Graph(extrace) + create_graphviz_pdf(g, f'{label}_final') + print(g) + print('============================================ END: after _transform_for_operator_executor_execution (always)') + end_time_ns = time.time_ns() elapsed_time_ns = end_time_ns - start_time_ns elapsed_time_millis = elapsed_time_ns // 1000000 diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 820359f4d4..b543f06921 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -122,9 +122,18 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat if not any(requires_grad_mask): raise RuntimeError("PyTorch's Autograd interface requires at least one tensor input with requires_grad=True") + import pprint + print('============================================ START: computation_trc split_forward_backward') + pprint.pprint(computation_trc) + print('============================================ END: computation_trc split_forward_backward') + primal_trace = computation_trc primal_trace = sort_data_parallel_syncs(primal_trace) + print('============================================ START: primal_trace sort_data_parallel_syncs') + pprint.pprint(primal_trace) + print('============================================ END: primal_trace sort_data_parallel_syncs') + if compile_stats is not None: compile_stats.last_traces.append(primal_trace) @@ -134,6 +143,9 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat # not any other container type. So we need to flatten the outputs of # the forward trace and inputs of the backward trace. fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True) + print('============================================ START: primal_trace forward_and_backward_from_trace') + pprint.pprint(fw_trace) + print('============================================ END: primal_trace forward_and_backward_from_trace') fw_traces = [fw_trace] bw_traces = [bw_trace] @@ -165,9 +177,15 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat # Now we can run the optimization passes on the forward trace # TODO Restore request for no rematerialization + + import pprint + print('============================================ START: before forward_trc transform_for_execution') + pprint.pprint(fw_trace) + print('============================================ END: after forward_trc transform_for_execution') fw_extrace = transform_for_execution( fw_trace, executors_list=compile_data.executors_list, + label='forward_trc' ) fw_traces.append(fw_extrace) @@ -205,6 +223,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat bw_extrace = transform_for_execution( bw_trace, executors_list=compile_data.executors_list, + label='backward_trc' ) bw_traces.append(bw_extrace) diff --git a/thunder/visualizer/__init__.py b/thunder/visualizer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/thunder/visualizer/graphviz.py b/thunder/visualizer/graphviz.py new file mode 100644 index 0000000000..8b3ebd78a9 --- /dev/null +++ b/thunder/visualizer/graphviz.py @@ -0,0 +1,37 @@ +import graphviz +from thunder.executors.data_dependent_partition import Node, Graph + +def to_graphviz_dag(g: Graph) -> graphviz.Digraph: + dot = graphviz.Digraph() + visit_stack = list(g.roots) + # Add root nodes + r: Node + for r in g.roots: + dot.node(f'{r.ID}({r.group_bsyms[0].sym.name})') + + visited = set() + cur: Node + while visit_stack: + cur = visit_stack.pop(0) + if cur in visited: + continue + + cur_node_str = f'{cur.ID}({cur.group_bsyms[0].sym.name})' + dot.node(cur_node_str) + + # Connect with parent + for p in cur.parents: + id = p.ID + op = p.group_bsyms[0].sym.name + parent_str = f'{id}({op})' + dot.edge(parent_str, cur_node_str) + + visited.add(cur) + visit_stack.extend(cur.children) + return dot + + +def create_graphviz_pdf(g: Graph, name='graph'): + dot = to_graphviz_dag(g) + dot.render(name, view=False) + From 4ec2fa06a60f93202c1c52f741776d7adf69551e Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Thu, 4 Jul 2024 18:10:04 +0300 Subject: [PATCH 002/171] Single trace region placement impl --- thunder/backend_optimizer/optimizer.py | 307 ++++++++++++++++++++++--- thunder/executors/passes.py | 62 ++--- 2 files changed, 291 insertions(+), 78 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 8f8c3f87a7..43f0a00935 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -1,7 +1,17 @@ -from typing import Hashable +from typing import Any, Hashable +from thunder.core.baseutils import BoundSymbolInterface +from thunder.core.utils import check, safe_map_flat +from thunder.core.proxies import Proxy, variableify, Variable +from thunder.core.symbol import BoundSymbol from thunder.executors.data_dependent_partition import Graph, Node -from thunder.core.trace import TraceCtx +from thunder.core.trace import set_tracectx, reset_tracectx, from_trace, get_tracectx, TraceProvenance, TraceCtx from thunder.extend import Executor, FusionExecutor, OperatorExecutor +import thunder.core.transforms as transforms +from collections.abc import Callable, Sequence +from enum import Enum +from itertools import chain +import time +import pprint class OptimizerNode(): def __init__(self, node: Node): @@ -12,25 +22,22 @@ def add_candidate(self, ex: Executor, benchmarck: float): self.candidate_executors[ex] = benchmarck class BackendOptimizer(): - def __init__(self, trace: TraceCtx, executors: list[Executor]) -> None: + def __init__(self, trace: TraceCtx, executors: Sequence[Executor]) -> None: self.trace = trace self.computation_graph = Graph(trace) self.executors = executors - self.memo = {} self.default_cost = {} self.hash_separator = '#' self.dummy_cost = 1 self.optimizer_nodes = [] + self.placement_options: list[list[Executor]] = [] def __repr__(self) -> str: ret = self.computation_graph.__repr__() ret += "\n" n: OptimizerNode for n in self.optimizer_nodes: - ret += str(n.node.ID) + ' ####################################' - ret += "\n" - ret += n.candidate_executors.__repr__() - ret += "\n" + ret += f'NODE: {str(n.node.ID)} - {str(n.node.group_bsyms[0].sym.name)} ####################################\n {n.candidate_executors.__repr__()}\n' return ret def write(self, file_name): @@ -48,33 +55,267 @@ def compute_default_costs_subgraphs(self, nodes: list[Node]): hash = self.subgraph_hash(nodes) self.default_cost[hash] = self.dummy_cost - def build_search_space(self): - visited = set() - def dfs(node: Node): - visited.add(node.ID) - childs = node.children - node_bsym = node.group_bsyms[0] + def backed_placer(self): + pass + + def build_placement_options(self): + class ExecutorType(Enum): + OPERATOR = 1 + FUSER = 1 + + class SearchNode: + def __init__(self, symbol: BoundSymbolInterface, idx: int)-> None: + self.symbol = symbol + self.idx = idx + pass - optimizer_node: OptimizerNode = OptimizerNode(node) + # We assign an internal id to each symbol based on its idx inside the bound_symbols list + def search(node: SearchNode, configuration): + def continue_search(): + if node.idx+1 < max_len: + new_idx: int = node.idx + 1 + new_symbol: BoundSymbolInterface = bound_symbols[new_idx] + search(SearchNode(new_symbol, new_idx), configuration) + else: + all_configurations.append(list(configuration)) - print(f'Node id: {node.ID}, symbol: {node_bsym.sym.name}') + def update_dict(idx: int, type: ExecutorType, ex: Executor): + if idx not in res: + res[idx] = {} + res[node.idx][type] = [ex] + else: + if type not in res[idx]: + res[node.idx][type] = [ex] + else: + res[node.idx][type].append(ex) ex: Executor + has_backend = False for ex in self.executors: - print(f'analyzing executor {ex._name}') - if (isinstance(ex, OperatorExecutor) and ex.can_execute(node_bsym)): - print(f'{ex._name} can execute symbol {node_bsym.sym.name}') - optimizer_node.add_candidate(ex, 1.0) - if (isinstance(ex, FusionExecutor) and ex.can_fuse(node_bsym)): - print(f'{ex._name} can fuse symbol {node_bsym.sym.name}') - optimizer_node.add_candidate(ex, 1.0) - - self.optimizer_nodes.append(optimizer_node) - - if childs: - for c in childs: - if c.ID not in visited: - dfs(c) - - for root in self.computation_graph.roots: - dfs(root) + if not isinstance(node.symbol, BoundSymbol): + raise AssertionError("Receive a symbol which is not a BoundSymbol") + if (isinstance(ex, OperatorExecutor) and ex.can_execute(node.symbol)): + # print(f'{node.idx}-{ex._name} can execute symbol {node.symbol.sym.name}') + update_dict(node.idx, ExecutorType.OPERATOR, ex) + has_backend = True + configuration.append(ex) + continue_search() + configuration.pop(-1) + else: + pass + # print(f'{node.idx}-{ex._name} can NOT execute symbol {node.symbol.sym.name}') + if (isinstance(ex, FusionExecutor) and ex.can_fuse(node.symbol)): + # print(f'{node.idx}-{ex._name} can fuse symbol {node.symbol.sym.name}') + update_dict(node.idx, ExecutorType.FUSER, ex) + has_backend = True + configuration.append(ex) + continue_search() + configuration.pop(-1) + else: + pass + # print(f'{node.idx}-{ex._name} can NOT fuse symbol {node.symbol.sym.name}') + + if not has_backend: + configuration.append(empty_executor) + continue_search() + configuration.pop(-1) + + res: dict[int, dict[ExecutorType, list[Executor]]] = {} + bound_symbols: list[BoundSymbolInterface] = self.trace.bound_symbols + bound_symbols_name = [s.sym.name for s in bound_symbols] + # bound_symbols_id = [s.sym.id for s in bound_symbols] + max_len = len(bound_symbols) + + all_configurations: list[list[Executor]] = [] + # Is the name reserved? + empty_executor = Executor(name='empty') + + print(f'input trace bound symbols name: {bound_symbols_name}') + # print(f'input trace bound symbols id: {bound_symbols_id}') + + if len(bound_symbols) > 0: + search(SearchNode(bound_symbols[0], 0), []) + self.placement_options = all_configurations + print(len(all_configurations)) + print('END OF SEDARCH') + for config in all_configurations: + c_str = [str(c.name) for c in config] + c_str = " ".join(c_str) + print(c_str) + + def place_optimizers(self, executor_list: list[Executor]) -> TraceCtx: + start_time_ns = time.time_ns() + + swapmap: dict[Variable, Proxy] = {} + + def update_swapmap(o: Any, no: Any) -> None: + if isinstance(o, Proxy): + check( + isinstance(no, Proxy), + lambda: f"Expected an execution transform to produce outputs with the same type, but found {type(o)} and {type(no)}", + ) + + vo = variableify(o) + vno = variableify(no) + if vo == vno: + return + swapmap[vno] = o + + def preserve_bsym(bsym: BoundSymbol) -> Any: + trace = get_tracectx() + trace.scopes[-1].append(bsym) + for p in chain(bsym.flat_proxy_outs, bsym.flat_proxy_args): + trace.names.add(p.name) + return bsym.output + + def visit_helper(bsym: BoundSymbol, ex: Executor) -> None | bool: + if bsym.sym.python_impl is not None: + return None + + # We have mapped this at previous stages + if ex.name == 'empty': + return None + # The call above represent: + # if bsym.sym.executor is not None: + # return None + + execution_transform: None | Callable = ex.get_execution_transform(bsym.sym) + out: Any + # TODO: What is this? + if execution_transform is not None: + out = execution_transform(*bsym.args, **bsym.kwargs) + elif isinstance(ex, OperatorExecutor): + # Calls the operator executor's operation + op = ex.implmap[bsym.sym.id].symbol + out = op(*bsym.args, **bsym.kwargs) + elif isinstance(ex, FusionExecutor): + # Preserves the symbol as is (it will be handled in the fusion pass) + out = preserve_bsym(bsym) + else: + raise AssertionError("Unknown executor") + + safe_map_flat(update_swapmap, bsym.output, out) + + return True + + def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: + return transforms.VISIT_TYPE.NO_OP if visit_helper(bsym, ex) is None else transforms.VISIT_TYPE.REPLACE + + def visitor_transform(trace_from: TraceCtx, executors: list[Executor]): + trc: TraceCtx = from_trace(trace_from) + + try: + tracectx_tok = set_tracectx(trc) + + for bsym, ex in zip(trace_from.bound_symbols, executors): + try: + # Creates a temporary scope to support copying the original bsym BEFORE + # the operations performed by visit(), even though this doesn't know whether to + # copy the original bsym until after visit() completes + old_scope = trc.scopes + scope = [] + trc.scopes = [scope] + + # This can be simpler? We currently trigger all the flow for the substitution + visit_type = visit(bsym, ex) + + if visit_type is transforms.VISIT_TYPE.INSERT_AFTER: + trc.bound_symbols.append(bsym) + + if visit_type is not transforms.VISIT_TYPE.NO_OP: + trc.bound_symbols.extend(scope) + else: + trc.bound_symbols.append(bsym) + + if visit_type is transforms.VISIT_TYPE.INSERT_BEFORE: + trc.bound_symbols.append(bsym) + + finally: + # Restores the trc's scope + trc.scopes = old_scope + + return trc + + finally: + reset_tracectx(tracectx_tok) + + extrace = visitor_transform(self.trace, executor_list) + + # Restores original variables + bound_symbols: list[BoundSymbol] = [] + for bsym in extrace.bound_symbols: + nbsym: BoundSymbol = bsym.from_bsym_swap_proxies(swapmap) + bound_symbols.append(nbsym) + + extrace.bound_symbols = bound_symbols + + end_time_ns = time.time_ns() + elapsed_time_ns = end_time_ns - start_time_ns + elapsed_time_millis = elapsed_time_ns // 1000000 + extrace.set_provenance( + TraceProvenance(f"Transform for operator executor execution (took {elapsed_time_millis} milliseconds)") + ) + + print('============================================ trace before fusion pass') + pprint.pprint(extrace) + + # We have to temporary clear the subsymbols of already claimed symbols by not fusion ops, otherwise fusion ops will check recursively subsymbols and clear all the current placements + cached_subsymbols: list[Sequence[BoundSymbolInterface]] = [list(symbol.subsymbols) for symbol in extrace.bound_symbols] + subsymbols_idx_to_restore: list[int] = [] + unique_fusion_executors = set() + for idx, ex in enumerate(executor_list): + if isinstance(ex, FusionExecutor): + unique_fusion_executors.add(ex) + else: + subsymbols_idx_to_restore.append(idx) + extrace.bound_symbols[idx].subsymbols = () + + for ex in unique_fusion_executors: + extrace = ex.fusion_pass(extrace) + + # Restore the subsymbols + for idx in subsymbols_idx_to_restore: + extrace.bound_symbols[idx].subsymbols = cached_subsymbols[idx] + + return extrace + + def build_search_space(self): + self.build_placement_options() + for option in self.placement_options: + trace = self.place_optimizers(option) + option_str = [str(ex.name) for ex in option] + option_str = ' '.join(option_str) + print(f'============================================ config_trace, optimizers: {option_str}') + pprint.pprint(trace) + + # visited = set() + # def dfs(node: Node): + # visited.add(node.ID) + # childs = node.children + # node_symbols = [str(s.sym.id) for s in node.group_bsyms] + # node_symbols = " ".join(node_symbols) + # node_bsym = node.group_bsyms[0] + + # optimizer_node: OptimizerNode = OptimizerNode(node) + + # print(f'-> Node id: {node.ID}, symbols: {node_symbols}') + + # ex: Executor + # for ex in self.executors: + # print(f'analyzing executor {ex._name}') + # if (isinstance(ex, OperatorExecutor) and ex.can_execute(node_bsym)): + # print(f'{ex._name} can execute symbol {node_bsym.sym.name}') + # optimizer_node.add_candidate(ex, 1.0) + # if (isinstance(ex, FusionExecutor) and ex.can_fuse(node_bsym)): + # print(f'{ex._name} can fuse symbol {node_bsym.sym.name}') + # optimizer_node.add_candidate(ex, 1.0) + + # self.optimizer_nodes.append(optimizer_node) + + # if childs: + # for c in childs: + # if c.ID not in visited: + # dfs(c) + + # for root in self.computation_graph.roots: + # dfs(root) diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index 4ce1c5b31b..30cf63227c 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -59,29 +59,14 @@ def preserve_bsym(bsym: BoundSymbol) -> Any: # If the executor has an execution transform, it's called and True is returned # If no executor can execute the BoundSymbol, False is returned def visit_helper_(bsym: BoundSymbol) -> None | bool: - import pprint if bsym.sym.python_impl is not None: return None - subsymbols = [str(s.sym.id) for s in bsym.subsymbols] - subsymbols = " ".join(subsymbols) - print(f'\n -> analyzing bsym: {bsym.sym.id} - subsymbols: {subsymbols}') - executors_names = [str(n._name) for n in executors_list] - executors_names = ' '.join(executors_names) - print(f'available executors: {executors_names}') - print('what?') - ex: Executor for ex in executors_list: # TODO Consider allowing operator executors to claim portions of operations # TODO Should FusionExecutors be allowed to claim bsym with bsym.sym.executor? - print(f'testing executor: {ex._name}') - if (isinstance(ex, OperatorExecutor) and ex.can_execute(bsym)): - print(ex.name, 'CAN execute', bsym.sym.id) - else: - can_fuse = isinstance(ex, FusionExecutor) and ex.can_fuse(bsym) - print(ex.name, 'can NOT execute', bsym.sym.id, 'but can fuse? ', can_fuse) if (isinstance(ex, OperatorExecutor) and ex.can_execute(bsym)) or ( isinstance(ex, FusionExecutor) and ex.can_fuse(bsym) @@ -106,18 +91,11 @@ def visit_helper_(bsym: BoundSymbol) -> None | bool: safe_map_flat(update_swapmap, bsym.output, out) - # Here is the point were we return the first available one - # print('swap_map -> sym swapped, symbol:', bsym.sym.id) - # pprint.pprint((swapmap)) return True if bsym.sym.executor is not None: - # print('swap_map -> sym.executor is None', bsym.sym.id) - # pprint.pprint((swapmap)) return None - # print('swap_map -> nothing happened', bsym.sym.id) - # pprint.pprint((swapmap)) return False def visit_(bsym: BoundSymbol) -> transforms.VISIT_TYPE: @@ -160,9 +138,11 @@ def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor], import pprint from thunder.executors.data_dependent_partition import Graph + what_to_log = "forward_trc" + print(f'============================================ START: LABEL {label}') print(f'============================================ START: executor_list: {executors_list}') - print(f'============================================ START: always_executor_list: {get_all_executors()}') + print(f'============================================ START: always_executor_list: {get_always_executors()}') start_time_ns = time.time_ns() @@ -174,40 +154,32 @@ def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor], trace = dce(trace) - if (label == 'forward_trc'): + if (label == what_to_log): print('============================================ START: before _transform_for_operator_executor_execution') pprint.pprint(trace) + print('============================================ GRAPH: before _transform_for_operator_executor_execution') + print(Graph(trace)) print('============================================ END: before _transform_for_operator_executor_execution') + + if label == what_to_log: + print('============================================ start: BACKEND_OPTIMIZER') + backend_optimizer = BackendOptimizer(trace, executors_list) + backend_optimizer.build_search_space() + backend_optimizer.write("backend_optimizer.log") + print('============================================ end: BACKEND_OPTIMIZER') + # # Step 1 Performs execution transforms # - # print('#### A') - # custom_exs = [ex for ex in executors_list if ex._name == 'python' or ex._name == 'torch'] - # extrace = _transform_for_operator_executor_execution(trace, custom_exs) - # if (label == 'forward_trc'): - # print('============================================ START: after _transform_for_operator_executor_execution') - # pprint.pprint(extrace) - # print('============================================ GRAPH: _transform_for_operator_executor_execution') - # g = Graph(trace) - # create_graphviz_pdf(g, label) - # print(g) - # print('============================================ END: after _transform_for_operator_executor_execution') - # extrace = dce(extrace) - - print('#### B') extrace = _transform_for_operator_executor_execution(trace, executors_list) - if (label == 'forward_trc'): + if (label == what_to_log): print('============================================ START: after _transform_for_operator_executor_execution') pprint.pprint(extrace) print('============================================ GRAPH: _transform_for_operator_executor_execution') g = Graph(trace) create_graphviz_pdf(g, label) print(g) - print('============================================ BACKEND_OPTIMIZER') - backend_optimizer = BackendOptimizer(extrace, executors_list) - backend_optimizer.build_search_space() - backend_optimizer.write("backend_optimizer.log") print('============================================ END: after _transform_for_operator_executor_execution') extrace = dce(extrace) @@ -219,7 +191,7 @@ def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor], if isinstance(ex, FusionExecutor): extrace = ex.fusion_pass(extrace) - if (label == 'forward_trc'): + if (label == what_to_log): print('============================================ START: after fusion_pass') pprint.pprint(extrace) print('============================================ GRAPH: fusion_pass') @@ -236,7 +208,7 @@ def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor], # NOTE This occurs if a fusion executor declines to execute a symbol after running its fusion pass extrace = _transform_for_operator_executor_execution(extrace, get_always_executors()) - if (label == 'forward_trc'): + if (label == what_to_log): print('============================================ START: after _transform_for_operator_executor_execution (always)') pprint.pprint(extrace) print('============================================ GRAPH: fusion_pass') From 8d084a85e05ee4adb08a2285d62ec9dc0e500dd2 Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Tue, 9 Jul 2024 10:19:23 +0300 Subject: [PATCH 003/171] Serial exaustive search --- thunder/backend_optimizer/optimizer.py | 363 ++++++++++++++---------- thunder/common.py | 3 +- thunder/core/codeutils.py | 16 +- thunder/core/transforms.py | 67 ++++- thunder/executors/passes.py | 81 +++--- thunder/executors/torch_autograd.py | 41 ++- thunder/visualizer/graphviz.py | 5 +- thunder/visualizer/visualizer_helper.py | 48 ++++ 8 files changed, 391 insertions(+), 233 deletions(-) create mode 100644 thunder/visualizer/visualizer_helper.py diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 43f0a00935..771ff29059 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -1,44 +1,51 @@ from typing import Any, Hashable +import torch +import thunder from thunder.core.baseutils import BoundSymbolInterface from thunder.core.utils import check, safe_map_flat -from thunder.core.proxies import Proxy, variableify, Variable -from thunder.core.symbol import BoundSymbol +from thunder.core.proxies import Proxy, TensorProxy, variableify, Variable +from thunder.core.symbol import BoundSymbol, Symbol from thunder.executors.data_dependent_partition import Graph, Node -from thunder.core.trace import set_tracectx, reset_tracectx, from_trace, get_tracectx, TraceProvenance, TraceCtx -from thunder.extend import Executor, FusionExecutor, OperatorExecutor +from thunder.core.trace import set_tracectx, reset_tracectx, get_tracectx, TraceCtx +from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_always_executors import thunder.core.transforms as transforms +from thunder.visualizer.visualizer_helper import Visualizer from collections.abc import Callable, Sequence from enum import Enum from itertools import chain import time -import pprint +# import pprint class OptimizerNode(): def __init__(self, node: Node): self.node: Node = node self.candidate_executors: dict[Hashable, float] = {} - def add_candidate(self, ex: Executor, benchmarck: float): - self.candidate_executors[ex] = benchmarck + def add_candidate(self, ex: Executor, benchmark: float): + self.candidate_executors[ex] = benchmark class BackendOptimizer(): - def __init__(self, trace: TraceCtx, executors: Sequence[Executor]) -> None: - self.trace = trace - self.computation_graph = Graph(trace) - self.executors = executors - self.default_cost = {} - self.hash_separator = '#' - self.dummy_cost = 1 - self.optimizer_nodes = [] + def __init__(self, trace: TraceCtx, executors: Sequence[Executor], produce_log=True, log_file_name='autotune_traces_computation_time.log', visualizer: Visualizer | None = None) -> None: + self.trace: TraceCtx = trace + self.optimal_trace: TraceCtx = trace + self.computation_graph: Graph = Graph(trace) + self.executors: Sequence[Executor] = executors + self.empty_executor_hashable_placeholder: str = 'empty' self.placement_options: list[list[Executor]] = [] - + self.optimzide_traces: list[TraceCtx] = [] + self.always_executors: tuple[Executor, ...] = get_always_executors() + self.produce_log: bool = produce_log + self.log_file_name: str = log_file_name + self.log_str: str = "" + self.visualizer: Visualizer | None = visualizer + + print('INIT TRACE') + import pprint + pprint.pprint(self.trace) + + # TODO (matteochen): fix this def __repr__(self) -> str: - ret = self.computation_graph.__repr__() - ret += "\n" - n: OptimizerNode - for n in self.optimizer_nodes: - ret += f'NODE: {str(n.node.ID)} - {str(n.node.group_bsyms[0].sym.name)} ####################################\n {n.candidate_executors.__repr__()}\n' - return ret + return '' def write(self, file_name): with open(file_name, 'w') as file: @@ -46,17 +53,16 @@ def write(self, file_name): file.write(s) file.close() - def subgraph_hash(self, nodes: list[Node]): - ids = [str(n.ID) for n in nodes] - return self.hash_separator.join(ids) - - # TODO: to implement - def compute_default_costs_subgraphs(self, nodes: list[Node]): - hash = self.subgraph_hash(nodes) - self.default_cost[hash] = self.dummy_cost + def compute_time_cost(self, fn: Callable, iters: int, *args) -> tuple[float, Any]: + total_time = 0 + out = None + for _ in range(iters): + time_s = time.time_ns() + out = fn(*args) + time_e = time.time_ns() + total_time += (time_e - time_s) - def backed_placer(self): - pass + return total_time / iters, out def build_placement_options(self): class ExecutorType(Enum): @@ -67,7 +73,6 @@ class SearchNode: def __init__(self, symbol: BoundSymbolInterface, idx: int)-> None: self.symbol = symbol self.idx = idx - pass # We assign an internal id to each symbol based on its idx inside the bound_symbols list def search(node: SearchNode, configuration): @@ -77,9 +82,10 @@ def continue_search(): new_symbol: BoundSymbolInterface = bound_symbols[new_idx] search(SearchNode(new_symbol, new_idx), configuration) else: + print(f'reached end of search for this tree branch {configuration}') all_configurations.append(list(configuration)) - def update_dict(idx: int, type: ExecutorType, ex: Executor): + def safe_update_dict(idx: int, type: ExecutorType, ex: Executor): if idx not in res: res[idx] = {} res[node.idx][type] = [ex] @@ -96,24 +102,18 @@ def update_dict(idx: int, type: ExecutorType, ex: Executor): raise AssertionError("Receive a symbol which is not a BoundSymbol") if (isinstance(ex, OperatorExecutor) and ex.can_execute(node.symbol)): # print(f'{node.idx}-{ex._name} can execute symbol {node.symbol.sym.name}') - update_dict(node.idx, ExecutorType.OPERATOR, ex) + safe_update_dict(node.idx, ExecutorType.OPERATOR, ex) has_backend = True configuration.append(ex) continue_search() configuration.pop(-1) - else: - pass - # print(f'{node.idx}-{ex._name} can NOT execute symbol {node.symbol.sym.name}') if (isinstance(ex, FusionExecutor) and ex.can_fuse(node.symbol)): # print(f'{node.idx}-{ex._name} can fuse symbol {node.symbol.sym.name}') - update_dict(node.idx, ExecutorType.FUSER, ex) + safe_update_dict(node.idx, ExecutorType.FUSER, ex) has_backend = True configuration.append(ex) continue_search() configuration.pop(-1) - else: - pass - # print(f'{node.idx}-{ex._name} can NOT fuse symbol {node.symbol.sym.name}') if not has_backend: configuration.append(empty_executor) @@ -123,28 +123,27 @@ def update_dict(idx: int, type: ExecutorType, ex: Executor): res: dict[int, dict[ExecutorType, list[Executor]]] = {} bound_symbols: list[BoundSymbolInterface] = self.trace.bound_symbols bound_symbols_name = [s.sym.name for s in bound_symbols] - # bound_symbols_id = [s.sym.id for s in bound_symbols] max_len = len(bound_symbols) all_configurations: list[list[Executor]] = [] # Is the name reserved? - empty_executor = Executor(name='empty') + empty_executor = Executor(name=self.empty_executor_hashable_placeholder) - print(f'input trace bound symbols name: {bound_symbols_name}') - # print(f'input trace bound symbols id: {bound_symbols_id}') + print(f'input trace bound symbols name len {len(bound_symbols_name)}: {bound_symbols_name}') if len(bound_symbols) > 0: search(SearchNode(bound_symbols[0], 0), []) + print('end of search') self.placement_options = all_configurations - print(len(all_configurations)) - print('END OF SEDARCH') - for config in all_configurations: - c_str = [str(c.name) for c in config] - c_str = " ".join(c_str) - print(c_str) + print('config len', len(all_configurations)) + # for config in all_configurations: + # c_str = [str(c.name) for c in config] + # c_str = " ".join(c_str) + # print(c_str) def place_optimizers(self, executor_list: list[Executor]) -> TraceCtx: - start_time_ns = time.time_ns() + + from thunder.executors.passes import _transform_for_operator_executor_execution swapmap: dict[Variable, Proxy] = {} @@ -162,7 +161,9 @@ def update_swapmap(o: Any, no: Any) -> None: swapmap[vno] = o def preserve_bsym(bsym: BoundSymbol) -> Any: - trace = get_tracectx() + trace: TraceCtx | None = get_tracectx() + if trace is None: + raise AssertionError('None trace context') trace.scopes[-1].append(bsym) for p in chain(bsym.flat_proxy_outs, bsym.flat_proxy_args): trace.names.add(p.name) @@ -173,7 +174,7 @@ def visit_helper(bsym: BoundSymbol, ex: Executor) -> None | bool: return None # We have mapped this at previous stages - if ex.name == 'empty': + if ex.name == self.empty_executor_hashable_placeholder: return None # The call above represent: # if bsym.sym.executor is not None: @@ -186,7 +187,9 @@ def visit_helper(bsym: BoundSymbol, ex: Executor) -> None | bool: out = execution_transform(*bsym.args, **bsym.kwargs) elif isinstance(ex, OperatorExecutor): # Calls the operator executor's operation - op = ex.implmap[bsym.sym.id].symbol + op: Symbol | None = ex.implmap[bsym.sym.id].symbol + if op is None: + raise AssertionError('op is None') out = op(*bsym.args, **bsym.kwargs) elif isinstance(ex, FusionExecutor): # Preserves the symbol as is (it will be handled in the fusion pass) @@ -201,45 +204,10 @@ def visit_helper(bsym: BoundSymbol, ex: Executor) -> None | bool: def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: return transforms.VISIT_TYPE.NO_OP if visit_helper(bsym, ex) is None else transforms.VISIT_TYPE.REPLACE - def visitor_transform(trace_from: TraceCtx, executors: list[Executor]): - trc: TraceCtx = from_trace(trace_from) - - try: - tracectx_tok = set_tracectx(trc) - - for bsym, ex in zip(trace_from.bound_symbols, executors): - try: - # Creates a temporary scope to support copying the original bsym BEFORE - # the operations performed by visit(), even though this doesn't know whether to - # copy the original bsym until after visit() completes - old_scope = trc.scopes - scope = [] - trc.scopes = [scope] - - # This can be simpler? We currently trigger all the flow for the substitution - visit_type = visit(bsym, ex) - - if visit_type is transforms.VISIT_TYPE.INSERT_AFTER: - trc.bound_symbols.append(bsym) - - if visit_type is not transforms.VISIT_TYPE.NO_OP: - trc.bound_symbols.extend(scope) - else: - trc.bound_symbols.append(bsym) - - if visit_type is transforms.VISIT_TYPE.INSERT_BEFORE: - trc.bound_symbols.append(bsym) - - finally: - # Restores the trc's scope - trc.scopes = old_scope - - return trc + # for s, o in zip(self.trace.bound_symbols, executor_list): + # print(f'{s} -> {o}') - finally: - reset_tracectx(tracectx_tok) - - extrace = visitor_transform(self.trace, executor_list) + extrace = transforms.visitor_transform_paired(self.trace, visit, zip(self.trace.bound_symbols, executor_list)) # Restores original variables bound_symbols: list[BoundSymbol] = [] @@ -249,73 +217,176 @@ def visitor_transform(trace_from: TraceCtx, executors: list[Executor]): extrace.bound_symbols = bound_symbols - end_time_ns = time.time_ns() - elapsed_time_ns = end_time_ns - start_time_ns - elapsed_time_millis = elapsed_time_ns // 1000000 - extrace.set_provenance( - TraceProvenance(f"Transform for operator executor execution (took {elapsed_time_millis} milliseconds)") - ) - - print('============================================ trace before fusion pass') - pprint.pprint(extrace) + # print('============================================ trace before fusion pass') + # pprint.pprint(extrace) # We have to temporary clear the subsymbols of already claimed symbols by not fusion ops, otherwise fusion ops will check recursively subsymbols and clear all the current placements - cached_subsymbols: list[Sequence[BoundSymbolInterface]] = [list(symbol.subsymbols) for symbol in extrace.bound_symbols] - subsymbols_idx_to_restore: list[int] = [] + cached_subsymbols: dict[str, Sequence[BoundSymbolInterface]] = {} unique_fusion_executors = set() - for idx, ex in enumerate(executor_list): + for ex, bsym in zip(executor_list, extrace.bound_symbols): + bsym_hash: str = hex(id(bsym)) + cached_subsymbols[bsym_hash] = list(bsym.subsymbols) if isinstance(ex, FusionExecutor): unique_fusion_executors.add(ex) else: - subsymbols_idx_to_restore.append(idx) - extrace.bound_symbols[idx].subsymbols = () + bsym.subsymbols = () + # Perform fusion pass for ex in unique_fusion_executors: extrace = ex.fusion_pass(extrace) # Restore the subsymbols - for idx in subsymbols_idx_to_restore: - extrace.bound_symbols[idx].subsymbols = cached_subsymbols[idx] + for bsym in extrace.bound_symbols: + hash = hex(id(bsym)) + if hash in cached_subsymbols: + bsym.subsymbols = cached_subsymbols[hash] + + # print('============================================ trace after fusion pass') + # pprint.pprint(extrace) + + # Apply always executors + extrace = _transform_for_operator_executor_execution(extrace, self.always_executors) + + # print('============================================ trace after always executors pass') + # pprint.pprint(extrace) return extrace def build_search_space(self): + import thunder.core.codeutils as cutils + self.build_placement_options() + for option in self.placement_options: - trace = self.place_optimizers(option) option_str = [str(ex.name) for ex in option] - option_str = ' '.join(option_str) - print(f'============================================ config_trace, optimizers: {option_str}') - pprint.pprint(trace) - - # visited = set() - # def dfs(node: Node): - # visited.add(node.ID) - # childs = node.children - # node_symbols = [str(s.sym.id) for s in node.group_bsyms] - # node_symbols = " ".join(node_symbols) - # node_bsym = node.group_bsyms[0] - - # optimizer_node: OptimizerNode = OptimizerNode(node) - - # print(f'-> Node id: {node.ID}, symbols: {node_symbols}') - - # ex: Executor - # for ex in self.executors: - # print(f'analyzing executor {ex._name}') - # if (isinstance(ex, OperatorExecutor) and ex.can_execute(node_bsym)): - # print(f'{ex._name} can execute symbol {node_bsym.sym.name}') - # optimizer_node.add_candidate(ex, 1.0) - # if (isinstance(ex, FusionExecutor) and ex.can_fuse(node_bsym)): - # print(f'{ex._name} can fuse symbol {node_bsym.sym.name}') - # optimizer_node.add_candidate(ex, 1.0) - - # self.optimizer_nodes.append(optimizer_node) - - # if childs: - # for c in childs: - # if c.ID not in visited: - # dfs(c) - - # for root in self.computation_graph.roots: - # dfs(root) + option_str = '-'.join(option_str) + # print(f'============================================ optimizers len {len(option)}: {option_str}') + trace = self.place_optimizers(option) + + if self.visualizer is not None: + sig_name = cutils.get_siginfo_name(trace) + # TODO (matteochen): consider adding more infos for naming + self.visualizer.set_hidden_trace(f'hidden-{sig_name}-{option_str}', trace) + + self.optimzide_traces.append(trace) + + def get_optimal_trace(self) -> TraceCtx: + return self.optimal_trace + + def benchmark_trace(self, trace: TraceCtx) -> float: + + input_args = [] + + def print_input_args(args, level=0, show_content = False): + for e in args: + if isinstance(e, tuple) or isinstance(e, list): + print_input_args(e, level=level+1) + else: + print(f'level {level}', type(e)) + + def print_trace_execution_output(out: Any, show_content=False): + if isinstance(out, tuple): + for e in out: + print(f'{type(e)}') + else: + print(f'{type(out)}') + + def thunder_to_torch_float_dtype(byte: int) -> torch.dtype: + if (byte == 2): + return torch.float16 + elif (byte == 4): + return torch.float32 + else: + return torch.float64 + + def transform_input_tuple(t: tuple, level=0) -> tuple: + res = [] + for e in t: + if type(e) is tuple: + res.append(transform_input_tuple(e, level+1)) + else: + # print(f'level {level}', type(e)) + if isinstance(e, TensorProxy): + res.append(transform_tensor(e)) + else: + # TODO (matteochen): support more data types + raise AssertionError(f'Input arg type not recognized: {type(e)}') + return tuple(res) + + def transform_tensor(arg: TensorProxy) -> torch.Tensor: + dtype = arg.dtype + if dtype is not None and type(dtype) is thunder.dtypes.floating: + torch_dtype = thunder_to_torch_float_dtype(dtype.bytes) + # print(f'thunder type: {dtype} torch_dtype: {torch_dtype}') + else: + # TODO (matteochen): support other types + raise AssertionError(f"dtype {dtype} not supported yet") + + shape = arg.shape + device = arg.device + requires_grad = arg.requires_grad + # TODO (matteochen): Missing parallel and fsdp handling... + # TODO (matteochen): Missing support for meta types ... + tensor: torch.Tensor = torch.randn(*shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad) + # print(f'Adding tensor shape: {tensor.shape} dtype: {tensor.dtype} device: {tensor.device} requires_grad: {tensor.requires_grad}') + return tensor + + # Can we remove this check? + if isinstance(trace.args, list): + for arg in trace.args: + # print(f'current arg {arg}\ntype {type(arg)}') + if isinstance(arg, tuple): + # print('Processig tuple') + input_args.append(transform_input_tuple(arg)) + elif isinstance(arg, TensorProxy): + # print('Processig TensorProxy') + e = transform_tensor(arg) + input_args.append(e) + else: + raise AssertionError(f'Input arg type not recognized: {type(arg)}') + else: + raise AssertionError('Unexpexcted args type') + + # print('========================================= benchmark_trace: input_args') + # print_input_args(input_args, level=0) + + # TODO (matteochen): measure time + trace_tok = set_tracectx(trace) + + # Obtain the python executable string + executable_str = trace.python_callable() + t, _ = self.compute_time_cost(executable_str, 10, *input_args) + + reset_tracectx(trace_tok) + + # Note, currently the forward pass returns a tuple: + # ( + # dict, + # ... + # ) + # We have to access the dict['output'] in order to get the forward computation result + + if self.produce_log: + self.log_str += f'Time taken: {t / 1000000}ms\n' + self.log_str += trace.python() + self.log_str += '\n#############################################################################################################\n' + + # print('========================================= benchmark_trace out') + # print_trace_execution_output(out) + + return t + + def benchmark_traces(self): + min_run_time = float('inf') + optimal_trace: TraceCtx = self.trace # Assign initial value for unbound errors + for trace in self.optimzide_traces: + trace_time = self.benchmark_trace(trace) + if trace_time < min_run_time: + min_run_time = trace_time + optimal_trace = trace + + self.optimal_trace = optimal_trace + + with open(self.log_file_name, 'w') as file: + file.write(self.log_str) + file.close() diff --git a/thunder/common.py b/thunder/common.py index 94328dfa4e..23770a0975 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -613,7 +613,6 @@ def wait_for_future(f: FutureTensorProxy) -> TensorProxy: def transform_for_execution( trace: TraceCtx, executors_list: Sequence[Executor], - label = "default", *, only_execute_prims=False, use_rematerialization=True, @@ -630,7 +629,7 @@ def transform_for_execution( # cse_trace = cse(dce_trace) # traces.append(cse_trace) - extrace = executors.passes.transform_for_execution(dce_trace, executors_list, label) + extrace = executors.passes.transform_for_execution(dce_trace, executors_list) traces.append(extrace) diff --git a/thunder/core/codeutils.py b/thunder/core/codeutils.py index 2c74cc8053..59d34a8ed8 100644 --- a/thunder/core/codeutils.py +++ b/thunder/core/codeutils.py @@ -5,15 +5,12 @@ from collections.abc import Mapping, Sequence, Iterable import inspect from inspect import Parameter -import string import functools from functools import partial import dis import linecache import dataclasses -import torch - import thunder.core.baseutils as baseutils from thunder.core.baseutils import ProxyInterface, check import thunder.core.dtypes as dtypes @@ -456,3 +453,16 @@ class NamedBindings: si.defaultdict = default_dict si.unwrapped_fn = unwrapped return si + +def get_siginfo_name(trace) -> str: + try: + name = "" + if trace.fn is not None: + siginfo: SigInfo = get_siginfo(trace.fn, trace.args, trace.kwargs) + name = siginfo.name + else: + name = 'unknown' + + return name + except Exception as e: + raise AssertionError(f'Is input trace an instance of TraceCtx?\n{e}') diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 5b3491a5af..cc24140249 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -334,6 +334,65 @@ class VISIT_TYPE(Enum): NO_OP = auto() +# Creates a new trace from "trace_from" by calling "visit" on its bound symbols ("bsyms") paired with an assigned executor. +# visit(bsym: BoundSymbolInterface, ex: Executor) -> VISIT_TYPE should call operations +# as if executing a program, and those operations will be recorded into the +# new trace. +# If visit() returns INSERT_AFTER for a bsym then that bsym will be copied +# to the new trace before visit() is called. This is useful when augmenting the bound +# symbols in an existing trace. +# If visit() returns INSERT_BEFORE for a bsym then that bsym will be copied to the new trace +# after visit() is called. This is also useful when augmenting the bound symbols in an existing +# trace. +# If visit() returns REPLACE for a bsym then that bsym will not be copied to the new trace. +# TODO Suggest a mechanism to preserve the original bound symbol with operations +# recorded both before and after it. This could be done by passing the (sub)scope to visit() for +# direct modification, acquiring the trace's current scope through the trace ctx and modifying it +# directly (this can be done today), or adding a record() function that is a sugar for the previous +# approach. Perhaps both passing the scope directly to visit() and adding record() would be helpful. +# TODO(crcrpar): Think about providing a guide how to let thunder "claim" if this is called after +# `thunder.executors.transform_for_execution`. +def visitor_transform_paired(trace_from: Trace, visit: Callable, zipped: zip, *, provenance: None | str = None): + trc: Trace = from_trace(trace_from) + + try: + tracectx_tok = set_tracectx(trc) + + for bsym, ex in zipped: + try: + # Creates a temporary scope to support copying the original bsym BEFORE + # the operations performed by visit(), even though this doesn't know whether to + # copy the original bsym until after visit() completes + old_scope = trc.scopes + scope = [] + trc.scopes = [scope] + + # This can be simpler? We currently trigger all the flow for the substitution + visit_type = visit(bsym, ex) + + if visit_type is VISIT_TYPE.INSERT_AFTER: + trc.bound_symbols.append(bsym) + + if visit_type is not VISIT_TYPE.NO_OP: + trc.bound_symbols.extend(scope) + else: + trc.bound_symbols.append(bsym) + + if visit_type is VISIT_TYPE.INSERT_BEFORE: + trc.bound_symbols.append(bsym) + + finally: + # Restores the trc's scope + trc.scopes = old_scope + + if provenance is not None: + trc.set_provenance(TraceProvenance(provenance)) + + return trc + + finally: + reset_tracectx(tracectx_tok) + # Creates a new trace from "trace_from" by calling "visit" on its bound symbols ("bsyms"). # visit(bsym: BoundSymbolInterface) -> VISIT_TYPE should call operations # as if executing a program, and those operations will be recorded into the @@ -3609,16 +3668,8 @@ def forward_and_backward_from_trace(trace: Trace, torch_autograd=False) -> Forwa output_spec = None - import pprint - def augmented_forward_fn(*args, **kwargs): result, env = augmented_forward_pass(*args, trace=trace, **kwargs) - print('============================================ START: augmented_forward_pass') - print('result') - pprint.pprint(result) - print('env') - pprint.pprint(env) - print('============================================ END: augmented_forward_pass') saved_for_backward = deconstruct_forward_env_for_backward(trace, env) if torch_autograd: nonlocal output_spec diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index 30cf63227c..bda0d6a2dc 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -1,3 +1,4 @@ +from pprint import pprint from typing import Dict, Any, List, Tuple, Optional from collections.abc import Callable from collections.abc import Sequence @@ -7,14 +8,15 @@ from functools import partial import time -from thunder.core.trace import TraceCtx, from_trace, TraceProvenance, VariableInterface +from thunder.core.trace import TraceCtx, from_trace, TraceProvenance, VariableInterface, reset_tracectx, set_tracectx +from thunder.core.codeutils import SigInfo import thunder.core.dtypes as dtypes import thunder.core.utils as cutils from thunder.core.utils import ProxyDict, check, safe_map_flat from thunder.core.symbol import BoundSymbol from thunder.core.pytree import tree_flatten, tree_unflatten, tree_map import thunder.core.prims as prims -from thunder.core.proxies import Proxy, variableify, unvariableify, Variable, CollectionProxy +from thunder.core.proxies import Proxy, TensorProxy, variableify, unvariableify, Variable, CollectionProxy import thunder.core.transforms as transforms from thunder.core.transform_common import dce from thunder.core.trace import get_tracectx @@ -23,6 +25,7 @@ from thunder.extend import Executor, get_all_executors, get_always_executors, OperatorExecutor, FusionExecutor from thunder.backend_optimizer.optimizer import BackendOptimizer from thunder.visualizer.graphviz import create_graphviz_pdf +from thunder.visualizer.visualizer_helper import Visualizer comment_symbols = {prims.PrimIDs.COMMENT, prims.PrimIDs.UNPACK_TRIVIAL} @@ -132,17 +135,12 @@ def visit_(bsym: BoundSymbol) -> transforms.VISIT_TYPE: ) return extrace - -def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor], label='default') -> TraceCtx: +# Autotuned transform_for_execution version +def autotune_transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor], visualizer: Visualizer | None = None) -> TraceCtx: import torch - import pprint - from thunder.executors.data_dependent_partition import Graph - - what_to_log = "forward_trc" - print(f'============================================ START: LABEL {label}') - print(f'============================================ START: executor_list: {executors_list}') - print(f'============================================ START: always_executor_list: {get_always_executors()}') + # Recover the function name + sig_name = cutils.get_siginfo_name(trace) start_time_ns = time.time_ns() @@ -154,33 +152,36 @@ def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor], trace = dce(trace) - if (label == what_to_log): - print('============================================ START: before _transform_for_operator_executor_execution') - pprint.pprint(trace) - print('============================================ GRAPH: before _transform_for_operator_executor_execution') - print(Graph(trace)) - print('============================================ END: before _transform_for_operator_executor_execution') + backend_optimizer = BackendOptimizer(trace, executors_list, produce_log=True, log_file_name=f'autotune_transform_for_execution_{sig_name}.log', visualizer=visualizer) + backend_optimizer.build_search_space() + backend_optimizer.benchmark_traces() + extrace = backend_optimizer.get_optimal_trace() + end_time_ns = time.time_ns() + elapsed_time_ns = end_time_ns - start_time_ns + elapsed_time_millis = elapsed_time_ns // 1000000 - if label == what_to_log: - print('============================================ start: BACKEND_OPTIMIZER') - backend_optimizer = BackendOptimizer(trace, executors_list) - backend_optimizer.build_search_space() - backend_optimizer.write("backend_optimizer.log") - print('============================================ end: BACKEND_OPTIMIZER') + extrace.set_provenance(TraceProvenance(f"Autotuned transform for execution (took {elapsed_time_millis} milliseconds)")) + return extrace + + +def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor]) -> TraceCtx: + import torch + + start_time_ns = time.time_ns() + + if torch.distributed.is_available(): + # Apply AllReduce bucketing if possible & needed + from thunder.distributed.transforms.ddp import apply_bucketing_to_grad_allreduce + + trace = apply_bucketing_to_grad_allreduce(trace) + + trace = dce(trace) # # Step 1 Performs execution transforms # extrace = _transform_for_operator_executor_execution(trace, executors_list) - if (label == what_to_log): - print('============================================ START: after _transform_for_operator_executor_execution') - pprint.pprint(extrace) - print('============================================ GRAPH: _transform_for_operator_executor_execution') - g = Graph(trace) - create_graphviz_pdf(g, label) - print(g) - print('============================================ END: after _transform_for_operator_executor_execution') extrace = dce(extrace) @@ -191,15 +192,6 @@ def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor], if isinstance(ex, FusionExecutor): extrace = ex.fusion_pass(extrace) - if (label == what_to_log): - print('============================================ START: after fusion_pass') - pprint.pprint(extrace) - print('============================================ GRAPH: fusion_pass') - g = Graph(extrace) - create_graphviz_pdf(g, f'{label}_fusion') - print(g) - print('============================================ END: after fusion_pass') - # # Step 3 "Always" executors are given the opportunity to execute unclaimed symbols # @@ -208,15 +200,6 @@ def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor], # NOTE This occurs if a fusion executor declines to execute a symbol after running its fusion pass extrace = _transform_for_operator_executor_execution(extrace, get_always_executors()) - if (label == what_to_log): - print('============================================ START: after _transform_for_operator_executor_execution (always)') - pprint.pprint(extrace) - print('============================================ GRAPH: fusion_pass') - g = Graph(extrace) - create_graphviz_pdf(g, f'{label}_final') - print(g) - print('============================================ END: after _transform_for_operator_executor_execution (always)') - end_time_ns = time.time_ns() elapsed_time_ns = end_time_ns - start_time_ns elapsed_time_millis = elapsed_time_ns // 1000000 diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index b543f06921..57bf6c58ea 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -105,13 +105,16 @@ def backward(ctx, *args): del grads return (None, None, None, None, None, *([None] * n_grads)) - +# TODO (matteochen): add control for using autotuner or not def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stats, /, *flat_args): from thunder.core.rematerialization import rematerialize_all_gather, rematerialize_forward_and_backward from thunder.core.transforms import forward_and_backward_from_trace from thunder.distributed.transforms import FSDPCommBucketing from thunder.distributed.utils import sort_data_parallel_syncs, sort_waits, sort_communication_ops - from thunder.executors.passes import del_last_used, transform_for_execution + from thunder.executors.passes import del_last_used, autotune_transform_for_execution + from thunder.visualizer.visualizer_helper import Visualizer + + visualizer = Visualizer(produce_hidden=False) utils.check(compile_data is not None, lambda: "`compile_data` is required") # NOTE: This function is rather slow, so it's intended to be used @@ -122,18 +125,9 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat if not any(requires_grad_mask): raise RuntimeError("PyTorch's Autograd interface requires at least one tensor input with requires_grad=True") - import pprint - print('============================================ START: computation_trc split_forward_backward') - pprint.pprint(computation_trc) - print('============================================ END: computation_trc split_forward_backward') - primal_trace = computation_trc primal_trace = sort_data_parallel_syncs(primal_trace) - print('============================================ START: primal_trace sort_data_parallel_syncs') - pprint.pprint(primal_trace) - print('============================================ END: primal_trace sort_data_parallel_syncs') - if compile_stats is not None: compile_stats.last_traces.append(primal_trace) @@ -143,9 +137,6 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat # not any other container type. So we need to flatten the outputs of # the forward trace and inputs of the backward trace. fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True) - print('============================================ START: primal_trace forward_and_backward_from_trace') - pprint.pprint(fw_trace) - print('============================================ END: primal_trace forward_and_backward_from_trace') fw_traces = [fw_trace] bw_traces = [bw_trace] @@ -178,16 +169,14 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat # Now we can run the optimization passes on the forward trace # TODO Restore request for no rematerialization - import pprint - print('============================================ START: before forward_trc transform_for_execution') - pprint.pprint(fw_trace) - print('============================================ END: after forward_trc transform_for_execution') - fw_extrace = transform_for_execution( + visualizer.set_fw_initial_trace(fw_trace) + fw_extrace = autotune_transform_for_execution( fw_trace, executors_list=compile_data.executors_list, - label='forward_trc' + visualizer=visualizer ) fw_traces.append(fw_extrace) + visualizer.set_fw_optimized_trace(fw_extrace) # Some of the optimization passes change proxies in the trace and # any change in the forward trace must be reflected in the backward @@ -220,12 +209,14 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat # Now we can run the optimization passes on the backward trace # TODO Restore request for no rematerialization - bw_extrace = transform_for_execution( + visualizer.set_bw_initial_trace(bw_trace) + bw_extrace = autotune_transform_for_execution( bw_trace, executors_list=compile_data.executors_list, - label='backward_trc' + visualizer=visualizer ) bw_traces.append(bw_extrace) + visualizer.set_bw_optimized_trace(bw_extrace) fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) fw_traces.append(fw_extrace) @@ -295,4 +286,10 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat # We only want the forward function to be called with `te.fp8_autocast` manager. bw_extrace._include_te_fp8_autocast = False + # Let's include the last traces also after all the passes + visualizer.set_fw_final_trace(fw_extrace) + visualizer.set_bw_final_trace(bw_extrace) + + visualizer.produce() + return fw_extrace, bw_extrace diff --git a/thunder/visualizer/graphviz.py b/thunder/visualizer/graphviz.py index 8b3ebd78a9..db3a98f17f 100644 --- a/thunder/visualizer/graphviz.py +++ b/thunder/visualizer/graphviz.py @@ -31,7 +31,6 @@ def to_graphviz_dag(g: Graph) -> graphviz.Digraph: return dot -def create_graphviz_pdf(g: Graph, name='graph'): +def create_graphviz_pdf(g: Graph, name='graph', directory='./'): dot = to_graphviz_dag(g) - dot.render(name, view=False) - + dot.render(name, view=False, cleanup=True, directory=directory) diff --git a/thunder/visualizer/visualizer_helper.py b/thunder/visualizer/visualizer_helper.py new file mode 100644 index 0000000000..3abdf26121 --- /dev/null +++ b/thunder/visualizer/visualizer_helper.py @@ -0,0 +1,48 @@ +from thunder.core.trace import TraceCtx +from thunder.core.transform_common import dce +from thunder.executors.data_dependent_partition import Graph +from thunder.visualizer.graphviz import create_graphviz_pdf + +class Visualizer(): + def __init__(self, produce_hidden = False, traces_directory='traces/') -> None: + self.produce_hidden = produce_hidden + self.traces: dict[str, TraceCtx] = {} + self.hidden_traces: dict[str, TraceCtx] = {} + self.traces_directory = traces_directory + + def set_fw_initial_trace(self, trace: TraceCtx) -> None: + self.traces['fw_initial'] = dce(trace) + + def set_fw_optimized_trace(self, trace: TraceCtx) -> None: + self.traces['fw_optimized'] = dce(trace) + + def set_fw_final_trace(self, trace: TraceCtx) -> None: + self.traces['fw_final'] = dce(trace) + + def set_bw_initial_trace(self, trace: TraceCtx) -> None: + self.traces['bw_initial'] = dce(trace) + + def set_bw_optimized_trace(self, trace: TraceCtx) -> None: + self.traces['bw_optimized'] = dce(trace) + + def set_bw_final_trace(self, trace: TraceCtx) -> None: + self.traces['bw_final'] = dce(trace) + + def set_hidden_trace(self, name: str, trace: TraceCtx) -> None: + self.traces[name] = dce(trace) + + def produce(self): + for k, v in self.traces.items(): + try: + g = Graph(v) + create_graphviz_pdf(g, k, directory=self.traces_directory) + except Exception as e: + print(f"Visualizer failed to produce {k}: {e}") + + if self.produce_hidden: + for k, v in self.hidden_traces.items(): + try: + g = Graph(v) + create_graphviz_pdf(g, k, directory=self.traces_directory) + except Exception as e: + print(f"Visualizer failed to produce hidden {k}: {e}") From 7781d759830d4039b30d05f07c07f2ec703fce16 Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Wed, 10 Jul 2024 19:09:37 +0300 Subject: [PATCH 004/171] Serial greedy search --- thunder/backend_optimizer/optimizer.py | 546 ++++++++++++++++++++++--- thunder/executors/passes.py | 2 +- 2 files changed, 497 insertions(+), 51 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 771ff29059..cc4d5bef1f 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -6,7 +6,7 @@ from thunder.core.proxies import Proxy, TensorProxy, variableify, Variable from thunder.core.symbol import BoundSymbol, Symbol from thunder.executors.data_dependent_partition import Graph, Node -from thunder.core.trace import set_tracectx, reset_tracectx, get_tracectx, TraceCtx +from thunder.core.trace import from_trace, set_tracectx, reset_tracectx, get_tracectx, TraceCtx from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_always_executors import thunder.core.transforms as transforms from thunder.visualizer.visualizer_helper import Visualizer @@ -14,7 +14,8 @@ from enum import Enum from itertools import chain import time -# import pprint +# import concurrent.futures +import pprint class OptimizerNode(): def __init__(self, node: Node): @@ -25,23 +26,31 @@ def add_candidate(self, ex: Executor, benchmark: float): self.candidate_executors[ex] = benchmark class BackendOptimizer(): + def log(self, what: str): + print(f'================================================================================ Autotune: {what}') + def __init__(self, trace: TraceCtx, executors: Sequence[Executor], produce_log=True, log_file_name='autotune_traces_computation_time.log', visualizer: Visualizer | None = None) -> None: self.trace: TraceCtx = trace + self.incremental_search_out_trace: TraceCtx self.optimal_trace: TraceCtx = trace self.computation_graph: Graph = Graph(trace) self.executors: Sequence[Executor] = executors + self.fusion_executors: Sequence[FusionExecutor] = [ex for ex in executors if isinstance(ex, FusionExecutor)] self.empty_executor_hashable_placeholder: str = 'empty' self.placement_options: list[list[Executor]] = [] - self.optimzide_traces: list[TraceCtx] = [] + self.optimized_traces: list[TraceCtx] = [] self.always_executors: tuple[Executor, ...] = get_always_executors() self.produce_log: bool = produce_log self.log_file_name: str = log_file_name self.log_str: str = "" self.visualizer: Visualizer | None = visualizer + self.partial_costs: dict[TraceCtx, float] = {} + + self.log(f'New trace to optimize\n{self.trace}') - print('INIT TRACE') - import pprint - pprint.pprint(self.trace) + class OptimizationStrat(Enum): + EXAUSTIVE = 1 + GREEDY = 2 # TODO (matteochen): fix this def __repr__(self) -> str: @@ -64,7 +73,333 @@ def compute_time_cost(self, fn: Callable, iters: int, *args) -> tuple[float, Any return total_time / iters, out - def build_placement_options(self): + # TODO (matteochen): this has a lot in common with the exaustive search, compact them + def build_placement_options_incremental(self): + class SearchNode: + def __init__(self, symbol: BoundSymbolInterface, idx: int)-> None: + self.symbol = symbol + self.idx = idx + + # Last index inclusive + def benchmark_partial_trace(trace_in: TraceCtx, last_idx: int, configuration: list[Executor]) -> tuple[float, TraceCtx, Any]: + + # Retrive all output tensors from each subregion + tensors = [] + for i in range(last_idx+1): + if not isinstance(trace_in.bound_symbols[i], BoundSymbol): + raise AssertionError('Expected BoundSymbol but received BoundSymbolInterface') + s = trace_in.bound_symbols[i] + # For each bsym region we expect to output a Tensor + tensors.append(s.output) + # print('Tensors inside partial trace') + # for t in tensors: + # print(t) + + forced_return_bsym = trace_in.bound_symbols[-1].from_bsym(args=tensors) # Should not be an Interface type at this point + + t = from_trace(trace_in) + # Cut the trace to the required depth + t.bound_symbols = list(trace_in.bound_symbols)[:last_idx+1] + + t.bound_symbols.append(forced_return_bsym) + configuration.append(Executor(name=self.empty_executor_hashable_placeholder)) # Empty executor for the forced_return + + # self.log(f'Debug\n{len(t.bound_symbols)}\n{len(exs)}') + # self.log(f'Debug\n{(t)}\n') + + # Place the assigned symbols + placed_t = self.place_optimizers(t, configuration) + + cost, answer = self.benchmark_trace(placed_t) + self.log(f'Executing partial trace for incremental benchmark:\n{placed_t}') + self.log(f'Symbol under test = {t.bound_symbols[-2].sym.name}') + self.log(f'Assigned executor = {configuration[-2].name}') + self.log(f'Time = {cost/1000000} ms') + self.partial_costs[t] = cost + return cost, placed_t, answer + + # We assign an internal id to each symbol based on its idx inside the bound_symbols list + def search(node: SearchNode, configuration: list[Executor]): + + def continue_search(time_inc: float): + if node.idx+1 < max_len: + new_idx: int = node.idx + 1 + new_symbol: BoundSymbolInterface = bound_symbols[new_idx] + search(SearchNode(new_symbol, new_idx), configuration) + else: + all_configurations.append(configuration) + + has_backend = False + min_cost = float('inf') + min_cost_ex = None + ex: Executor + for ex in self.executors: + if not isinstance(node.symbol, BoundSymbol): + raise AssertionError("Receive a symbol which is not a BoundSymbol") + if (isinstance(ex, OperatorExecutor) and ex.can_execute(node.symbol)): + # print(f'{node.idx}-{ex._name} can execute symbol {node.symbol.sym.name}') + # safe_update_dict(node.idx, ExecutorType.OPERATOR, ex) + has_backend = True + + configuration.append(ex) + cost, extrace, tensor_out = benchmark_partial_trace(self.trace, node.idx, list(configuration)) + configuration.pop() + + if cost < min_cost: + min_cost = cost + min_cost_ex = ex + + if (isinstance(ex, FusionExecutor) and ex.can_fuse(node.symbol)): + # print(f'{node.idx}-{ex._name} can fuse symbol {node.symbol.sym.name}') + # safe_update_dict(node.idx, ExecutorType.FUSER, ex) + has_backend = True + + configuration.append(ex) + cost, extrace, tensor_out = benchmark_partial_trace(self.trace, node.idx, list(configuration)) + configuration.pop() + + if cost < min_cost: + min_cost = cost + min_cost_ex = ex + + if not has_backend: + configuration.append(empty_executor) + continue_search(0.0) + else: + if min_cost_ex is None: + raise AssertionError("Unexpected min cost executor or trace: None") + self.log(f'For id: {node.idx} - {node.symbol.sym.name} -> best backend {min_cost_ex.name}\n') + # log_min_cost_trace(min_cost_trace) + configuration.append(min_cost_ex) + continue_search(min_cost) + + # res: dict[int, dict[ExecutorType, list[Executor]]] = {} + bound_symbols: list[BoundSymbolInterface] = self.trace.bound_symbols + max_len = len(bound_symbols) + + all_configurations: list[list[Executor]] = [] + # Is the name reserved? + empty_executor = Executor(name=self.empty_executor_hashable_placeholder) + + if len(bound_symbols) > 0: + search(SearchNode(bound_symbols[0], 0), []) + self.placement_options = all_configurations + + # TODO (matteochen): this has a lot in common with the exaustive search, compact them + # def build_placement_options_incremental(self): + # class SearchNode: + # def __init__(self, symbol: BoundSymbolInterface, idx: int)-> None: + # self.symbol = symbol + # self.idx = idx + + # def retrieve_executors_from_trace(trace_in: TraceCtx, last_symbol_idx:int = -1) -> list[Executor]: + # executors: list[Executor] = [] + # if last_symbol_idx == -1: + # last_symbol_idx = len(trace_in.bound_symbols) + # for i in range(last_symbol_idx): + # if not isinstance(trace_in.bound_symbols[i], BoundSymbol): + # raise AssertionError('Expected BoundSymbol but received BoundSymbolInterface') + # s = trace_in.bound_symbols[i] + # if s.sym.executor is None: + # executors.append(empty_executor) + # else: + # executors.append(s.sym.executor) + # return executors + + # # Last index inclusive + # def benchmark_partial_trace(trace_in: TraceCtx, last_idx: int, new_ex: Executor) -> tuple[float, TraceCtx, Any]: + + # exs: list[Executor] = retrieve_executors_from_trace(trace_in, last_idx) + # # for i in range(last_idx): + # # if not isinstance(trace_in.bound_symbols[i], BoundSymbol): + # # raise AssertionError('Expected BoundSymbol but received BoundSymbolInterface') + # # s = trace_in.bound_symbols[i] + # # if s.sym.executor is None: + # # exs.append(empty_executor) + # # else: + # # exs.append(s.sym.executor) + # exs.append(new_ex) + + # # Retrive all output tensors from each subregion + # tensors = [] + # for i in range(last_idx+1): + # if not isinstance(trace_in.bound_symbols[i], BoundSymbol): + # raise AssertionError('Expected BoundSymbol but received BoundSymbolInterface') + # s = trace_in.bound_symbols[i] + # # For each bsym region we expect to output a Tensor + # tensors.append(s.output) + # # print('Tensors inside partial trace') + # # for t in tensors: + # # print(t) + + # forced_return_bsym = trace_in.bound_symbols[-1].from_bsym(args=tensors) # Should not be an Interface type at this point + + # t = from_trace(trace_in) + # # Cut the trace to the required depth + # t.bound_symbols = list(trace_in.bound_symbols)[:last_idx+1] + + # t.bound_symbols.append(forced_return_bsym) + # exs.append(Executor(name=self.empty_executor_hashable_placeholder)) # Empty executor for the forced_return + + # # self.log(f'Debug\n{len(t.bound_symbols)}\n{len(exs)}') + # # self.log(f'Debug\n{(t)}\n') + + # # Place the assigned symbols + # placed_t = self.place_optimizers(t, exs) + + # cost, answer = self.benchmark_trace(placed_t) + # self.log(f'Executing partial trace for incremental benchmark:\n{placed_t}') + # self.log(f'Symbol under test = {t.bound_symbols[-2].sym.name}') + # self.log(f'Assigned executor = {exs[-2].name}') + # self.log(f'Time = {cost/1000000} ms') + # self.partial_costs[t] = cost + # return cost, placed_t, answer + + # # We assign an internal id to each symbol based on its idx inside the bound_symbols list + # def search(node: SearchNode, time_so_far: float): + + # def continue_search(time_inc: float): + # if node.idx+1 < max_len: + # new_idx: int = node.idx + 1 + # new_symbol: BoundSymbolInterface = bound_symbols[new_idx] + # search(SearchNode(new_symbol, new_idx), time_so_far + time_inc) + # else: + # all_configurations.append(retrieve_executors_from_trace(self.incremental_search_out_trace)) + # self.log(f'Incremental search ended:\n{self.incremental_search_out_trace}\n{all_configurations[0]}') + + # # def safe_update_dict(idx: int, type: ExecutorType, ex: Executor): + # # if idx not in res: + # # res[idx] = {} + # # res[node.idx][type] = [ex] + # # else: + # # if type not in res[idx]: + # # res[node.idx][type] = [ex] + # # else: + # # res[node.idx][type].append(ex) + + # def log_min_cost_trace(trace: TraceCtx): + # self.log(f'Min cost trace:\n{trace}') + # b: BoundSymbol + # for b in trace.bound_symbols: + # self.log(f'sym = {b.sym.name} , ex = {b.sym.executor}') + + # def extend_min_cost_trace(trace_in: TraceCtx, idx_from_to_extend: int): + # new_items = list(self.trace.bound_symbols[idx_from_to_extend:]) + # # Remove the mock return statement + # trace_in.bound_symbols.pop() + # trace_in.bound_symbols.extend(new_items) + + # def update_self_trace(trace_in: TraceCtx): + # self.incremental_search_out_trace = from_trace(trace_in) + # self.incremental_search_out_trace.bound_symbols = list(trace_in.bound_symbols) + + # has_backend = False + # min_cost = float('inf') + # min_cost_ex = None + # min_cost_trace = from_trace(self.incremental_search_out_trace) + # min_cost_trace.bound_symbols = list(self.incremental_search_out_trace.bound_symbols) + # # self.log(f'New iter, node idx = {node.idx}') + # log_min_cost_trace(min_cost_trace) + + # trace_iter = from_trace(self.incremental_search_out_trace) + # trace_iter.bound_symbols = list(self.incremental_search_out_trace.bound_symbols) + + # # Seach for last placed executor index in min_cost_trace + # idx = 0 + # while idx < len(min_cost_trace.bound_symbols) and not self.bsym_assigned(min_cost_trace.bound_symbols[idx]): + # idx += 1 + # while idx < len(min_cost_trace.bound_symbols) and self.bsym_assigned(min_cost_trace.bound_symbols[idx]): + # idx += 1 + # # With Fusion operators, our trace will be collapsed. If the min_cost_trace is assigned to a trace that comes out from a fusion pass + # # the length of the partial trace (local optimal) to be injected inside benchmark_partial_trace is < node.idx + # idx = min(idx, node.idx) + + # ex: Executor + # for ex in self.executors: + # if not isinstance(node.symbol, BoundSymbol): + # raise AssertionError("Receive a symbol which is not a BoundSymbol") + # if (isinstance(ex, OperatorExecutor) and ex.can_execute(node.symbol)): + # # print(f'{node.idx}-{ex._name} can execute symbol {node.symbol.sym.name}') + # # safe_update_dict(node.idx, ExecutorType.OPERATOR, ex) + # has_backend = True + + # cost, extrace, tensor_out = benchmark_partial_trace(self.incremental_search_out_trace, idx, ex) + + # if cost < min_cost: + # min_cost = cost + # min_cost_ex = ex + # min_cost_trace = extrace + + # if (isinstance(ex, FusionExecutor) and ex.can_fuse(node.symbol)): + # # print(f'{node.idx}-{ex._name} can fuse symbol {node.symbol.sym.name}') + # # safe_update_dict(node.idx, ExecutorType.FUSER, ex) + # has_backend = True + + # cost, extrace, tensor_out = benchmark_partial_trace(self.incremental_search_out_trace, idx, ex) + + # if cost < min_cost: + # min_cost = cost + # min_cost_ex = ex + # min_cost_trace = extrace + + # if not has_backend: + # continue_search(0.0) + # # configuration.pop(-1) + # else: + # if min_cost_ex is None or min_cost_trace is None: + # raise AssertionError("Unexpected min cost executor or trace: None") + # self.log(f'For id: {node.idx} - {node.symbol.sym.name} -> best backend {min_cost_ex.name}') + # if node.idx + 1 < max_len: + # extend_min_cost_trace(min_cost_trace, node.idx+1) + # # log_min_cost_trace(min_cost_trace) + # update_self_trace(min_cost_trace) + # continue_search(min_cost) + + # # Assign search initial trace + # self.incremental_search_out_trace = from_trace(self.trace) + # self.incremental_search_out_trace.bound_symbols = list(self.trace.bound_symbols) + + # # res: dict[int, dict[ExecutorType, list[Executor]]] = {} + # bound_symbols: list[BoundSymbolInterface] = self.trace.bound_symbols + # max_len = len(bound_symbols) + + # all_configurations: list[list[Executor]] = [] + # # Is the name reserved? + # empty_executor = Executor(name=self.empty_executor_hashable_placeholder) + + # if len(bound_symbols) > 0: + # search(SearchNode(bound_symbols[0], 0), 0.0) + # self.placement_options = all_configurations + + # This expects a trace after the placement call. + # Nvfuser can be slower on the single trace region but can be faster by combining more of them, try to fuse then and compare + def try_to_fuse_after_executors_placement(self, trace_in: TraceCtx) -> TraceCtx: + best_trace: TraceCtx = trace_in + best_time, _ = self.benchmark_trace(best_trace) + trace_in_time = best_time + + self.log('Try to fuse') + + for bsym in trace_in.bound_symbols: + print(f'subsymbols: {bsym.subsymbols}') + + for ex in self.fusion_executors: + self.log(f'Try to fuse executor {ex.name}') + extrace = ex.fusion_pass(trace_in) + self.log(f'Fused trace:\n{extrace}') + extrace_time, _ = self.benchmark_trace(extrace) + self.log(f'Fused trace time:{extrace_time}') + + if extrace_time < best_time: + best_time = extrace_time + best_trace = extrace + + self.log(f'Trace in (time = {trace_in_time}):\n{trace_in}') + self.log(f'Best fused trace (time = {best_time}):\n{best_trace}') + + return best_trace + + def build_placement_options_exaustive(self): class ExecutorType(Enum): OPERATOR = 1 FUSER = 1 @@ -82,7 +417,7 @@ def continue_search(): new_symbol: BoundSymbolInterface = bound_symbols[new_idx] search(SearchNode(new_symbol, new_idx), configuration) else: - print(f'reached end of search for this tree branch {configuration}') + # print(f'reached end of search for this tree branch {configuration}') all_configurations.append(list(configuration)) def safe_update_dict(idx: int, type: ExecutorType, ex: Executor): @@ -122,26 +457,108 @@ def safe_update_dict(idx: int, type: ExecutorType, ex: Executor): res: dict[int, dict[ExecutorType, list[Executor]]] = {} bound_symbols: list[BoundSymbolInterface] = self.trace.bound_symbols - bound_symbols_name = [s.sym.name for s in bound_symbols] max_len = len(bound_symbols) all_configurations: list[list[Executor]] = [] # Is the name reserved? empty_executor = Executor(name=self.empty_executor_hashable_placeholder) - print(f'input trace bound symbols name len {len(bound_symbols_name)}: {bound_symbols_name}') - if len(bound_symbols) > 0: search(SearchNode(bound_symbols[0], 0), []) - print('end of search') self.placement_options = all_configurations - print('config len', len(all_configurations)) - # for config in all_configurations: - # c_str = [str(c.name) for c in config] - # c_str = " ".join(c_str) - # print(c_str) - def place_optimizers(self, executor_list: list[Executor]) -> TraceCtx: + # def build_placement_options_parallel(self): + # class ExecutorType(Enum): + # OPERATOR = 1 + # FUSER = 1 + + # class SearchNode: + # def __init__(self, symbol: BoundSymbolInterface, idx: int)-> None: + # self.symbol = symbol + # self.idx = idx + + # # We assign an internal id to each symbol based on its idx inside the bound_symbols list + # def search(node: SearchNode, configuration, all_configurations, level = 0): + # def update(): + # # print(f'{node.idx + 1} >= {max_len}, reached end of search for this tree branch (len = {len(configuration)}) {configuration}') + # all_configurations.append(list(configuration)) + + # def safe_update_dict(idx: int, type: ExecutorType, ex: Executor): + # if idx not in res: + # res[idx] = {} + # res[node.idx][type] = [ex] + # else: + # if type not in res[idx]: + # res[node.idx][type] = [ex] + # else: + # res[node.idx][type].append(ex) + + # futures = [] + # with concurrent.futures.ThreadPoolExecutor(max_workers=100) as concurrent_executor: + + # has_backend = False + # new_idx: int = node.idx + 1 + + # if new_idx >= max_len: + # # As this is the last symbol, we expect a return statement by default + # configuration.append(empty_executor) + # update() + # return + + # new_symbol: BoundSymbolInterface = bound_symbols[new_idx] + # new_node = SearchNode(new_symbol, new_idx) + + # ex: Executor + # for ex in self.executors: + + # if not isinstance(node.symbol, BoundSymbol): + # raise AssertionError("Receive a symbol which is not a BoundSymbol") + # if (isinstance(ex, OperatorExecutor) and ex.can_execute(node.symbol)): + # safe_update_dict(node.idx, ExecutorType.OPERATOR, ex) + # has_backend = True + # configuration.append(ex) + # futures.append(concurrent_executor.submit(search, new_node, list(configuration), all_configurations, level+1)) + # configuration.pop(-1) + # if (isinstance(ex, FusionExecutor) and ex.can_fuse(node.symbol)): + # safe_update_dict(node.idx, ExecutorType.FUSER, ex) + # has_backend = True + # configuration.append(ex) + # futures.append(concurrent_executor.submit(search, new_node, list(configuration), all_configurations, level+1)) + # configuration.pop(-1) + + # if not has_backend: + # configuration.append(empty_executor) + # futures.append(concurrent_executor.submit(search, new_node, list(configuration), all_configurations, level+1)) + # configuration.pop(-1) + + # if level == 0: + # concurrent.futures.wait(futures) + + # res: dict[int, dict[ExecutorType, list[Executor]]] = {} + # bound_symbols: list[BoundSymbolInterface] = self.trace.bound_symbols + # bound_symbols_name = [s.sym.name for s in bound_symbols] + # max_len = len(bound_symbols) + + # all: list[list[Executor]] = [] + # # Is the name reserved? + # empty_executor = Executor(name=self.empty_executor_hashable_placeholder) + + # print(f'input trace bound symbols name len {len(bound_symbols_name)}: {bound_symbols_name}') + + # import time + + # if len(bound_symbols) > 0: + # start = time.time_ns() + # search(SearchNode(bound_symbols[0], 0), [], all) + # end = time.time_ns() + # print(f'End of search, tot time = {(end - start)/1000000} ms. Configurations len = {len(all)}') + # self.placement_options = all + # # for config in all_configurations: + # # c_str = [str(c.name) for c in config] + # # c_str = " ".join(c_str) + # # print(c_str) + + def place_optimizers(self, in_trace, executor_list: list[Executor]) -> TraceCtx: from thunder.executors.passes import _transform_for_operator_executor_execution @@ -173,6 +590,11 @@ def visit_helper(bsym: BoundSymbol, ex: Executor) -> None | bool: if bsym.sym.python_impl is not None: return None + # if self.bsym_assigned(bsym): + # return None + # if bsym.sym.executor is not None: + # return None + # We have mapped this at previous stages if ex.name == self.empty_executor_hashable_placeholder: return None @@ -204,10 +626,7 @@ def visit_helper(bsym: BoundSymbol, ex: Executor) -> None | bool: def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: return transforms.VISIT_TYPE.NO_OP if visit_helper(bsym, ex) is None else transforms.VISIT_TYPE.REPLACE - # for s, o in zip(self.trace.bound_symbols, executor_list): - # print(f'{s} -> {o}') - - extrace = transforms.visitor_transform_paired(self.trace, visit, zip(self.trace.bound_symbols, executor_list)) + extrace = transforms.visitor_transform_paired(in_trace, visit, zip(in_trace.bound_symbols, executor_list)) # Restores original variables bound_symbols: list[BoundSymbol] = [] @@ -217,8 +636,7 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: extrace.bound_symbols = bound_symbols - # print('============================================ trace before fusion pass') - # pprint.pprint(extrace) + # self.log(f'Place optimizer, before fusion pass trace:\n{extrace}') # We have to temporary clear the subsymbols of already claimed symbols by not fusion ops, otherwise fusion ops will check recursively subsymbols and clear all the current placements cached_subsymbols: dict[str, Sequence[BoundSymbolInterface]] = {} @@ -241,39 +659,69 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: if hash in cached_subsymbols: bsym.subsymbols = cached_subsymbols[hash] - # print('============================================ trace after fusion pass') - # pprint.pprint(extrace) + # self.log(f'Place optimizer, after fusion pass trace:\n{extrace}') # Apply always executors extrace = _transform_for_operator_executor_execution(extrace, self.always_executors) - # print('============================================ trace after always executors pass') - # pprint.pprint(extrace) + # self.log(f'Place optimizer, after always executors:\n{extrace}') return extrace - def build_search_space(self): + # TODO (matteochen): add config for exaustive search or incremental one + def optimize(self, strat: OptimizationStrat = OptimizationStrat.GREEDY): import thunder.core.codeutils as cutils - self.build_placement_options() + def greedy(): + # This builds one option by default + self.build_placement_options_incremental() + + self.log(f'Placement options size: {len(self.placement_options)}') + + if len(self.placement_options) != 1: + raise AssertionError("Unexpected placement options size") - for option in self.placement_options: - option_str = [str(ex.name) for ex in option] - option_str = '-'.join(option_str) - # print(f'============================================ optimizers len {len(option)}: {option_str}') - trace = self.place_optimizers(option) + option = self.placement_options[0] + # option_str = [str(ex.name) for ex in option] + # option_str = '-'.join(option_str) + trace = self.place_optimizers(self.trace, option) + # trace = self.try_to_fuse_after_executors_placement(trace) - if self.visualizer is not None: - sig_name = cutils.get_siginfo_name(trace) - # TODO (matteochen): consider adding more infos for naming - self.visualizer.set_hidden_trace(f'hidden-{sig_name}-{option_str}', trace) + # There are no hidden placements hence do not call the visualizer - self.optimzide_traces.append(trace) + # Append the unique trace + self.optimized_traces.append(trace) + + def exaustive(): + # This builds one option by default + self.build_placement_options_exaustive() + + self.log(f'Placement options size: {len(self.placement_options)}') + + for option in self.placement_options: + option_str = [str(ex.name) for ex in option] + option_str = '-'.join(option_str) + # print(f'============================================ optimizers len {len(option)}: {option_str}') + trace = self.place_optimizers(self.trace, option) + + if self.visualizer is not None: + sig_name = cutils.get_siginfo_name(trace) + # TODO (matteochen): consider adding more infos for naming + self.visualizer.set_hidden_trace(f'hidden-{sig_name}-{option_str}', trace) + + self.optimized_traces.append(trace) + + if strat == self.OptimizationStrat.GREEDY: + greedy() + elif strat == self.OptimizationStrat.EXAUSTIVE: + exaustive() + else: + raise AssertionError('Optimization strat not implemented') def get_optimal_trace(self) -> TraceCtx: return self.optimal_trace - def benchmark_trace(self, trace: TraceCtx) -> float: + def benchmark_trace(self, trace: TraceCtx) -> tuple[float, Any]: input_args = [] @@ -347,15 +795,12 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: else: raise AssertionError('Unexpexcted args type') - # print('========================================= benchmark_trace: input_args') - # print_input_args(input_args, level=0) - # TODO (matteochen): measure time trace_tok = set_tracectx(trace) # Obtain the python executable string executable_str = trace.python_callable() - t, _ = self.compute_time_cost(executable_str, 10, *input_args) + t, answer = self.compute_time_cost(executable_str, 10, *input_args) reset_tracectx(trace_tok) @@ -371,16 +816,17 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: self.log_str += trace.python() self.log_str += '\n#############################################################################################################\n' - # print('========================================= benchmark_trace out') - # print_trace_execution_output(out) + return t, answer + + def bsym_assigned(self, bsym: BoundSymbol) -> bool: + return isinstance(bsym.sym.executor, OperatorExecutor) or isinstance(bsym.sym.executor, FusionExecutor) - return t def benchmark_traces(self): min_run_time = float('inf') optimal_trace: TraceCtx = self.trace # Assign initial value for unbound errors - for trace in self.optimzide_traces: - trace_time = self.benchmark_trace(trace) + for trace in self.optimized_traces: + trace_time, _ = self.benchmark_trace(trace) if trace_time < min_run_time: min_run_time = trace_time optimal_trace = trace diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index bda0d6a2dc..e74704c40c 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -153,7 +153,7 @@ def autotune_transform_for_execution(trace: TraceCtx, executors_list: Sequence[E trace = dce(trace) backend_optimizer = BackendOptimizer(trace, executors_list, produce_log=True, log_file_name=f'autotune_transform_for_execution_{sig_name}.log', visualizer=visualizer) - backend_optimizer.build_search_space() + backend_optimizer.optimize() backend_optimizer.benchmark_traces() extrace = backend_optimizer.get_optimal_trace() From 99eb302953bb91611b4a9c2003d89070e3f4a31f Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Thu, 11 Jul 2024 16:44:24 +0300 Subject: [PATCH 005/171] Extended incremental strat with a fusion try after the greedy search and comparing both with the priority approach --- thunder/backend_optimizer/optimizer.py | 321 ++++++++++++++----------- thunder/executors/nvfuserex_impl.py | 4 +- 2 files changed, 186 insertions(+), 139 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index cc4d5bef1f..5ba7d47856 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -15,7 +15,6 @@ from itertools import chain import time # import concurrent.futures -import pprint class OptimizerNode(): def __init__(self, node: Node): @@ -38,7 +37,7 @@ def __init__(self, trace: TraceCtx, executors: Sequence[Executor], produce_log=T self.fusion_executors: Sequence[FusionExecutor] = [ex for ex in executors if isinstance(ex, FusionExecutor)] self.empty_executor_hashable_placeholder: str = 'empty' self.placement_options: list[list[Executor]] = [] - self.optimized_traces: list[TraceCtx] = [] + self.optimized_traces: list[dict[str, TraceCtx]] = [] self.always_executors: tuple[Executor, ...] = get_always_executors() self.produce_log: bool = produce_log self.log_file_name: str = log_file_name @@ -62,26 +61,22 @@ def write(self, file_name): file.write(s) file.close() - def compute_time_cost(self, fn: Callable, iters: int, *args) -> tuple[float, Any]: - total_time = 0 - out = None - for _ in range(iters): - time_s = time.time_ns() - out = fn(*args) - time_e = time.time_ns() - total_time += (time_e - time_s) - - return total_time / iters, out - # TODO (matteochen): this has a lot in common with the exaustive search, compact them def build_placement_options_incremental(self): + import sys + + old_max_recursion = sys.getrecursionlimit() + # TODO (matteochen): parametrize this + sys.setrecursionlimit(20000) + + class SearchNode: def __init__(self, symbol: BoundSymbolInterface, idx: int)-> None: self.symbol = symbol self.idx = idx # Last index inclusive - def benchmark_partial_trace(trace_in: TraceCtx, last_idx: int, configuration: list[Executor]) -> tuple[float, TraceCtx, Any]: + def benchmark_partial_trace(trace_in: TraceCtx, last_idx: int, configuration: list[Executor]) -> tuple[float, TraceCtx]: # Retrive all output tensors from each subregion tensors = [] @@ -110,13 +105,15 @@ def benchmark_partial_trace(trace_in: TraceCtx, last_idx: int, configuration: li # Place the assigned symbols placed_t = self.place_optimizers(t, configuration) - cost, answer = self.benchmark_trace(placed_t) + cost, answer = benchmark_trace(placed_t) + del answer self.log(f'Executing partial trace for incremental benchmark:\n{placed_t}') self.log(f'Symbol under test = {t.bound_symbols[-2].sym.name}') self.log(f'Assigned executor = {configuration[-2].name}') self.log(f'Time = {cost/1000000} ms') + # TODO (matteochen): log this to file self.partial_costs[t] = cost - return cost, placed_t, answer + return cost, placed_t # We assign an internal id to each symbol based on its idx inside the bound_symbols list def search(node: SearchNode, configuration: list[Executor]): @@ -133,6 +130,7 @@ def continue_search(time_inc: float): min_cost = float('inf') min_cost_ex = None ex: Executor + # TODO (matteochen): do parallel for for ex in self.executors: if not isinstance(node.symbol, BoundSymbol): raise AssertionError("Receive a symbol which is not a BoundSymbol") @@ -142,7 +140,7 @@ def continue_search(time_inc: float): has_backend = True configuration.append(ex) - cost, extrace, tensor_out = benchmark_partial_trace(self.trace, node.idx, list(configuration)) + cost, extrace = benchmark_partial_trace(self.trace, node.idx, list(configuration)) configuration.pop() if cost < min_cost: @@ -155,7 +153,7 @@ def continue_search(time_inc: float): has_backend = True configuration.append(ex) - cost, extrace, tensor_out = benchmark_partial_trace(self.trace, node.idx, list(configuration)) + cost, extrace = benchmark_partial_trace(self.trace, node.idx, list(configuration)) configuration.pop() if cost < min_cost: @@ -185,6 +183,8 @@ def continue_search(time_inc: float): search(SearchNode(bound_symbols[0], 0), []) self.placement_options = all_configurations + sys.setrecursionlimit(old_max_recursion) + # TODO (matteochen): this has a lot in common with the exaustive search, compact them # def build_placement_options_incremental(self): # class SearchNode: @@ -247,7 +247,7 @@ def continue_search(time_inc: float): # # Place the assigned symbols # placed_t = self.place_optimizers(t, exs) - # cost, answer = self.benchmark_trace(placed_t) + # cost, answer = benchmark_trace(placed_t) # self.log(f'Executing partial trace for incremental benchmark:\n{placed_t}') # self.log(f'Symbol under test = {t.bound_symbols[-2].sym.name}') # self.log(f'Assigned executor = {exs[-2].name}') @@ -372,30 +372,44 @@ def continue_search(time_inc: float): # self.placement_options = all_configurations # This expects a trace after the placement call. - # Nvfuser can be slower on the single trace region but can be faster by combining more of them, try to fuse then and compare + # Fusion operators as nvFuser can be slower on the single trace region but can be faster by combining more of them, + # try to fuse then and compare def try_to_fuse_after_executors_placement(self, trace_in: TraceCtx) -> TraceCtx: + + def count_fusion_regions(trace_in: TraceCtx) -> int: + count = 0 + for bsym in trace_in.bound_symbols: + if isinstance(bsym.sym.executor, FusionExecutor): + count += 1 + # ex.fuseion_pass regions are zero indexed + return max(0, count) + best_trace: TraceCtx = trace_in - best_time, _ = self.benchmark_trace(best_trace) + best_time, answer = benchmark_trace(best_trace) + del answer trace_in_time = best_time self.log('Try to fuse') - for bsym in trace_in.bound_symbols: - print(f'subsymbols: {bsym.subsymbols}') + # for bsym in trace_in.bound_symbols: + # print(f'subsymbols: {bsym.subsymbols}') + + fusion_regions = count_fusion_regions(trace_in) for ex in self.fusion_executors: self.log(f'Try to fuse executor {ex.name}') - extrace = ex.fusion_pass(trace_in) + extrace = ex.fusion_pass(trace_in, fusion_regions) self.log(f'Fused trace:\n{extrace}') - extrace_time, _ = self.benchmark_trace(extrace) - self.log(f'Fused trace time:{extrace_time}') + extrace_time, answer = benchmark_trace(extrace) + del answer + self.log(f'Fused trace time:{extrace_time/1000000} ms') if extrace_time < best_time: best_time = extrace_time best_trace = extrace - self.log(f'Trace in (time = {trace_in_time}):\n{trace_in}') - self.log(f'Best fused trace (time = {best_time}):\n{best_trace}') + self.log(f'Trace in (time = {trace_in_time / 1000000} ms):\n{trace_in}') + self.log(f'Best fused trace (time = {best_time / 1000000} ms):\n{best_trace}') return best_trace @@ -671,26 +685,37 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: # TODO (matteochen): add config for exaustive search or incremental one def optimize(self, strat: OptimizationStrat = OptimizationStrat.GREEDY): import thunder.core.codeutils as cutils + from thunder.executors.passes import transform_for_execution def greedy(): - # This builds one option by default + # 1. This builds one option by default self.build_placement_options_incremental() - self.log(f'Placement options size: {len(self.placement_options)}') - if len(self.placement_options) != 1: raise AssertionError("Unexpected placement options size") option = self.placement_options[0] - # option_str = [str(ex.name) for ex in option] - # option_str = '-'.join(option_str) - trace = self.place_optimizers(self.trace, option) - # trace = self.try_to_fuse_after_executors_placement(trace) + # self.log(f'sym len: {len(self.trace.bound_symbols)} options len = {len(option)}') + # self.log(f'Trace to optimize\n{self.trace}') + # self.log('Chosen options:') + # for s, o in zip(self.trace.bound_symbols, option): + # print(f'{s.sym.name} -> {o.name}') + # Place the assigned executors + trace_greedy = self.place_optimizers(self.trace, option) + # Append the unique trace + self.optimized_traces.append({'greedy': trace_greedy}) - # There are no hidden placements hence do not call the visualizer + # 2. Try to fuse additional regions from the greedy result + # Attention, if all the fused traces perform worse that the greedy one, the greedy one is returned + # TODO (matteochen): ignore a duplicated trace + trace_greedy_fused = self.try_to_fuse_after_executors_placement(trace_greedy) + self.optimized_traces.append({'fused_greedy': trace_greedy_fused}) - # Append the unique trace - self.optimized_traces.append(trace) + # 3. Try the priority list approach + trace_priority = transform_for_execution(self.trace, self.executors) + self.optimized_traces.append({'priority_list': trace_priority}) + + # There are no hidden placements hence do not call the visualizer def exaustive(): # This builds one option by default @@ -709,7 +734,7 @@ def exaustive(): # TODO (matteochen): consider adding more infos for naming self.visualizer.set_hidden_trace(f'hidden-{sig_name}-{option_str}', trace) - self.optimized_traces.append(trace) + self.optimized_traces.append({option_str: trace}) if strat == self.OptimizationStrat.GREEDY: greedy() @@ -721,102 +746,6 @@ def exaustive(): def get_optimal_trace(self) -> TraceCtx: return self.optimal_trace - def benchmark_trace(self, trace: TraceCtx) -> tuple[float, Any]: - - input_args = [] - - def print_input_args(args, level=0, show_content = False): - for e in args: - if isinstance(e, tuple) or isinstance(e, list): - print_input_args(e, level=level+1) - else: - print(f'level {level}', type(e)) - - def print_trace_execution_output(out: Any, show_content=False): - if isinstance(out, tuple): - for e in out: - print(f'{type(e)}') - else: - print(f'{type(out)}') - - def thunder_to_torch_float_dtype(byte: int) -> torch.dtype: - if (byte == 2): - return torch.float16 - elif (byte == 4): - return torch.float32 - else: - return torch.float64 - - def transform_input_tuple(t: tuple, level=0) -> tuple: - res = [] - for e in t: - if type(e) is tuple: - res.append(transform_input_tuple(e, level+1)) - else: - # print(f'level {level}', type(e)) - if isinstance(e, TensorProxy): - res.append(transform_tensor(e)) - else: - # TODO (matteochen): support more data types - raise AssertionError(f'Input arg type not recognized: {type(e)}') - return tuple(res) - - def transform_tensor(arg: TensorProxy) -> torch.Tensor: - dtype = arg.dtype - if dtype is not None and type(dtype) is thunder.dtypes.floating: - torch_dtype = thunder_to_torch_float_dtype(dtype.bytes) - # print(f'thunder type: {dtype} torch_dtype: {torch_dtype}') - else: - # TODO (matteochen): support other types - raise AssertionError(f"dtype {dtype} not supported yet") - - shape = arg.shape - device = arg.device - requires_grad = arg.requires_grad - # TODO (matteochen): Missing parallel and fsdp handling... - # TODO (matteochen): Missing support for meta types ... - tensor: torch.Tensor = torch.randn(*shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad) - # print(f'Adding tensor shape: {tensor.shape} dtype: {tensor.dtype} device: {tensor.device} requires_grad: {tensor.requires_grad}') - return tensor - - # Can we remove this check? - if isinstance(trace.args, list): - for arg in trace.args: - # print(f'current arg {arg}\ntype {type(arg)}') - if isinstance(arg, tuple): - # print('Processig tuple') - input_args.append(transform_input_tuple(arg)) - elif isinstance(arg, TensorProxy): - # print('Processig TensorProxy') - e = transform_tensor(arg) - input_args.append(e) - else: - raise AssertionError(f'Input arg type not recognized: {type(arg)}') - else: - raise AssertionError('Unexpexcted args type') - - # TODO (matteochen): measure time - trace_tok = set_tracectx(trace) - - # Obtain the python executable string - executable_str = trace.python_callable() - t, answer = self.compute_time_cost(executable_str, 10, *input_args) - - reset_tracectx(trace_tok) - - # Note, currently the forward pass returns a tuple: - # ( - # dict, - # ... - # ) - # We have to access the dict['output'] in order to get the forward computation result - - if self.produce_log: - self.log_str += f'Time taken: {t / 1000000}ms\n' - self.log_str += trace.python() - self.log_str += '\n#############################################################################################################\n' - - return t, answer def bsym_assigned(self, bsym: BoundSymbol) -> bool: return isinstance(bsym.sym.executor, OperatorExecutor) or isinstance(bsym.sym.executor, FusionExecutor) @@ -825,14 +754,132 @@ def bsym_assigned(self, bsym: BoundSymbol) -> bool: def benchmark_traces(self): min_run_time = float('inf') optimal_trace: TraceCtx = self.trace # Assign initial value for unbound errors - for trace in self.optimized_traces: - trace_time, _ = self.benchmark_trace(trace) + best_label = "" + + for trace_info in self.optimized_traces: + + label = None + trace = None + for k, v in trace_info.items(): + label = k + trace = v + + trace_time, res = benchmark_trace(trace) + del res + self.log(f'Benchmark trace "{label}" (time = {trace_time / 1000000} ms):\n{trace}') if trace_time < min_run_time: min_run_time = trace_time optimal_trace = trace + best_label = label + + self.log(f'Benchmark end: Best trace "{best_label}":\n{optimal_trace}') self.optimal_trace = optimal_trace with open(self.log_file_name, 'w') as file: file.write(self.log_str) file.close() + + +# This will benpchmark the input trace with the del_last_used call +def benchmark_trace(trace: TraceCtx) -> tuple[float, Any]: + from thunder.executors.passes import del_last_used + + input_args = [] + + def compute_time_cost(fn: Callable, iters: int, *args) -> tuple[float, Any]: + total_time = 0 + out = None + for _ in range(iters): + time_s = time.time_ns() + out = fn(*args) + torch.cuda.synchronize() + time_e = time.time_ns() + total_time += (time_e - time_s) + + return total_time / iters, out + + def print_input_args(args, level=0, show_content = False): + for e in args: + if isinstance(e, tuple) or isinstance(e, list): + print_input_args(e, level=level+1) + else: + print(f'level {level}', type(e)) + + def print_trace_execution_output(out: Any, show_content=False): + if isinstance(out, tuple): + for e in out: + print(f'{type(e)}') + else: + print(f'{type(out)}') + + def thunder_to_torch_float_dtype(byte: int) -> torch.dtype: + if (byte == 2): + return torch.float16 + elif (byte == 4): + return torch.float32 + else: + return torch.float64 + + def transform_input_tuple(t: tuple, level=0) -> tuple: + res = [] + for e in t: + if type(e) is tuple: + res.append(transform_input_tuple(e, level+1)) + else: + # print(f'level {level}', type(e)) + if isinstance(e, TensorProxy): + res.append(transform_tensor(e)) + else: + # TODO (matteochen): support more data types + raise AssertionError(f'Input arg type not recognized: {type(e)}') + return tuple(res) + + def transform_tensor(arg: TensorProxy) -> torch.Tensor: + dtype = arg.dtype + if dtype is not None and type(dtype) is thunder.dtypes.floating: + torch_dtype = thunder_to_torch_float_dtype(dtype.bytes) + # print(f'thunder type: {dtype} torch_dtype: {torch_dtype}') + else: + # TODO (matteochen): support other types + raise AssertionError(f"dtype {dtype} not supported yet") + + shape = arg.shape + device = arg.device + requires_grad = arg.requires_grad + # TODO (matteochen): Missing parallel and fsdp handling... + # TODO (matteochen): Missing support for meta types ... + tensor: torch.Tensor = torch.randn(*shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad) + # print(f'Adding tensor shape: {tensor.shape} dtype: {tensor.dtype} device: {tensor.device} requires_grad: {tensor.requires_grad}') + return tensor + + # Can we remove this check? + if isinstance(trace.args, list): + for arg in trace.args: + # print(f'current arg {arg}\ntype {type(arg)}') + if isinstance(arg, tuple): + # print('Processig tuple') + input_args.append(transform_input_tuple(arg)) + elif isinstance(arg, TensorProxy): + # print('Processig TensorProxy') + e = transform_tensor(arg) + input_args.append(e) + else: + raise AssertionError(f'Input arg type not recognized: {type(arg)}') + else: + raise AssertionError('Unexpexcted args type') + + # Always benchmark trace after a deletion last used pass + trace = del_last_used(trace) + + # TODO (matteochen): measure time + trace_tok = set_tracectx(trace) + + # Obtain the python executable string + executable_str = trace.python_callable() + # TODO (matteochen): make the iters configurable + t, answer = compute_time_cost(executable_str, 1, *input_args) + + reset_tracectx(trace_tok) + + return t, answer diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 52394fb260..706c6fd28f 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -749,7 +749,7 @@ def map_redundant(x: Any) -> Any: return cse_trace # TODO Restore fusion logic here -- this just replaces supported operations in isolation at the moment - def fusion_pass(self, trace: TraceCtx) -> TraceCtx: + def fusion_pass(self, trace: TraceCtx, fusion_regions_in_trace: int = 0) -> TraceCtx: start_time_ns: int = time.time_ns() # Replace uniform with uniform_philox and rng state operators for better rematerialization from thunder.core.rematerialization import replace_uniform @@ -782,7 +782,7 @@ def _can_fuse_node(n: Node): # Counts how many fusions (per executor) have been constructed # (Used to name fusions like nvFusion0, nvFusion1, ...) - fusion_counter: int = 0 + fusion_counter: int = fusion_regions_in_trace for bsyms in bound_symbol_groups: # TODO The following allows generating single node fusions, which # may be suboptimal for real-world performance. From 2dbef84e27dc817fb655d745b7f7578db01be83e Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Fri, 12 Jul 2024 10:22:14 +0300 Subject: [PATCH 006/171] Fixed key error after fusion_pass (symbol deleted by CSE) --- thunder/__init__.py | 1 + thunder/backend_optimizer/optimizer.py | 87 +++++++++++++++++++++----- 2 files changed, 72 insertions(+), 16 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index ce16cd1f79..17dfb2db6d 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -1000,3 +1000,4 @@ def _fn(*args, **kwargs): return original_result, original_trace return _fn + diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 5ba7d47856..ba4213701d 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -1,6 +1,7 @@ from typing import Any, Hashable import torch import thunder +from thunder.clang import sub from thunder.core.baseutils import BoundSymbolInterface from thunder.core.utils import check, safe_map_flat from thunder.core.proxies import Proxy, TensorProxy, variableify, Variable @@ -389,15 +390,14 @@ def count_fusion_regions(trace_in: TraceCtx) -> int: del answer trace_in_time = best_time - self.log('Try to fuse') - # for bsym in trace_in.bound_symbols: # print(f'subsymbols: {bsym.subsymbols}') fusion_regions = count_fusion_regions(trace_in) + self.log(f'Try to fuse. Fusion regions already present: {fusion_regions}') for ex in self.fusion_executors: - self.log(f'Try to fuse executor {ex.name}') + self.log(f'Try to fuse executor {ex.name} with trace:\n{trace_in}') extrace = ex.fusion_pass(trace_in, fusion_regions) self.log(f'Fused trace:\n{extrace}') extrace_time, answer = benchmark_trace(extrace) @@ -652,28 +652,48 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: # self.log(f'Place optimizer, before fusion pass trace:\n{extrace}') - # We have to temporary clear the subsymbols of already claimed symbols by not fusion ops, otherwise fusion ops will check recursively subsymbols and clear all the current placements - cached_subsymbols: dict[str, Sequence[BoundSymbolInterface]] = {} + # proxy_names_to_ignore = set() unique_fusion_executors = set() + cached_subsymbols: dict[str, Sequence[BoundSymbol]] = {} + + if len(executor_list) != len(extrace.bound_symbols): + raise AssertionError("Invalid executor - bound_symbols lenght") + for ex, bsym in zip(executor_list, extrace.bound_symbols): - bsym_hash: str = hex(id(bsym)) - cached_subsymbols[bsym_hash] = list(bsym.subsymbols) if isinstance(ex, FusionExecutor): unique_fusion_executors.add(ex) - else: - bsym.subsymbols = () + elif isinstance(ex, OperatorExecutor): + if isinstance(bsym.output, TensorProxy): + t_proxy_name: str = bsym.output.name + cached_subsymbols[t_proxy_name] = list(bsym.subsymbols) + # This will leave out these symbols from the fusion pass + bsym.subsymbols = [] + + # proxy_names_to_ignore.add(t_proxy.name) + + # self.log(f'To ignore:\n{proxy_names_to_ignore}') + + self.log(f'Before fusion pass trace\n{extrace}') # Perform fusion pass + # TODO (matteochen): filter for the current fusion operator as we wanna find the most efficient one for ex in unique_fusion_executors: extrace = ex.fusion_pass(extrace) - # Restore the subsymbols - for bsym in extrace.bound_symbols: - hash = hex(id(bsym)) - if hash in cached_subsymbols: - bsym.subsymbols = cached_subsymbols[hash] - - # self.log(f'Place optimizer, after fusion pass trace:\n{extrace}') + self.log(f'After fusion pass trace\n{extrace}') + + # Restore subsymbols + # TODO (matteochen): Improve this search + for k, v in cached_subsymbols.items(): + # Note some symbols may be cut out by the fusion pass -> CSE + # For example: + # a = 1 + 1 + # b = 1 + 1 + # c = a + b + # being replaced by c = a + a + for bsym in extrace.bound_symbols: + if isinstance(bsym.output, TensorProxy) and bsym.output.name == k: + bsym.subsymbols = v # Apply always executors extrace = _transform_for_operator_executor_execution(extrace, self.always_executors) @@ -682,6 +702,34 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: return extrace + + def clear_bad_inputs(self, trace_in: TraceCtx): + + def args_eq(a, b) -> bool: + if len(a) != len(b): + return False + for obj_a, obj_b in zip(a, b): + if type(obj_a) == type(obj_b) and isinstance(obj_a, TensorProxy): + if obj_a.name != obj_b.name: + return False + return True + + def clear(bsym: BoundSymbol, input): + size = len(bsym.subsymbols) + if size > 0: + for subsym in bsym.subsymbols: + if not args_eq(subsym.args, input): + print(f'Sub = {subsym.sym.name} Args = {subsym.args}') + print(f'Got subsymbol {subsym.sym.name} with different inputs from {bsym.sym.name}') + subsym.args = tuple(list(input)) + clear(subsym, input) + + # Solve the issue of nvfuser mismatrch input args + for bsym in trace_in.bound_symbols: + if bsym.sym.executor is not None: + print(f'Calling clear for {bsym.sym.name} Args({type(bsym.args)}) = {bsym.args}\n') + clear(bsym, bsym.args) + # TODO (matteochen): add config for exaustive search or incremental one def optimize(self, strat: OptimizationStrat = OptimizationStrat.GREEDY): import thunder.core.codeutils as cutils @@ -701,7 +749,14 @@ def greedy(): # for s, o in zip(self.trace.bound_symbols, option): # print(f'{s.sym.name} -> {o.name}') # Place the assigned executors + self.log(f'Placing optimizers for greedy trace:\n{self.trace}') + for s, o in zip(self.trace.bound_symbols, option): + print(f'{s.sym.name} -> {o.name}') trace_greedy = self.place_optimizers(self.trace, option) + self.log(f'Greedy trace:\n{trace_greedy}') + + self.clear_bad_inputs(trace_greedy) + # Append the unique trace self.optimized_traces.append({'greedy': trace_greedy}) From 8c4fd7dcf2cfb2ddd353db2f344acbf13ea64924 Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Fri, 12 Jul 2024 10:24:02 +0300 Subject: [PATCH 007/171] Removed from tracking --- examples/dev/backward_trc.pdf | Bin 17024 -> 0 bytes examples/dev/backward_trc_final.pdf | Bin 14194 -> 0 bytes examples/dev/backward_trc_fusion.pdf | Bin 14244 -> 0 bytes examples/dev/forward_trc.pdf | Bin 13368 -> 0 bytes 4 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 examples/dev/backward_trc.pdf delete mode 100644 examples/dev/backward_trc_final.pdf delete mode 100644 examples/dev/backward_trc_fusion.pdf delete mode 100644 examples/dev/forward_trc.pdf diff --git a/examples/dev/backward_trc.pdf b/examples/dev/backward_trc.pdf deleted file mode 100644 index 208da6ea8a2479559a06032d0d74978f696871e4..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 17024 zcma*P1y~(R7A>4WfZ!I~-Q@rWcXxMpcemi~?hxE9xVyUq4ess|G(Y6dop)#E{onVx zIo-8)?W$e1tGdsj*6Jpa6%?ic&;p@I>Q5`5pqTLK@on_Yp*T75=_HM;O&m?}fp02# zD13Z;I$<+QM;{h@@9s^10BT z@4L6~j{Vs(zBP7Q31j}g{RY?hFo50h^kfYGJF;PY>$ms1*YDfzlf8?g@w8qZyz)q| zDpDMMHclDJ-F9$~YSIO+eAl5vS&6my$!9n21k8JV%FwNoGleyFnY^Viod;~$pXh=H>cOtu&z0>f@as5P9m|%BkpS8NW@vxv(A6T#V}d2Fxuc8Yvi(yQ z?^j)J-BLK67q_RE=?w0jZM-P5^r%jrCD-TGLw~2i-)vuPH-AN}AO0Ggh46l!jK3X> z4xSBmeZCnlsgqcL1%2JwmiZ=~&f_)t)WNp9{qSk+H|u6kXQn0FYe>h{<(sYT^Rf4) z3;XMBx&o)y^P^$!C4ABBg3gQc(_qz}pIZ7+mSmqkrdkZ&d;?*@3F@`60EWZRdWwZ? zQlrmuSg41%(S=mnH5a*wqZWIaQ?0@?w*hbfly7pAX&`)at&Gi0=VZv@{Jbp$h)g0ZUc{1y4 z_e1IVuFvhD@hZ56g{g{-IZO0vPTMazXKVwQD{&Nk?Zs$#ZuF&dB|de(uulmZ(Q-)+O|_9tAfD&KKL|Y{(F}{s|q|!Ta?! zKXGTpu!JLPcwdQlfmN$M%fM<6UIK+lwzZG3pAL4?P%Cs{>4V^^VJJ?!+0ve)Y{*lJ zvu1VfsT~gPWmjP>mrMz6VM)u5)=ijwkq_#okVrwrXOgYY>kN>AZC1ibjPbP?AlX=ZF(Mx+U45PhH=|JN2T=4nq4wHwUBm zi%WYd7xCS*)tsqKg8J`0(Gwymx#GH+AfhX}Gmrg_49I#uY8bk3e{NRO8PksYR-V^A z<%MGiqAuj_c&9jWg4s9F`sA>ykM(0iGsv@ZvD-Od1tr1II-}8&;Gl+BHWZmsIU9*X z@dNTTl%x*~VS%-IfR{u{3OI6ZxGY5#M^g96+{oPs-+a#|kbG%!fng4q_dW27$gf8W(%5(!sO6E`V&Y*E2C)cT*xZ82&U<~xLY&i^rc7t z(JxN^xpgG^^m^a}b^eYKUuy)9z$IfP#ELhvV`rJ05|U^q{YCjI?k9pNqS0n~vQVby z&_xY2f**Wtt%Edssn+=|-65LqV4A_hgp*rj-EWJc8;Lc1f!LIL-TJzzwlq;gpLs1- zeFa6qIz6H}vMD)ai^a7Rt8Se`Lv@I=svzUIT|Y>xW~H0l z?dL9@Xoh+C=s~$mySU~T%DTRnj9|Gk8Z;rbk~aht4a2XLS*EGald;{%{-ot&<{JsS ziKNm!fDsiV6d*mMRv%bU{fRuH=;Bg{n8+E1=|E};6k?U1tOoB)cqD5I{V+d#hfk)X z%J?j9KPFUTj0qXruXLHgkh|ePbu;+CHI3Tvcee-eSJu}$~IH@IW@l_yG9MM*?>A1 z%rrw@e#eyW<+saCS2SCo$L*B0+J|W^)@C8~LTHx7;@~ZiAijF10+?gjkR@8#5Md30 zp=4Q8SgR#)BX?Dqc@Jk&S@dWEqycoyQT84P$zxjW-qQ8eN`=VfE z<6SI>3`l*)qO{!%ACT*%6xU_gvYx)3lCPHf{vB3%uCdsRS$csdv<$=`L%3;BVY8fg z@gTrw|EgC{{9svk_%rIYyrE1a_=|`Y!W595kX^TvgchKlYs!RvdI8yHs8eoUM@1`& z4XUA?R4sm{7sqw48(z*+U=^6QlYYUHC#CF#l1kE3epv^8c*R*Va}9@1LhN`R(K&g| zhTmv}9(kyr0#6(e zouMr*Ky2F!atz{dxC}`qgp*+%PwmpRvGWcsk4?`TGtrBh(VQ-!v*8+$1e>nSRV4a8 z+$k(`MdAn-9i#&i;E&M=X(eF9kgCI*iG!2_@lGp{TSwS=SpW)~LZp7mSeXyLTh@=} zg6v;8)=-%k4s(N^KXJtZ$f&@%?*u*s7+9gs0+0QQcY4FfM>?Se3!12TLW9W7@Io*x zarjT1`4sK-2r2cD$7!uvcoQ`LYL+)@?nQAUVqzyuZYjBN=@?}P;s+4CPsTOZQ#LU8 z;Ij#a#<~S#LAYeH_7>lLcen8Q>qj$#4=VhCAN*~>;1OdCSVxk7fIth~bxtt)#HUMF8kFe;`x^UL5w8I#zUVeX+Y$Jnn%V&Bbh8t{;n`J(DG3 z$riZolp zS3P)BP5kR6VqOPOIW)Qp;zrWRt1Sh6WsW&y@Fb{hf;BX1B&02EL`X+g>zFzQcD1vG zE@h%2@^l@e!o98XgU*)|+6YNM>7r~>nKY}A1=Jk@8iXlykYRfh2nzFHk64&(T03UR z3Yo+ZHO=T`i@!K$dSLHF@OE1#P!UkkQqFJ2vBsh98ykzWiEA9j7cNU7o{;4)(0%w& zOMFM=pzcUfel5))&X&aTjJC4ZFFK^Si+Su5m>cWOG(DDw!t49(q78k=x*nub^;<&* zJ^nm5v6!$ErG+wu0Eno1A1@+{5@8sKK3@i{yyP#@vVs}9)g<%@qrTBDU-bf)l4ge%6NSfw!9ZHsV+eEMo3c05ONCfyP8Ob}yIF*%J7=fT6iL_!n0)S8D2IL9x$wS15EoQ$AjNjJ*oIk+0i!D16mUg&^<+VqUEixx#YWOrI(g==7d~n0k6qktJ=lfN7 z#DUi^cOWPE>ylbzwlgx%j>e#Eod{3Q2+tY|;;bEJ38Hu86D@?GMv@o1zx9-cBt3n( z<_2U9b>fsUIaqp1(m`L%K0Df8Er9?z??a>CI2NV%Y0zmMW60@(?+*5{RhuOJHUDHw zeJOR!4i9(Ows81-?{l%VujBZAem}<-wt3M}{(Wn1ag{bCIAGl}D3RMO$f+QOh<&dx zochR=nu$Q1YcGBtp1;;9ZFZHeGH-t;PxHx$=SCCe1Rn5Wt(ErrYe^=l2 z$tGfYj(NdHI!HjpbXbX&L|Dm?qV`LfZhe=!4S0C=Xa#{C1;1)ug#GFTtT^%K&M|3$f zVysHNehE{xRLD?%9{H1g&E;;1h_BTF0#)45G)m5)TJ|bNwWAz7Xh!40X^%WQwewag z7=YB$_s$4?bvid;LspMZ{Krj&KNCvrMcTDlc#MK=ec!D`0mDsCn-ZdhX$ZxI$08L- zS^-IlRfm!|3L|L1VsDNbC*l+J(WqGt%sozRJ>PbnJo$ATWDlDz-bhC+reLb0DQ=Q5 zEqateIYqFtV(G-k@VY*>snH7@#aFnqr7d6pyXgWK(<%S6s%65%rwUN)Xp_hdY_0tc z^xvSD)aCqWGmgRQUuMqH2-|f;HPZZqpR?2`Rfrce2?mg2^6Dwdv6!_#ZCQR+LKNZR^jOhm8Laza#EzoF4 zrL07GVyv-W9GKWat#8>g#+;FGRgRqI1GX~g6~kQ5?nMA-=S_Y5N9?)eZGY6wRMpl8 z8FRsvrY8BL2p=3M!p!{RUl+%ngm#;y{0Fq+8v{&Cav!+kE$9I`rxY@~lDH`b8CiDu zL$vsm9p+A@gTY$#wvswr4C>tRMMJc_6WiGQc-mHF)S-&4Ic5b>$r_r<-ENRFDjNkV z6DwLM&|Kxqo{_C$((MgwSiUkQ_9d>wZfTVQmv#!~0fV_nGlQdy;}ksdfRanxXdyrG zyW|)}NUUa*OkhlQi8!(xQvL?gb&uGk>?#t2K4gYruo|YuwUpCX6XCU6AODxUF$Qld zEv*j*3KHh7CPn1U#$T7sX_25|y)tt3#mwh>bb_)e%Xx?8^j9&oi$10g^YQ#xqwFb( zNl#4`{&lEs|53yDNsvKY9XQ#-GAwKkOis{jGt6rZ@SSS#ze|D$8_eIW#Wqm7JrWdW zXp$wTwiBp(WIDG|hFXdVm-}B~r)Hxt^vg++xPlzTfzlX{x*_BZNVu?Q-(wOsk>n$C#4`%s z=0g%AuA&5mi}GH&M^Re1b@Z&FBcAxSde8%JqEs6h^yymJ@=>=ew3LIsL6~1kt*$`n zYbY!(yfL4sCE52dbS&B0*I<0ZaD-6*AS()c5sl1;jji`MZcUOgLaXpHdB|#3NrDx z?l|7-j(W&u?W0bI%4_Blu6v}AfcVkJDgIhYbJ2(+Tz0xp z(=>s zN8#*Unq9kMw$rvdlNy&L5Gy0AAq}vr5|;_c4}=DK)%Ro#cmv^4_m@ zQWW{RcGW;OrGbRhsRBAOqbtc#Qc`j#mCn+FM8VcIxG^?^-A~f`_%cS?b6b$gL{lLJ zmr%X3VcHCnmvvn$C}hFx({8}3)2<0i-XgT?G+pO4#wz>u-L{&Q{r7BW#OjthHrGlIeHq~?>D((H;=!8L|@ovPed$K#2O1VUiNT@N*4f%I%RQH6s zAE2?Ofw(i=P+Z^|RVTWQZ@G}hsP13RlW5yp$!}u*4g=*4`6c7ts7JmlHxY4?eQNSA zUEvcDy#7ZC&nbyk2t50B67H}@y36VFmaaiVqnrRkvO1VzMT|+CsURY!cq@ zC{giSIa~m9j<=SL{e## zu@ghuW4L5cEiK3haU3Y=*y)C34-!CFOlwYhUCi5h+FmTmK@aw;c@S%mYd^%dh)1}> zWYj0PQk2t2YNxZfsw(G_5d&@=IF+H2r&1^GydNr9?gchQnv?rL<*bF1%#qfHBCGrX zQOpJQiKkt49PK3yrH-?&RAhLcfyQ%IaqKY&@X#xtKY{yYpM^_)uEk)R(Jn6o!^Do* zBVN`tH=}k4(iNvh+D*paxD31PA-r|HCIY|M!>A@^P5Cx}PRTbkVyVh3Zy<1PJRt28hnd}S5GXZ-VVT%ZWfM9M3l~7+sK5CLr*sx&5&~T z>N9d2ud>QmfM7gj9+<1OK5oJzkNAAQzp~a_;vvgjnMlUP%9&Q0rsj|v!Ea~gIy+W6 zzBv1IB(z_x?h%n4ae+ewhn7z^TQ}@T6{W5~N-P)|^^f6VxtYktzycWAPwg94e z_>i1Wc>fgkmfm0nFP>aY8oi@Sl|si>I0D-n&xve{Xbk#* z5y7}*_g~Ika>cc%h|LMiR=P<-&&>mz2ZVZ+(z~cEreoz1aCyg4!`D}a&730hOLPOT zJ`#F+3y@mPIIb7x8mYhXE}m~d?OqR!`(Zo9Hw0mYhL?eZ+#_Zn%8E6XVR#;~;vc=J z_$I0t8AMabwXNdPS{Jln&WS4wck^4P{r4wW=&*}FiuVosyl%45MfmN{&kZ*18L5SD z^VPDO!;!68z-@Ah-=_b-m?%ZAV6W4M6k>y*{Dy5n8a~RkXd*GgpTc*4)IA@lj7IB3 zBg|9E^lh>^ef#(X&E`V%*OBj0?|t!{c83&eSo6xI$7Xtb|I9aE5@kF3ovDPT%5>@*XCQBzZQyF9nB70i(&J{`S8vbx&4fB_hF|!@g zoBn>j3*B-1vMdXF`OAa?1XboBPsLCIxF&@!a&g9;U-x_=yfYTurX~F{Y&Rf80veA} zaYIuTtFC8G7gSO<;Q(q@=HZ{=ONE@dUc7ZPz+GLmexSRMj_$>T(0AH*t4CmnBn9EZ zpDBhP!_6ZY7a<|MeW$v<-SG`HL>_t5MpltwoX=5UUU@7@+1g{Cw>TBTzodY3NuP+L z{>~2$#eui8BYe@o+nu4fvq#f9-M@?Vtzn(?Wafz@RzoNfE3@E!@=En7=lk!PR#5zQ} zcJR@gYjME{1+W@Wa-#yUE6{?GNfcL$Y*O24N++e%p2MQS5fB4;DBnoR=JW_O+)gsr z6{nC?j`JOBE)4llMbtY(x>m4`r)>|2?5QDy} zde*ERVPS~~Xul*q&BOA?)maTeDjZCpCL_y8Xhx~Og31&rq8OQIWI8BfVK-ZDWEnC) z1cyCuWOu@S0$r+x8HzZ!p%(DkV61Xl6o|-ghdjSW*t|E2i&yzR5YmqFJzll+6Qq0^ z0mZ%-I>+RxrAr>bbN=YZUSfs6Z!)xeG`O+uq*tK16u!3$EPmvXl{o1rsR zh$45u*gD{SmQZvc`b(XkLlRl8#TiSLPLz1dm&<%qzgKhnsWKYZNCk@ZUP=tcs_gHdh?a^e)f*; z=Wh@-s3$BaBWuHdA$f19cMR|y_M;J@RXe}S4J_BL;eLjR}je|h<%{kuB~Z@%sItQ~CMy&1Uutro*~ zw0AQ4TM~G)5;Sr)Gcb}D;eS{DtF!V(4mM8q21X9}?}3y5N8ec9in>p?tQj|&0FaR#W zMG06u%$0h+@-*4L?qbnevQXK+uxwGW{(3q`7=eH^_6Dq4o%RjaP zR;K*xs8;(xXlNk_!J1mT#r87!_;)*W&M&=-xrV@9i?C*wWNz>UI>H?@S_=1K`VO%r zxO0xpr3)li-CGRqfF}#vZ?ZM6LPmqgGp+$3-)+~GwZCr=^sOv?z9-xC`n?rQnOuWq zhRQ_a`O3D(j3O{ocqrqHbLv&^}HtQKAgrT+H8ccZT zdH--Gnci3hL!8WMqC(W%PX;1}fC}bZa>hMvR2r8se;9Tce!|M2W>)d0-mCfRn z6QZ7SOh#T|d(Hp`lLF zDVZmjB)*7CQOrDAtGrHdLVZ1!^|Lfeb{Mm-i`Xakhx|?$1hJbh&H4jS*L=mHTZsMY z{OnNxp^*qPq5b$W?c1Rp(ENzKM~MNk3rPw*=*&2!ubERLat zWVgT$JR_MG>h}Bct^45qf_Az)Bt*vf0Eu&0)KETut@7+TSx>$IHyj8NHAH;hja)Fe zH8PyR%qT;a+UYy$JDc$&k5tWim$HG7@rPEIem zZ1rMhMXWL@i?O*ibDjE5b}uPiK@D>c%J}Fj5mik)Rw_A|I5q?#V-WRjRD z(8~=SX}YLhIF93J8TR^&EY`0#eKiI)UY_Xb9uk#vw6)l7M(0)Y9ZyFdbQC&N>x%0s zFjRQS`9NKm*QVXUL$@*A5kQL_jQ{x_SHyvE#O!8eW7-Y2UMY-a z_^lcIHa)bvwg_NLD35Afj~MrmvoD2YFwsB`c43vJ@Fz-=JGa*;2z*x-;_B zH*!gAYlzT;O;;Lwt@(H}tr?tl02QDN8AqN$YEM4ki@+52m-u{6&iLl&iR^nj#-`Mk zr5;NGNLoWE(~l~}8iC5&;(Lf!pIiMZH|7thA`^9k61T+75{(i)l*URI33Nf(K_F^g zqI|gRdyG?_@^zh)jL|(k<7zx%JM0mPwtgsUcB+TMsVEszf94t5jlck0MiT8b8X3KW zj+q!4{Z6UB^Ahuje zb{NAz#3AH;QxBA&DEb}@KR@#dVisEEXC4e|c!lzo4g*xy=P9gA{8?FY`@T&p4E`X3 z-b6$vX;u@~w?*nXjQRa-yX=}-yY&MrciEVa|96Hav%jy!uJe`j(^TMzy`76>1RSf5I7sQj5K_{ZX zPYym{*AVfvNJ!r(g!^YV33HcDN0bdQC>qgjSX2P41sWnD*{C<#Q!hVAZ(i(Hhjn1A zCIfY26m^G&In|Lr3}WPno6C{Zv!6tmDKsn;}Lxc6MY&l-8uRPihqH4X! z&8&BVg>|aukSNhy?rgHnx=KG@NOajp!6omf;#hzA+0XO3%Tg5E(%{oppsVAgB^v0I zvo_@4xH`;L!W(vMfwub%lEZ~;FBc+9#^(2JCY^A>JyPeAQ zL{$Qv?M)&TtNMfe%tA7qPP6^P)dI9_Qtn3em4=Yqgt;mQWo=yLQ?4?{4c;y zPX2#!%lZNt=rv$cd5baW(76cl7%Urgh`~zj z^$?!*G6D`4KMXL|Fl}0sSK_Fo2~%spq8t3ey>Pa z(YRB8h+Fq8Z^m-Z+w80A9#$L+U$nqFAjl>-Kx88tun7^4LsStu-#NB=J@svAddyeX z@clkACaAC@6V4L2LI{Of%(0dlS56~TS}T-9C|kH+GjqL)CO_%Im4y9)sf&?~PrO95 z1bZg6M;Z38FH@52P`}j_ZDIGo-np4Qk3ElLwa{~lg+t%VRv^7FF^F}kH^}6Hs4@rI z(1JZa!luD;rA^;WYtYN+_hLg0QRi{9PiiJs5+~_6m8;PmUT5Fz@0QZit?#G541Yru zj220=FWr?HD9l@uXZ#kkzcOjO$%1wNDF>5AI|yjYgSu=rYX_9#zQ{W*s%oFx=ApyT zVwHh^k|Fwm3H}19oH2|3{IRq z*{`nwJ8zhSI^15uuGpMkPuBqsSr&_$wF3GGzlTkL$8e4QD)*auBV-~`oz6#dA#G^3 zfR~7qIo5O$W=43TGrgoD)-WMjT_S^)UKTT~DTI2>>@g>y3PYQpz(n_4#O!MmAB1Vs z0_^h|TjBW3rR-=1rO;v_ovz%NaB#8D(G0=)*??YH^|C2;;afRYr(h)}r3YbB*FwHq z$;JV+l_9P!H0{bzvuHhQ_DT&0&BkR(kI0wCRHfeR&wHvMV(aJ!||O1$#u=>6=TL6UCzQ^c<+Y^BB*io^|4m7MLjhW zWlZZ9TA)3nLunw_5N4XSmb@H>h;p(i7&l3oN@K~KTrBXvbc9CJEf6=5fne$lTA89_ z!!+>KMFe24@q~op@bF5YP^JurgY#Y{)_^v4$%j%<`D4g^odmsVm1{T=Zq>T&s~i25 zP@t_N0i{1?Cd^XP+MIwMbNl?X50PCXbKCNUTz;q26?zsgm}-nCoZD5ExsW-flyB*z zbT29%le_<=Fy5M`Y~!-c0f!6L$j@>y6(PB6e)>q)`D?4J{umq(|AIu%h`BK2R;kQ& zZt28&+qOs<+Lp!Acy!w~KB0=UO)U0hFjc5%)~-=u2hv~$bi8CMNoa{A@v=@8u~6{5 zI#k%Vmeola8sysym|2LSA4^H)r~IWfPZ}2H}G!VJ)+Omo6UVT~7Uo z5x3Xu?c3b7;ns{c4M&dnUxt7sHu*8Ty(ZG|+6SEw=!d{9ev?oOVCT{jG46JrYP<9l zMsaK!D{8#C4K~vwJ6+fhL=54Axq?H^7*(4m58)>J5G%B2 z?#iJI)-6KkI>vzo-}u0h6N!VSkM6H8G~iHDJBo=XllVcA0#EXo6hrvs{98Xm%*n0g zI9!Oe5OWi?khXBQ0M&;Y!FZRbofZt6Vd>&)emB{ym$_Ej>n`RGBd1tB-v!N+4GlAJ zDzPcgf9pInxgWP!RK0)h!j$9=U=8D!8NO^=1E$Hl$Bn(Oy(~m_*2^$%xn>H8yl?glU0* z!zdYh*Cz*~{yY2=I_>Xyew>O;EvAHzZCDZGc-I0~dK8)5Rt~))FD7$y?!UTA>%c5) zu2#8bWg6gTI>Q)h!<@2-s_!Jaim*+?P6*vC65ECRcF_?hK9w($22rO@e6C7^oeQW9PiXnCQ zwU7;b*x|9-O!T>quU|jg#VG%LQT4K)1=X8`HeZz4bClMS-Jc!*K{Hh`bLg72Fjx)v zk>m0?lsr%>hr)gFIYlH_(u|>H)G_%1jeN2VAu&bed)a|`W!!?_)E8*Jz2BQ1;5sVv z&$?u_;B8jA=QGJ=k4xIS=aPX*&=(_wYEdOg(7LXMMs6Xbk-r3IJ}hnpZ`_(NqKJbx zP;zCwEPoSkJ_L310pE2p&P=KbiH&4mM51NHqGs<))003y!g`C5-mQ%-J6I3#b5xw`EQw6Q$&vMgQr-RV_C|sLCvfM2hi-#VnDr{m_!<{YhL6 z$Fb+A>V&jN8qpi8gyzzw=Zo?H`cEgNKW%eD6AHv7}6DHVJ)kT%< ziEv^Lt^=W-QveEIG34pY-%jlLRT|^$U+=dahjf|yii`xV{MNGR4&VnMKpg^`gPI4L z3!B%PJDOjcFWr12*CIP2Un4KEK&945b>cfIIw)+rJi1)EyzyUoyNHVB^Kl#SoUxs8 zopGYrvA-^RE!!=-EjzIvIo7F->%v$j*cj9Q@Nr=CO6Zo|wsE7-XmmO?xOTQ#lz+Lj zsEwi{H)5&M{O*2Mh||jK`LNbgIlZ7~5-V3b~&HEniC&^goxxtcebvGrS+C*>NYIDIkmPb5>-D5 zjmQ1GQz#{qnlw`=`)KjI<`CuEQ#_j2?|oa;DDRrfp{zTVFf3=z4yzZU4Mvv_)^+~4 zFg4lAY_SwSpxrgeKR>p9*A*Nc_vs!D6bjf>gHw|fWV(n@I^yC1C-s(0PHUHM1l#xh z2m``e*3I0f5`?>ToFa41_%KGyur#|IX`VZs|B6^|nWvQQL6zXUBa!g2^-`1*tQLFQ zFz9n{YHW|zm?0m;WdO}1-={FyAnHA?K!U-?>&m+--GI6=^Ttj@3h>RdV^K|BS*m^Vl0gtTA_RxZa}~A3txIXz;Qn>)EMsFt1Ry-~q z=9@KYOlG6+sA`dr64t6ZrweLq?;fBQLbfcA+|Ti>XA z3CqOs?PHsl&9m!I*C*FtGg}BA%MB6w-{XBPDC?S1!6dN?$nZ3;@%Estl{W@*Yz3MD z7&_+emY1mtBPFM;&Q|LkWvOoN%ZirIM@S9i=?P!O3+!z)SL$=^GlmBI<-G1`(RYnI zj%RprA8^@*hsELiaBfUwQPb_6=*=z>dP)~laQ z#MS!*)uXOZyB}uJKubRu)R`aUOB@0ZPY=t*otU(le<=*AK2NgAtS1U(x`E9{9^ail zQXI{?{N7(DBYA1g-_U72@=p8UyHE=7jY0fA=Eyoo0GX8{2ixLnvK4Gg;`QlVcTJ!6 zG0ro!L#_ku<&wp^xbrfX*^`de=l&Rc?dQhViSc`$uf@N-5YM}5TD5d|gO=h(keb#(q1l;NhM`pF#d0FJ`u)b?lOk)13^EMetmJ zG@%k?fgm65eeUb&Upfc?UCdkY%w=z zrLn0_T@D;Wn$hr*=v;?8e5BCeG~aDnmXOrcnf4820_4<_q9c-4S9IRXFQPiQxt!fJ z2iNKk$HEi|4>v>=nnLiVL}ioWhax_fR>fa#I3e55qd5nMe`(9}hS}ko4sG}wGH|Z? zh%bSV^HuDaKMFbOrkJC9^XfPzx%r40`~Sj^$cwT=edGxl@N4VGKQt_kpyX-xTMj`a zjwr0kD`_d=)n;8~XM$c+j;C%5zRm`ziyHI8E$ppjXTtFY$@A&O-6n^o!i9RscP6zH z*K2*g=3M7+cs9$yTYV(Xtil;LWvRGyT4@3hMMvg1ESp(!i0%Ab>g&?+KHAlp<)1a) zd01X=XLQM4;DEN@TpHtSQas-PaCS1XNR6Tp?sda0Y*qzLA`d=8TwSHDYi7hV*jzln z6uso-%RF>;RU8qqZmE7gM5V$Fpw2qyIn|*2%6{=Y4+sdv$fKbxg1XK zh8zWGQA?Mz>JuIU+_2fqI@O&Y^>))Buvr>P*Mhtm=f5f(b0&#lk=z+O=5sHMX~QMX zeAu~Rx1ZsaMjyIVIlXaBX)`Co(YvaGVk?YM`Eq3uh^;gfd)Impdz}A5Yp+TWJRh6< zX?9xAP{33u$T9val6djPZ@;6>prn3Wj%96R&l=M9OFuw^W2AwvC)x zCr7~HvUo5dSd-T+=(0Ez1xG#V-mod=%bX%A03m$!9r9AJJZ0Jq|*0D9v zGI@!;sT51dyfHJo9%`up_|2Mwuc5-E(ON8D<9b**brrO5G!+JRBL(!jzN(@Q7j-2jD0|Q zVCcluH)7=bjNAjUNq*p->aZv@h8|%rN_Mz&T1AYW8Z-NtbWIP)cunkuEEg;^G4}UK zS+3!+klfT}Vxm<(##~WhvUFUsDMtt~oyo#OXft(W$bw}vQ%o^168TeNC1VBNn*{zi zPU3jdXfgrbYJyba&qdk%al)~J3SXqg$RIybNu`K^T&Q)4gD_xCB6re-Jzj@STI6() zZyg)9)2A%(w;eCB`t}W%lkw^cA;t5qOsD?fiWd~9nocDaC!j_DQGie+i0#bblz`Ye zE_~{&pRAT1shX2{`{{evjrteDZm~38!5AG&LSt7GLEA$(8Qu7m1xM9@bbyGu;`y#jXY}4WqP)2Vs9B1GQ z*Cdb%u#v|YKJ`7PfLm zJ40@g5FWRe)&3SPS^IR{_Ue{a%ve}%)5s@D1PIWK(~R!WKL7OCen3V~K@>zB zlVp%+d{?{I5LxKbaJN zI!R&9vAJ-~ohK{BVZ)_1p zzW{wl-D_}_K@v#$S6Gqb)EZvMId<^8WCkp4Xf$oA&{ zjm-jNek07hpV{6!`krEDVtVVt`~8RcZ5q6%*qB)VXqo95{$)Xr&&0xt4+JpcGtmS7 z1JWXKQL^ z;6Q6*Z$k0cg&^?uvH?dk8*9OL$`QFBI|Drf6Fmz(0Kf!bU}RCJ_@}@BLsj~lJtUy_ zc2(Gz{Au%_*8ht;^nVyfuJ%U8P;aCmCMf!UZ}`kWAR`dp82>LFfSHNmZDij!eCvPd z=vf&5FFFRm+ZV-u+R@Vk-p2Vqbu0k3|JDJ3Y{384(X#-V{@X7A_-6TE^}Xp>-$MN# zJ^(=Gw>k4aI@Y%k{ihu>6WjlkeJlDlPyV)ZwAXvP?(P4)BSOK<-RM15?{_$P8=JTB n{)^@GE{j_>5)vU9VW|HPZvI;z diff --git a/examples/dev/backward_trc_final.pdf b/examples/dev/backward_trc_final.pdf deleted file mode 100644 index 3932a6093d9be465465aa05099c82817366d1a44..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 14194 zcma)j1ymeM^DmmOZ;^liR)CeBDUyHyfJMg8(#X~r z!1=_IM*;u9 ze#W^HF2;pQ1&v}5>1bbzDHHRHayM;GKiuceZKQz0sbEGj3yqJov%DnJ@|XO^YJq}S z?GoXgjnS|4waKHa`n!4pliNWuhxBHyMk^2XgQFGWTaw2I{p*wcRJGxwpIdgyUDi4^U;?>%iNq&x>93irr+2D9X<}e*!lOUayU49c2e>7!&;YcD*R}^Gd z4_|d2Fxq?gq5btDYtpPh-}=#)i_>Md#zv2txpE?XO#}-fe>JAYPxCWcm1c$0tCK~& zn^jgvL$9=3UCxJ3HkV0`P9BQx+jH;ZUlL#HkQ@~+K5koYcRXOGPd-=$@V7qBw%^=$ z9$DjGTIrBU-|h|nJZ^3Bc)TdTx<_|CIw$`b{_^VK)RX1Gx!mxP<7VL2A)vAR@d4lU z)@EzNQ%-YHmgo+k2zN_TPG zsBh%~5||~2orF6XTj;gX!T?ZZKUW*{x_3AB4uX>Jl!kJ(wMk!dwpp>Pqe*t@p_86I z7QfQUdvzb~x^|x|I@j9GA-B-7_5Mw@^rEGzp!@CmjQlb2VYN=7j#j~BHc?|-)Tcbt zBc2`2(+cN_s+xr}Pn!5oY0h(Qb*^pa?iKKZ)fUH==cs9-M8pE!%$jffLCfgDiAzbDP`6#Vx?TQ?eC!af`RdBKcBn z%EFRzsAANj@5mtQ6Ds>|=K-;Oy$%0mDZYB|Zu1=^q_A-$%e{xf>$}z5!9(4*`MT6` zZWpWFC^BomWKcRDkgK5sN=12!9pFy~sMr`(pA3^pQX)sp>ZE zDy56q)Jf!3`*vDa_}b{2#06O(StoA9fGUI0%WVI7^4R$Xi-^4+N|N&l6|!vfx#rbQ zoc$K1S;BZ839uwfS{YTpaEjM2w680}P3xaV@!}Kc!U{^+6w4Dg{HT{3Dz;mHN24h&htipW z;*g>+bXJ0HE}w1>6bww{5g$=yhHgS12T*C}sO!T|=dX4t%Y!Vv`QP>h_7>Uny%xd} zQ&{w9JYa?rNM_z*l`*=LP4agkVfX(^$j0E+K)8@KWv=IRei68#!pPzK*yv1eIseYq z@k$s%#pot#_;&fT(m1Q9C;(0E!=IewpU0Wp+ zUAg8-+yKP3`fN{v(kybsCVxU=1r!Hn@SjdG9*jv$Yi-M|pb=QT(AE1U!@j8i<*&yF ze{#N;N`>;Gua6Yav9%3RgY|_{B+X!8W60$(6zUHVUMph*C?QZWn5sxHY=uyg$G;#u zO~N+%Ui=HI!?n1G#U0;G+YQ$jG@`uDh_AnSAb3o_fsFvAiD*xc!0!Oxft3Pv`BAF# zpb6_&``Z&FYB9A;VX5|0d|T<0w|H~*j`kXe1sbV}{w(xHXT*~fm}RXoI$`8?vrc_= zos8yo;byLiI#Al-p7^hkC_i|dEncGILxairAsKLeMbSL#D8?BGluo@%ft)sP0a)ET zydGCVMjr{Qp>l6wG_Q&%X|WR2vp|xhnMg16^p@p_iU8d*sS{l^(ZWbobGJ z#s;?p_$;+@9$d4NHF&1JhG2QU6Y*(>RE+Fq!sN(PZdIR>mu{>E*PwD8xca?sfQpx1 z7DIvIFJ{YHM6y*9I5M*kkQb(Sakuqe3v&UE(a7@hi^FH8nVL4vVmwzkW2r2>+8dl| z3?uU|Mydf32ka?Rmclv2f*ZDPj?R{|dS*mnp;0$DEZj2 z=rGu@av!Cc2;eUECuLUl(WEo+CYjF^9a+zzO)x`6CGRx7A}san-cgjDU8aGUgk(Dn zuQw6ewRCCd(i(zJ>2(7>;FG{6%rKEU%r{JPtodvZ?%cwC7TxMsp5uHE;`W$+!+78? zhQ~fxtpBJWin=K3%j1A(??p97hAabJyEz&6vid682%*vD4Fes6kba&mg*TJdFi{hw z9kxB9>XcUvBvK?xtP+VFMp8FR8)^BSOqiIS^`tY5<2I+=TfLLE+(JbE16S1L?KA>o zKPxc=V|l9tt@2ns;O^Hc#-3q$7Je8$U_pGjzt9JU&%pdSVsN~9r&b|e$$pw!*yYI} zTj3F3_>Fx{d}yQCVGh2m%bTWg+f0kd^qfAoX7k{=#KVw;ArdG#I{-~p8VOC`#vDWF zEP3?TJ#GssHU5weOr^2n7n!IcKk>Z6Z=m|rtpBqtn`^B!*^CMvmeV{q+ zoUgjCagziQU;zD7g-aJjwq>gEh#)jE?5jR%N$o&z!q?eCR+3^Jm@!6L;`f82eOL)p z>NUarL|+;;%p$bIlQ&~%{Tca4Q`Mp4TP!hch-FUs2@g1;d?!{#oa}I!uNB(5IXFD2 z!g;0QdPA_@Nrwgb#inIT(elz%)JHoty{aaSH6N##&5?JOb*0{tXg8%7lro8J$EYID zu{9vQ;{q*Io|eS8Zw*u$qjG5b&;MF$K6I=oHS|Oxpf3V;9lOIBE502Ul<=9Qkm0^v zLw69#a%cvAKbt*iWXnC`4ZkK78#lJ<`EX5DFO{^@jUqC$hn2x>$Z%kbkE=Eh8gOjM zwR=4#;yY5)HQPOrU!=tD=ZiwFtY*AgE$K_O6^ibPV9n%9rM*Seo8p259w0Igms#3d z?RFXBg^&ixl*S%NP&1c?0bkf)O5#YyuJRE9*-@){<5x#^IKJ5%%Jo%Weyc<+bonUi zj;8m0*sH$cF!cqU4m5`#B#)MuJ74Bdl-sJ#DmPf@V!g>b6D&~7ljE)iX zUBhNN2ga8o9+Ok0IWn*=?n}{DeNOx4q@2etnR13JA4%o2!*gj?Gwc;M{PgiDSA`|x zyhbF;@%0MtP_*umFf;zbTih}+91GY(7s zmyJ}M(uxSeWl`lgCT}Uo_mr^kDBS!-zu8Co8h#?}J*M~cUV;G)Eg>=qmTVT~w1pVy z7nfvTPlaO$BWz%zzYliavq}tGH%rGQg&Z*i69=+1m}r5wi!-!?!g|=g!641uiETD} zAw^HG3Ckni+rz{rl{r_IMC}d5ZmK{BA{AN!TuqJR3^zH#bbyVd?Fjg|jYKJe6HpfEDC)8)zhhk|cjev-DQY5HAyI z6@fhvKZvQ0G4}QHsVmfdlw4>^U*H=NTOg$Govkh^@Ll@#5=(z$=+w9I9r|2$$xx;6 z&?(1yUP&fj+TvW))R}M71g7;1)piU`GunBaenK%@oT#USoAc!GWWJJj@663X zUn{YREgCQ;0z+$MHkI^{r=?R>)7X{gjW-RKL+3;g4-=J2u{Se!WqUlaPd2AQ5NpXW@DsRBP!icPm^Q$pS}V@Hdz2}nw0A{We_0nfgN;>)hf zSR~I|t-m+VMS4n4Hw zt+wlr&1^LN%=PG9tx14*GXsXyyif?zJ$-DvYGE&FXQ)0XN!h$M*P5`J`f?Z%r5EZ3 zYtzJAftlLfBz^vVO?Rj6{ivpzQ||^xmaclL*~}ilfoX7*^6A@QtSNFOsUjzrVH#nF zi0UMm^R*grfnFs;w35?^2!55^L1Xg66XmSgA}dk^hUr?3&`*Mhha~W{YR1kS?;CN} z0@1Ra8!ESUV#Q5J&L?m@kY%(@fe__40*$FTdn+YVmJTQ@2D$Fhn8xApPan7JBR=%& z3;7Vl+I0F&$Tc|7h2ES|iZOSkm2zxFTjav^hCx=_imzsAva}LRpcu^cWHOy_ zfT3_0fdG?{VlC*`cU&a)I*yC{ZYfLAo}iH26zUIAZuDc;pHVxSn^@qeWNnT~mvS0Y zA>uFJkS>Fp7ZGoMoWq8F0r@-dhf42;9K<#%{j^VOpaB0$s|}K;iYigee8?d~E+N(I z(wCdT%o$T3EUoR2<1<5u$I`yjM3Z93hj3)BT^BOszJgLdpvoOVsho0+I!?R(Y<4<| z^@Wa}7bD_LDxHWH{pn=|o>vKA5DMw*$7;MF=kXPFX41G)^hZhXY#G~$H<(JPHYr34z?NEG$T94~hJLyyvL}44(|FJR!OEeLt9i+M=#jIz1YcmYzZjRC0&Cx**=*4L3MI=K@gTJCDPye3dF3-W1=jT(n z1qfv2coujjIsOxt`5lN6bF`IEuzd=`00ab{8BYfbi1mp8u!sr+*Z?g0pr`v&Xy&Ql z&z(iw%F^~9(qJZ5CRPB~-=Uo+p{L(p%Gm#-48Wr3WNipwkp&qU0;mCOzm>~^AcjwE zvHe|_<%6Mt2}s1s5uowJVg<1A@^AonIJf{>zr{tL6hFz^0G>ljf8t0I5GyHgK^bGEQi^p#0n3Va2v%;=*Ko zM?k=&+@tHvztS}Tr(lvf>ictSITFy&j}-2~joE-(PM186?prjvKQ>=#oBdq5DUtAD z6_I>99hlMfT@0PM0mqb`a^nl2YbdEvPc4J%_L|xEE*Ybvaf>Ojb*E_0!%Tv14!vc8 zuhn+CsW(pXz4Zf($bBI`j-YCSYMEugcB{9G4}Ke}M%$T?rpJp}DP%mRqO0?lYw-xw zmDWqD>Rgw?UB1Mt-v~>S@dZ^Ye6>DRVIt96zfwJs9go^9s{gtA4T}5*N^eY{CM{Ls z(x>{xhSJbMa-a+LDKH%*2fpinq}SQi@Jb% z#U1_i(9wByXWYaH;un=!6dWoB8ph1W7$du^kj#EcJ3CtwH?5<+2)?SQEG{)Yk)V** z^WTFc<7Dj>ZwJknCdbMp%vK6Kx=$RmROQRUIu(u;1yDc9X-wfy;ZMgk)I+oM7}g7G zWy}=tPr+vpihn99W8+MLRg-Fo^XWWj~twd>|H0FUr6M4-lHS|13zI%LPNi)>_F95BK2{idb|@H$o;>@VG#G;G-^))Z3xsdN~y z5`hIp&FdY835o>@wVNGy$s3mSS@r`*-;1Rbj0mXq>#P@C5)vI1rF(gF*x)CsCi1db zq{=?#T|3Huu~^8)yRE*Zso}|8;$ByZ@yEg7M$K6=CT2h<g7IflzI&FthMm3uVD|?HC-M7uR6?PfA)CDAQAw52gCaE*C3#-G zzCM*?X6@j;ecSbKNrR<^o%v)m12W&cbjOKS!hdDb$W+^4Q%8{0k1_fdxrKm9BHa) z)Q25mzB)J2_z>!2eG?g_j?fP4hG2pBA;l_9JB`LNTGvw{-)!{`ubE((Q&K~&#OG{ z)rDA7&Fae6HkN15%nn@<6=h`PtxWX0QeYt_`s?26=}W0Key2_e(^^(c@=jQD(gBfC z`Yc%OM!(GoOt!1+b1Kcmg#?l_9+@dwUQcu6h7h_D<`5vQIE-r-$<~o|cnC5-aD{#a(kH65{?;U8QjS6S-)%g5)<<9ze z99=ijS?VOzF|1X!mkcaOMbER&2q`gF#vymR$$Hs#*=>0ek8^8XE~9dg-iY6acm5;; z2i50s$&c&PO_4cdga!Xi51*rd z)STB{@w`JV+eH5$L`LJv^j8$akNMTMS0!ztXgKd2==+5k_z|V&x?SwmK2y_6#LmFD zCo?UZ4ODPo=ArBiOT$5e#aSdcFNP25i?MyOF-0wB#%su-Mle3et3^+3*|IJadSA`0 z1{0>_%JyP5Sszy^9nO~4s0i|_E`{PBeOFdcQK_;pVp9rs!NVO<_OQACc5`vi0o(29 zw7BTea27D?ZbM|!0>8oJe?EOD!guAT&Y@zqn0pOh#xc5sJ85|h|Dm)v9_F**XOxjo zExfgz@DJoh5nE&_XkD-8YFsW$wHk@--Amkq)=M?pKkuomw=}q355x|}w1Hm)*Ej{)I~cqw;at_w#js^X zDDQ2HaJ5nT^&l#!64sW2gKu_oxt4GUC+KDz#NMX+*-_1VX&+u*ig?{Z*_+YN0xL>l z;Z}@?O%GkezXFp5UmTN$@hg5cTbrrzAzZQjzB!ttVBD_GzSXa24G8rPgu0D1n(@{? z{~jOkagzK(#$?dzo?%h3ywF+HhOHrQV|t>y^;$B3*tz+rs|?zJXuEsHjB$rN<6#Fx z%^_qQEE+|q(sQ_BcG!OSHOY4&#^M6TCMPy1i!LG3ILJ{9hfB^XlaWcuL5yfluM{C> zx2f3P{V`|i+AS{y+4qHOE|rvtp$YE&;>6s*gGGI)3FVS3w9LvkKd>9AS< zOE(+6L3GyYef8atXbxj*d)61IqI43oi3aYCPMx`i@6FUNi*6;j;z_Qxw#9}HG+VDQ zP3+qTCm)Y-)9$u$V-3U2dud&kS+w~+fA8PYAx$nfbz4mdS5SD=~u~lqS1)|C8zuq=c#fpQh<1J5l+M)H8p3;7wY*cI7@R+@+LjrZ>@Nj}UL~Gg z&O3>$ZZ^dy(n&|cmr0at`b!)SFquvMlv+d9-Z%`!*H$N9!5aMzSE}-|RK?|kM~F`y zQ_+Yp+eoD7P=I3(xt-O;B(+4Epr_GjDcn@$>QO~E?U%vjhlGBnU8l97by|QBt^xv-3ac5`kx$9$P5$NI6lp?_n4Ut>}VGbIa`6ou7B$p>*tD+gAtZ8|BeZaL-g zO~~>gmQ#mn2*#EoWRo-?TTCF7%a@*w?=VyChe3Qh4P{#lxeO&0@>sSoqBmHBc}%k` z{m64pLdPo^WcK>|xTp%z7xy1zMco2=O?anw%WcGBSzjC)9m{s?a*hl13eEBp_6O|3 z@$*jK;BTLvc3+pmAI09$s5Co_#W5>PAkXPHskN2pO57U=Au}g6Co~_9g&Z`D@UQg}o(?(13GAm{ zB=!(?piGW-@1EE$+TOyf3!DlJx|&7u=s3%|k8$;CZzOCp1eUQXBjDyr7l@RWmz1mU z)h*QVaq)9Eny&A6JMZGIe_pLGnC3*)D^(;}w{KuM3nm{17;#{`d_#vCv0fy-+txSb zU#`i^D-sq|aCmzvxkPhkAQj|{xKm2d+9cw%ba(1y;ol?|?|T7rjT z#$UDM7dHtQRe@EPJHVllzT^Xb`Iq&*IQ-EQ)5B*T!-7*Zne- z9wAF}zrhuFFUV^A{fcc!+1*8JQh&n$F)g2zakqcvD%!E^r0j|0-hM6+ID7&RVfzhbaI$bRoK zF0utm{|LieeT{*;*d3U%8SKZ2&*u!AjQmr?c)f+G`txp7gKoGK_+=hag^!QVb?pky z(Qri~Y)@GK5Xo&T`nl^QjQ98nw38DodvD0K#G@Wl;24DJ2NIJ+Wu`K<3M;#7bhjxR z@be2s67gRz2>f$UxRLa2>!1^z3jIW~G1ZtNUVMHnEG64%7 zMNXOz-$r2;ox@<2;T+;DW+iY;qce{5ASa_Di7(cBnWX-KKVuT$$l6u9|2Y&0Bpo?uI%v*}RsX;_-dlU%~TG^q_`m zDq;*XzZlkqah1~YZIl^l>Q$0>`~1bm&+E&jX0ubXTR7`>SMI8lcHK631)v%^H+j_k zdrLpgM47To_Z>@WV|&23cO0AFGR3#CN_l(LAxkPIjW8CRRV`DDpdqcNC#V4U;{NZKE7tV@_8L;)wf7` zjb%ZsRym#fjkpEzS`y?2$_L1n!g&*MT20|C{_4ezhxxplbtBDB`5#GHx^)FRne6Ii z0U>)#0U_*53;_auNhTI^=P^};kUfjK0dm5RD|C_Q{Jqp{gcRRSTUi2Z9gLOoetKx4 z?H2jNotgH9l)oWCc-a}WI`$eAXAQ`qEw+tcG zbX79GA8#;tFa#nT0UHVvv$56x({+U2bD(_>h&*g%9ayXgzgpwaIZOf zKaeG0I+-tr!A4e>ZL>6Pr&YXpKiF4Hzj&{g&7J0#uw#>;fZz2_A@(ey{fLIH{QcWl zG^C(RhxqCbbbxRkqUgE6#Sy<$8Mw`>rZ^xI)x?~BSD?o$%yUSB__)jnAUPK`0sUxj z9;h^b5SM0!W)eBJQq~L|iJ_grb@08mxzzHuHg#j;jj>?|o46%PpxouueuCu)+8 z;u~n@C2#Vt_xJnZ_L^k}HS6MmD>L3HCc0L;;yRHlEknv1%qL$824&@>G1chsvR`6_ zw8$2EijZt?4xvHMC!DF)P8m7Z4cHrT!XCg4h8xHJCCfcWTjL?G`@OH|?F9+%3_HK`)$uwpw-888WS?fd5!A9Vv>*Xh77cYGJrGG;lMbP$u|ik;eMwvUdp|$l zRW-DJJc=a89vV9>0_f;BnOa5KCMNChflxOHDbpsz2aJ;l7k008QB+}+v62s@n7}1t z#^Vox&Nbt?qs^!n%WGeHeVZNW(md6|KC|8XwGYXzGJ~6Rp$VItM!v-Ss z1vyGDUhCJUcMF}L7G3fEviB%}bn`f2?(X49PkpL6=Y!j?33d;>+YzHqv~EGN!lY%w z{z4*_LRj|BW(XcSeFy=uZ!4YcJ1-$(>vZXe~@A zmGqTXHPva-JvtO)mhNg9%%Hwt#f^79eP2Fjs)XR*1_`mXzox*vjU+S5lo>MNDr%~J76_-(&QNUmC(5b>42> zG>hpug_@Pct20NOnK%h_dLfbOu@K{mm)!8ahQJtffb#foklM%lIDYNOn6oLmp*Hes?G+$${6C{nhf7j%@qe^~Iy ze{|owUizff==f_vq5kW2U*uL^7zsHe7bnns;rrNe?XBkIZMw&U`y{W!-RcLH)Hjb4 zxt&SjYO_V}NudR zd{teJ{v#NBG@bXfu};VBl=79tCc-$mCnYUWo^Q7uq{c9{U$>aX&^GK}tf?4J2 zN;&KvOu6~*+;#=0T}5I)o(X){h@NCsDAT7Q99F^_4rIiVi=5O+DrZo{X&qCGz{0@N z^sV9@PKk7`Ry*3MKR45;L&LWSiw*^6iOpaUePr)ziaV56mWrWLn{l5?JV+-=AgR93 zxRAaS6x>r1ctP7)9G^_4=DIb-lQ3~= z)gHsHLi+)3>ss@kQQMEKE&2GXukX(*Uoo*76tPUqW-q zE|*p1t$r&L9*y^_Jy1O!SIy{7U(w(6MR;S8_ek%#>D{yFQ$H%u&yi_i8%2RIHRz!XJoJ*N4t;p<_V{M0+%v9Yp(fKo_YrEu-8;Cw zk~tTl9blF+dx~LygX(0f^UKh8z2m}%DQTkT$IF5{`g6T+e`1eI*eCsOo@+xu}0g67F5xZph#b!!binw6M2!kt8K*XPdC8#|>sYk6Z&mbd0237r^V z+qn3e z#3xa+uSz@c_V=Ot{?5S2n{n$4L11`+SSMlMj0wR9gO#_ZbJ516msmeoL;QDsDz-O& zvGzsl#{j0ddeU@oW7oK6jJl2O>4ub;3#cRgG+WmSm_ay1GY+VBU%i5<9v+){@=9sy z{?@MmY7YZX$Fu3plnA7U9MhaP$5V1w_WuFu^xO z#-q@vH(0A0CU8UV@v4!GTVWTxFAwl_4~a1#3C=0nLP#Ya?gkDhHh?6^sf)=Cu_VFNF zHLQXPlG@18wXP6nR9+JbTA7V2Ml@#uY3wsswt*REqoED^Qq-qfx&v^8nGO4VwOnp8 zR$uN6vBdzYeU{8j=h!}SzrNar9KKU_S}{lDUWbT!NmNY^mKL!DskVL&xL9HTz%53OC4OV^}K2AhEA2BB+rx zYW-hp1B02F$-jt4PGE5`#Y&oRYUt5Cm7s{?>ChtMr(_^2RqDD6X<2)}&IfOgzRx{9 z?RWJ{jh6}jk(0TfT^=!H7sboqBufon^>Pd*c8K&XsjO6s^ile|BMnh-fdQB_@oQ9i zk--fjq)a*&FUo~v!$^*OGmj#*(*I78Dq4b3zweWbfgX8cMBgOhxq)lUcQGpG_|%Z9 zNPa@EWkQLQo&^(aL_&V?O#MTHA!F=Ul#%1Xkzrhz-I6#Kq0(oGK0-1P*vCk^Ga2qE3+IN zoA9F}8-+!2lEIdAa4g}i)1h`MslGOFM_@!m- zyN{!mJCC(cj}X|hi*%eH4s!p_PCw^L|H@8tbMUhNoiOJ8cUG0{ITQTbgCNZ8{el=Od95zt_Pb4DpLkp?SyBy{yUeeL^g*3f7kL3W+ zOt%Zkl9$mEdb@;Bc>Ip4gDZ)*y3zu=WlPD)O6r0XF8#H&;Hc|DfV<*{tS{l!PHF2( zENORhzkDXt@vX*fNIK+LTI^reqs5!ttYQbK#hW%xr`;+Z<)msaV((MXvRn!sqMzC& zPCQTum0|BQl&M5)h>va;*pVz+8%)P+Ds#KFDv{i^3N8S0_Lm%pl|-@BENvYD3kw5ex<{z z+e)r?i*F44F4+ql%_%C%M07y(%M7$rtwuYdxAA(W&58=&&1_eLJU5C!gjuRoOegs# z+)c?2Evun8iWG`4Y|x7iHOUb@>uvF)l-GC41nvXw#(?p1|B;#ol^&6!K%Yp1jXd#dr5G3xR#4|91IrNB-q=4ql$WUC#EL{{K6Z{wII` zj4*g-Xrm7?v9`5>Jg4J-gAP7|ES?Yqic(S%A0!yx8JdEW?Gy|lCgA@Fh=VLl%$)#q zf64&A3+exqnuCmR-z)`NLknd9&!76w zgyi4E^AMj<0*35>XCT2pfd$V>ksuj(5F^`z^4&Bzk9fzCc*PRULenJE*C5NZ~5msKyDrYCmRO< z$jbJ+jFambhw|JO@XwrinwwTimL|`enE=m?{nM-ep5)I+7@OZ#d~ag#Y{O@;jQoF~ zE_Tnrn1A3+*#3#7k%d?p*y%rQprX?UnLw-nY|Lz2%&Y)9V_REm8$OmN(9RRr2m-P; zHqp0Xwt^VZ|5*m2Pur_(O{^^6Jp+E|-tn=svIAMUS=rctZ0sD|>h%BA_kV#xe`9w< zK~HPK%IJ5WfA{_`iDee&z`0kHgw z#|q?P|2L15?dc=&-*T+1Y)?M^ZypfH`#*VX9Bltfdpu98{J-ly4VvR=nE%z@6OZ%X zay%SQ=KmjZJe>d5%fcIY diff --git a/examples/dev/backward_trc_fusion.pdf b/examples/dev/backward_trc_fusion.pdf deleted file mode 100644 index 31a387d1c8b03317011050e8fa7d397e8e4463f8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 14244 zcma)j1y~$Q({9it1b1f%?z^}rxCeK4cPF?*aCZpq?(XguJh($}mkY@`=brEX{`)+4 zhn<Y070dYeRwUyuew%LGGGb5wc{!w3KZEOks^I5`1y;`$bbc18fE7nTeR z005v9G&Zx-w|TjmY1`@Z>+4$T>BDey!`Rx{=xdw9IHioJN`&LqA+=nWC2S09FFsI5 z1*oA5gJbp+R#?;`v7)d!R2H#@Z2_LSf1!O!s;X+ihDz>7B$lX9HiN(2pLxr6rFF71 zx}JWxW#z8+RDo%J-(`uA?xqXxal4~~@XWiob94V=Yp4vtfGOY;-7K!VWBF)O4JKQ+ ztJ4y(i0S()t&(^qVNPN;C6rD|?$h&IsZ7->3RIiHqWEgN98%lvb3|SZ32Hu?1anST zZc9rTgZbQv>TQuU1kLKQ>kUo#YfE$-1>KMj-)XuKzipgEnXv_`=b#2iTH?o#wYD%g zIz6t9)D6W8**1#xm4%T}_0dT(NIZ9)STa*F{?A8i(}9~XhsyYM#rUD>{N7TwyrG<=x1;tY^CfB_ zX6EKDTG1`LmsU3x&VrZj74vxm?M>5goIwQ8VIy;}OTBW9ev3xyenHp%CJ}u5+Y9o> zTW;efTSxZaxc6=MOs2#w{UbvsR2y|HHzJIYj=A(-lo%(ufHC$LUTJ;nt(K&1Zba)nHx*DNxthkP|I8;CGVZl;b%ucmptRhxA)I1tzBL&SF5P*1y)HZSF$_WS!yl_RHs4r!mS@oZats1+Q^=W z?x~(59~ho=^CvwLD%m*zOx^GWsFRenJcT}Wyo0^o%Rw?ADD-yt;6*thAB!i%x2%ng zyoMVT!aGMTcjaRy`7cJiOY6jF15-Lh^Ss@GJ4kR0{P+D6dH5I%O69$4Pj(MxeTa{PSBE2>kBF>A^CH|a-#aB@Zd5}uyxGDLDLP1 zZo7Vy&Msw@eChIvsfKPGTFnH+DEjGoXL)a6ue!gHX_k>XUE+a=$OF6xM@KJ~pyQ09 zIh-y;(x6uh9ZxgGpnC;4)+gWe5zuEz8QKLJ)h{KMC`mQ*=pkrw=gsJu1;^L$rv~=* zz87p)DDHq>A7Z!@#f6oQ=Y_v%RDf!Q{0Sbn@ZLBuRbIyn8C-)Td-qUWn6Q688wGg> zp>#RFcpx^cahgJZH_EEFUOH|cZ&B>lO1z+GRJIjQO}ogYxm$1xt^M+y6_1XRTy*-~ z#bXz3Ga`}qN6a^#M!IlP6(P#E{kmIR>`oZ%Z~)k`%dYd7dOZLf^)+nt98%WcI4TLxp3Yq+x0)n+go&EXA5QgJ;mv3s%Hxw4aXBoI zz3nMulr--bi_l7PS@ec*ErfFo+W2M?%%%9eRckSLS8}IN>2kj^NLBWF1V^uON3b+a zrIIVHL!%dI1>sS@33i4i2jJ`EQ1>~0qN;U2Z$+NCmMZUa^e6gN6)8F6)J_1z*C!fA zYufvWju0Z#FOY$WP)M31=ve?JNYQ(^~?Xhx{gai7R!le{ZjK-lom+DLI%*Mi} zLc}<8Yzfn}&7*hLcmO(Q3MgS`D9>sC#Q|b6AJznGcIKWbsc2$#D#I3bx>M^=>_bnV z3UQo!8}$leSSR#&nkA^E#MnDbsi^7!PxQvIN!xUY*3M}f04DP$(<*3^2Xs8hPY<0( z=F1bW&c5M0qxq4XPs_h-8)&f!-FfZY_c^$GuSJ%wFO-nG?D81H+PQNl0(-i?3ND>+ z$X7B-tS>LGD-~=38#~2+x}90h>pLi)x4K0_w+IxS%QSXw5Lj6E2o$XGH{*s9j7-^_ zEF=&5F6>Q**e$X|+4$7@*-o|wl|8Fev1@_IvOqT)DBY+L5Z@PC?5vo3lJa9Sn4ZjY9cULdATaJSl}#^yle^ojHz|a*R@9q zhJ)lC)HTUodx{|waAi9q=AS*er7Yha$A#me!r=2&pINFZYHEgSkz%(j+o}TPm9mZ9 z5+Q&4m}<(Uu-fCaDr;48Njfb9O~h<$pO-}GqN{Qga!d0r(VZP<+ii9Rzg=-qzwR?S zlI*VctPiaXJawGa3BE3w^rUtVcV*C%BK98VjGng_^RefRR1v)dY)da1oIRv&tb+fB z;U-4UL9XmzIh6g7dI{5R{*HuAER#PaG`^2YK?S*HU4Fep@1^$w+ z6!&`--lDs98O%fwhTm7MiS0~G>si5Plz|8xY;bH)~`2;lp-P{MJ`Ar zUa63RLy$M$;SKPv-hS(atwUcelkshfJ78K?303{rw3GE^t{cAa;LbNDk(Hr)8ZV4+^ECM^e| z(FtIl-bWS%V#{?-*<-^I(=^76D@sk34c*?Krc39K`SpzwmT&%SwSLQ zvkdgU5+UgQa`LzU<|VvrOXdmv!2M%FXa%+Cul4709l{W!3+vUv-=Ub|PPrxpx)I?( zib>aE%yz!RU;y~ojr%@}Z8v;Qy5$m`Q3O4y zjuVqU_Gtet*o{)cG--PTpzSa&rmp{d=A}dY^bCr!RgvxHy~#M;$(r33PNU9XfR(s8 z%U)}zMDu>~H5c4OTJ7OmYEo(c!lv88vnYJWcI4Ti!(wogN~jlxoOrE zZNe)|!sLg9_z47oz=LmiN>{rOw6pqZ7E|>IuFTK9Tv5ae9xCp;d&t zDp>fspPvb3+FP}>;$j!)EejGY-d)~_c!6PA2?xI11X&@zhdG-(ivBpauk;m)MzVZ= z05xsAkgVCYCN5`!rB~yvTVY=^88l&7=n0O3t@S!FDG76Zf95HQ6PZ9Vt@O+_jb#~2*f!&JU&XBsv?vHX z>#DCQJRe_U2#S^dz-SE7Q?R7FnBjeR>_np+&pT=vBXxO4V}jx6=^0 zgmy6;2nD=jlKwHU2nM!0LysXN7!Stt*5U&-`b~?9g#ig>YH;QLMn?y>ZRB4tCl2p8vf;)KHTrSTaRzM@9>f*HL4 z%62B`8BvR`%bDv5VzSAGE(}gN3CZC&n`q-4>OXKvC>8bbV|5vnEW@ z3ZV8UX0rkzleSB;18*qtWs38fn;{x-Mo18yd`%76J|$-7VM&|zQr}Ktua>C37xtQ@ z?=@|?7ieC*Dftjukh$4mNw^dA)0+aP&F{o_)r8Nmo|dT+yHKH)Cw6_|z{L2wWR9`K z)I*?m+tRRgI$#N5kivbPwf1Hgx=(uco;Y-a{6MC)-}<{!OWgOCkdp4|Guo40>wZ)f z0%Q5xJoa~INK+t>rFXQkgIt^S-KuN3X^N+$__}wvMYsvUhr_hnNga_TqO19!ODZN2 zpMbY`!hHPx7GKDmN_xq?lYF+At<|=@M(oq7YSBjSBPq}Pd+mn2m~6HxbJ_GratK3p zXG)7i^aG3 z#GF+Et2e9z8iD8Dj{J4mu_3z$GaO-D8OSbN_#%=m43K@tD)SpE(OXW34#Qujj`ms2 z*TZ2V`hd#@u;zYX?z48u$%WTu{R`rC=5t(Zi1?*BNowbEiDW8-W#3uNCuPCybSk>` zkf#?T6;&XeMO>=La_*9gXty2Ox^G=r6JQQ1WK4gcj@sKNEb?t030qMp=^Ht)?NNM= z&sImpYShWK?vabQjz3#VkWoy{dq4ujTTjt!J&KVY*L>F4yCa)%ekLloTRz5`vct|= zd9XGCcZtH6uJuRUznm}e5`yAWlWTQig*>(1cNQLo2DfA<6kDs{$)uC0P)KjFAM*MX zIdJdCdnG1K$8mDwbrUqx2h?fjMhckIm&h+qbdeOSaLNhFsASZ{HLa16NVcX(LYCmD z#;Jw$=ev;>^ArT7>Ws@awoqjj%UG#kF8;WSff_nnv{k81OH_)V*+)iJSt%$`{O)C>tfAY2L2oZL@kz0b9%yv1UKKPlFUm8e&?zw@<~3 zT zPbcmWMVw&o*4!uu#N)5QRaC`zJ3ln5pc`JM3!^+K&`b?_*SP2m4oA$8!kLs^9LH0H zV(~bdYPZ;pia4pO`W%z$K5D`=9xzEC$bUpMd{`n1EhnF#ozJ=aj`IeWv*G$c8dgk@|lAvK^WMX&~cqJMBN#y)a-3U0@ z3CY^Mq-_A4oUe?R51lsfg#gg;@c=*oI$iCT^Gh=4rQy$+PSDcA?qAXdG(Z|4fcfua z(2LN^-(Sk;|4{~@lXJGx2hd4r8|ni{0HEK>rL=AIU&aFc-Iq>6U(Z;Z*U|}~`oaPN z=z$<0fR%w6p#EE&??v&8ye;508TBXgBxGZ0Z}qRq{893o`?oc+FRE>{Eo`k`mFYVF z%@zUJ+1Ts*~u0@xC(uvQNlBxna1)$|Br+i>p#=_ z-x6Z4p8b4SxfKUArbA9cjez`=^3OT5jg4k zGi*Cyl2VKo?n92-8nm7-dmKNus`h+oxz;fK_4BSo$Op(P{COrYqkTyLfwlqFgq~>g z8=!kQsYypUgZciJ*7qS9v9oELCb4a|Xy3zBhW5(zqP9E7u;$d&_6^zpwkrHT)th8e@9YjwXCGVe9hDCi?h~! zrBvaoE?I>HLt%xdbS5uPEn?>Dtl; z_kLks!j0qB--(83-)?PY;HLCt9_J9Z-hZfB(Q3J>3#gad)!hgkn^$o`ON_vQuFU$# zps1&6M0Vxg=BTbDQx?`Gdm_gPCn>EujWLZe6W34=M%Sxf&!e6(Tfi|5m4PKFSyTpMN`X+8 z`qZOS4k%q3B;+9tk>B692=AO*{~uG5vVcZ2D0Air%=RI^LBLuxbG@y zo6*rGd$fP37=CxGFxRdJ&}h-Wb!%3&qyM-ro9a)jNr{}uSWwim(W##xS0G!v)rp?G zX;Gi$FnGLFEFx=&NqkUeweU3|(Me9Uk5v-{HCZ*8mrW;9_9gGuN#>jRLN@w+^*w0~ zYwj}3hJ1`aDk2M9&ax2>B?1<^tn#d;FjG#FEKR}F^k)+wlXX5w5z_$M0N-GsW}6iD z9N`1(hiy`LI0;131kxljrQUloNK!~L;!#wCB!eVF6$5AkXhTwT95Fu`*@A{{nRPn2 z7-%sR4F5u3O4-LZ#kXNHak+WGFNSI-Q03(Q+vFk`qiB8 zp5?A;ugeM8TV-Iv5sQbz;=UWU&0850RvbGlQdX3gVbkgFS4^hW2;M)i+h|N0D%J1G zCm}Fm_y8wMD(%OVWCDLGQVn{JqkEZnu9cVaKQ&z1HJ4V7g&Y!-WP{y)ocj;9K zcRvWjCVpB}`ljBa5c7<#Fr_u4U>m2J6oj!3GmOoj7b7a$z(_cPTI=6~Y7Ct=xFnW0 z{0VTM1*J4TTOq0!{GHHiee$-X33OI%36 zyfsN3G&5tdM z_GkU~4hijE`qegigh2DPr+N15ns8e_o^AnZ38xwKlJ=hniQ*TZnEUUf5A+!~B zJnIZ$5_82Jb9b7pR_s>XR;JLIwl}0RDu*ZxIeggW&oWTqe4dy6m?iIu%xp&KFg`Us zdk!*U#K@PJD%`FR*=9h|^{*(_-@Qq!aNWASb{R#-U>F!PV{@0g>{QM+);$anSG_Sg zjiOwgUu(xJY3GAS<#(hQ;GyJrFFM!r)j`>sghD8G7Ti6VX2o={f&nQHc6US+(k57t zPUt0ReOOaY`{$H{qd(WBuaV9ebX+qe~Klg9sY z=7E>}#!rPo(R4BQ7OIS4Y!_|H;ucDxv^XBzTHhLWRI-(=whQWs&@f_~AO*ggdamZ{ zb*Xw2j)QxNd(cLyT8H(%;znzO>+N9dP<&hBw<>uwp_!kBg)U!q7S;gWTtH31Y;xw( zC^|!F5+5RRGXPQ&$s4WA<@g%sAO}Z1yb`80RV_q2;Jfm^_6S#7`O_ypF2%6+6jTh; zyX*CYBS07tFln=hpuDQe_ z#`?x+kBgIYgHPu5p~gfN^6%hiYv%NX8BwLsQQuXlV~SXtDa?W_@H| zSLig@t(OM2HSv?nP2AQ}!ewPYwd7pejW2qwPw0_&t)<2W{^TdNEH)V;!EXB*Yok*VS|UHJjsvwz|TwdLFz4r1t20*^+Aku~(s_AL0>?yj3RMR^)Q z17DHGveKj6T!y!0i--9FzXonmT$T63P4owb?T(bgbt0O#2`3&D&i1!%n~q=`1bB|7 zgY$KN-xAGK-5y4NTeY%VAs3uKq12uknCveCiBKh&UCuW3LB({Mowth|mpv0Z+2m9Z z9Wa$mI91g5L1XhM6hlJ==O^;m5@f03+fv1^5*{Hwbu>kzJfP7?zTp6;UP61zt0@wp zG%ioWu~NwC%C+N)9kXA11j6vJOqoT|)BX@$n?_4+IKzCIC&R{x zioSo-Kz~zG3M~;Gv&cuw%@PSgB1=ag^A5QPa*wpa#Fov<5wdfqQV8O~HR? z1VufZI?d?k4=e3B0$JZ2o1Dsa?Q>2FbP7%L6AlLL!!h#CKVj^gpZDCBLLJ9GkSew~ zgo^Mmg5|u+E2_<2VmeuM&H)I$@{AdiKuj7VYlAH#pBZ zhg?k~Sv6gx+{c;wG&U1M7hWcP&=U!A zdB0nV+14!Wrm3lzeW#R73qFddDTo97mN4vSAJ^YGPNQdkVl^0}6b5b${eH^kWZ?s9 zy6WesoY{`E{3u;)JSco0;*CJ0IWT<}yUGUFHFd6|aij0*GK*W7R7#9h*Sp25JZY`e zog#2Cj@T)#!GwNeA*wivI6qLyHW&Fa%xdG92I@}rMMnwJ-0#ph#khb*OE;il1^2J& zQ@Wda@6+<}srCkb-b8;j{9&H&T>9f}G(vkA=k9xwNR|b{L3Ap8e)YvS2{u_GtBJ70 z0%oprxD^IHJrBoz@CEAUvZ>OShsr=#uH#V|e^KF1WFn-X*!;~HwJ6myJr}fWEkEhN zPD;kHf_l`tb?eUYkSm)vZ!u|~Hhpy1ohNe}5x$NZueh9a z42CqA72Z8clBJvNeRv)eEyn)g6^2;j0e>2UA|bWZZB%5(D7qSkv__4Hw%8MxvK8#d zgu(6tk^JEouhB*;O||u2RD)Kyh{4-DmsK+A}i4eVE1H-uYZ3vgHQ{dhc zXJF3G5cGW^w?fZ4G=bwb#H%ngLY0{cB#JmfO*L~!(@g#)7*oUO~@EurMzBxpoIMXmwF%Z)! zjitBhblTVsW`I0YBFk7%*u9ufI*t*4S-)ZE&X%!@oq(piR7-I@Rt>5Q=X(wTrt(RE zGhAzF=^YPg70mc-a+BSTG3#R9Qz*rhDVVzeu^4E^6bv@pj2dA!`z?0`{hK%+t>zBg zT=9`^=9*|*wR>F5u6j+o&7%H||PP_C2=f1==;zZZdEOj~0GRiQ;9M?zLa( z%E!j>N*M>GVGCj{mOfRe_lx4(f%Da@Jd@yzTsrryD99ws-e^;-+9}2<&MCSnvO`h4 zE>Z35cnJ(zaKsU$A)TPIzy=Z$3=Fjnh08kps>VnP)fFz}Hfhbr&A0`@THFr}L{B!` zvX{*`X*GrS7;9H|9%l1yR!yY8WLA^1wCW0WGwD^z0z&p_0z&AQDFZnDl8nvgE@P^& zZ1&CT1_`mg{3MS=;OHX(VSQ{oZ=(yab2O6A`{khqzgOfBd12BYQvM0|-P^99wQ*|g zIO90yIQ=*e1uX?Dh4qke=x6_S|5g8PxbH3=ZW-K4=}H9btDg{A5ji8A0GqOtKXM5i za2@a+@Ei!PaKku`SdTc5*pIl6*zgQzC^1`t&PHJT4XiA7!YMJ69%*O{OlynX^{&z` zlZ&`r>=#ZOdJc29h7?8@>UFR0%2pD8az88#qTMp|Nf0C;Ih(D7K}1%U?a(#tq*Z(p z8|p8nSbWsUW=Zo)*tJcN#pvdjjlGEII3^`87yBFw4-=H>7+)-{z4WA2v}psPX8+lKZ|Zr>%Jzy$WF!eG-ZlUCM7V zoF*HW#1FS6`Hi`h1z~vn$1EObIs@>p#=k6n43rkB?^v3LDg)qXWZq1oFitY;x%c~u zySevceg;c19!8yeXni-EjglquRnbF;K{F}3*qH>kv-l2-cG;Wo`{Uz5xPw~RVa?d4u$L6!HBWa>6X-=BAYBtExttV%^3>$qWL7 z!SS0FBQP#Z;w(~?rLTITc2-> z6M34aa+o#f(XV4zYK_*QS?kT$&dQ2#FC9WY=JT+@2wg6Q(kp7++Vmdo%k!ce_EQIs z0-GLIXQaJ-bkS+anoD-bgPLIXz=vG{l0>Uk7)vBF8kBFi0x4K!{IBvrSoP27jL%)`+gBS{x~; zKUi*)->3hZb*5qn#(fYssDt_=(tRX>VW#-7F*}E@Ja!5E86R*LGeWu4a3P~)>e-)b zAA@E5L2_^?bnI}+@uVSFIoURM3gNnFDk^Haw5(13nNRic=(6qWSpLnEg(VF{L+rz3I`{^Kr;bIElDdsMd&BUNk9X9 z!706&kQLI(cN{NR1Q;BrYC3&!qFvwSV^3Zq}Y-YWgBa zw~xT?*)q2iZ8}YDgtHja*bnU&mZ%mfSW<8~i9S9pc;r93@82#^(&meBiYXor*+R9IwwL_7lGz3%LSsT>_SVLD@WsQ@B4k6 z627wCc8rV3{qgc*-7gERv=m09}_iFQwOw)N%if*@ZuBk1*4R!*&#KmZKy>n-_xVm8Ta@OO?)_t zkoFu3AwfeFuMvwUl>g=9ER66@cV~X|Og3+>N}emLTtz;I-h(DLpWkhdYsQr~_R9t5 zl4|r6P_|5$6l+8tc_feuSvqn`GpU?X4z+DuIRY6GStiwD2S2|MkX`y7mmA$mRFua3U zER5RpGYwo3XgzmKUD6lVoR3F#o?WMIc92)T;hb{MC4arHDsS`qIqA{#xYqlo*W;!I z!TCGFr~U|UWWrw2eK(y)I$e@yS&BIVbrd6fBL_EH*i}EIP3Q%5UsiZ8UB-{_Qte$2 z7Z7^AM1eN=bMFe#M# zdJQohEW5@fxY%)RQse=1H{frm*?xb%+K4qWk*-O$bKyqE)>jNE$pAH)yQ|5Vr9hve zKhU5w)#maxl)rC+=O`tO@A71YYgc!!@AEH|2@yM_Pr+9k2ouR4kVg=pica6KSe7KZtCr&TWpTKhJl4GUO+fNJKJnz8`H~ z(fCE(@AFmGu)V+ms&4Hm)fvl00v3c&O`0AfYw(Te3Gdb!k0st9cjTDk=`lJ@dy?$I z(5gTivd_~mLOXJwPx$KbG5VNflC0(0(p($eAf$0Za9v1}&-ACO3XM_V1Aiys0YlxJAyX~sCHkw0k0kzpM8l+Y3lW0!dtjGcnHTe5+Jx0LKAkj z7iO@JkzW&}%61Pnf~h$qD1?ru2^w)v7e2A`fwQ+%CGyX^> zzIj6h@4_jnVI46`eyd5zmKzl>e`XqsD?bn!*R&gXkiA7~qeK!g8-XEz<{iuFF)p6& z7~x}ZrWW&17{QtR`ATSjK@wxqQ0~l*(7A9$G+mT*1P$E$3#mkOKX(u(nW0CsGg)x?6C+b+=+%yEF<}$#f?>+&ta*-K}oiTcDgWW6aG+UlU${3K76C^8hPj#fJhsYLZ!=B9%tif5b#2qtVgtD*-G{@hs{^FcU=Eb-Z{@q-To zy2_i<7|z*g#SD@AoxCa~Q8hWp>NrC98oD_MLIc=T%IYMOGY&E&QG+g(qNHfx$_*bs zD{7EY#a(Anf$KScs)XfMiTfN^xJeirLsv0w6H8qYK?0*%>rbu07)(n`_)Rc!5}AP} zR@j(HRfqJY1(pw8lk5XVO2!BIN-cM8bt`Y`e1n}avE1|X0aw4&c=6!XoXmsl@`z#k zC^kxGDH5ovx8vZk!}y=^r9@gqkJI^&Rrw4G^bGJ5zelAP>D|4zDHF{_kNS$cX(&y< zmG?2V(*Hq@I9iBGw_j3RPlqruqJQfBrJifdk^r$wd}>HlBnOt)3YO4W?}D+$dn^vY zOx+{BVIvegqR5Hh$S`K49${2-{JcuX>1b&)XgqrsEVfP_p)Vha{6E!2%1i>>-=6}wtxWqCdR!+oE ziWQ=N%W9k(pjTS~fegvsE&?rAh1JWOBrW-I4<0*_f08H2?yZR893vuENksP5SY|<% zhzDY^2m1Ri!s4hHZQ}X$?*q5Q6na@AvLf%eKg`PShSX${Y+gDL8G{};u>>Z2ymB^$odvu?VPqDPnY&Ecj_~# zf?+vfi`ywp*Xr=L9zNdqZVe?sIo_mcChcDCI44zO5#`_`8QnGa5yH8B;^fmu?lP1E z$}+`>O~J990(;y=E4`VREd>_0HhJ8KHm(Ig&cU)HjyxZ-O58wx?tfnw#078 z-0mxoEqx(~+m=3v%8F#k=d1Qx2N{AE$~}EdcXSOe#ZIRBFXH);X1qM7*kXs{2lPffLVtCh({bvCJI+bxMP+)VdW33H<`d1*yT1vHa?LEe?@k^%Jv zVMSnhAcCMfm4!!jtab#CQ>Y&lFx>~;jQ|tn{-ZU!#8&iY<8>IZWbyOOV^K}r3BBY9 zR%rx@DTF6(y?NGA#Sf3*3xWOk{~hw*f&44x3~a1_haB`;{{Op@{-=KbiZI~Ux7D>V zwz9Ldc`e8P1|59SHh)17$ccyuNeEH#>zimR*vsnM7#sW}AgFC_Y~~Cg|I-He-AM7L z)lA#)rG`)ck0|dSlmQJ3GaG=0fsq+N4+1g)Kp-I7Utj|fJ8d&#T^eRBl>>!1Fwgz(?QtBEhm zS)U&83MBY9u;3L*@!#(FGn~DS-ES0y+zWK!)m6W_yxO+{AK?CJaPgSFS~h+Eerv4ipCZ^7PiKJGk&X9G}f~-vVB2oFo0go z`1irU1_Zo*|9vpOKso+@zJL0#GW^dmUtlI|uQ-hV{a|K&g}nUx_@n>NnF;vX!o>D6 z1ICvt^B?Nf(3epe8DA`TefKcGEP~fNHb&OpTxKBsZ~50gj4Z4GCJ+OF5eWL-#>D)J zLwOyG@z0uhS(}#f7RImJp8&6e{X44vy~tmYFt)$FC}ynp>cdyCjLbh!7yDOW%)js^ zpnqd&q--qp>~&wZT#@T)8{1d{K(rucS|EVj$j;8nmYwbewDZC>w9&RQGS;=FwX`v$ z`12U>y==&`Gq$wge+B%I^Rv?f=^246KoE!#M9;vYLh)~X{}(9qH+F|l`{kLiH2gj1 z-?RSU*h`u0a*OS12QtR{5Ow@ z`Q=CAKjnZx&`XT}lgG%!_&<4|m*x21_+ gFn9^;?dr|E@0CRo%Td^=ld>F$rc63nwy7{c*)DG6w(%a4@z;=H~~n%9+_)Kr8{A zPm&5U003Z>u(E}ifuHuaMi4VGGZP0>Gh{(QWM>H2%*YPeJ>!eEd^AZTR{M43-sYvz z64N`}AyF0-IB3+Z7)@>+#Dv%~`c!q?zTtt+F)#5Ni?-T!LRb`0;FkSWj4jvUOzGRc z#D(=vzR!;}{l9(!?@hp0N70?>ogSMz*N+*d58PQk;l93^9lipsGhZC&oGwl+oUk=M z?yBvCl;?PqEAJ2uYkPmLa}G;Q-po+drt#vcYvK%9@gS0ygOFTTm$A&)$w9aUPWQpm1HW1ym_M$LU02eEl{$YNnD6wkF|W|pD4&tdG!T$my>s6^0bbu74D8z$2sO(q6!`OlEC56+HXD+U&Bn@SuM#e34cQyM7l{pH=4wY6vD zfWg+UxF{+IjyWymmx%gfMqaMiBAhn8Z?Qqlq38w|nnyw!f>Ssrt6ngzXI1rLqjN}P zE-G~14)uHglKbTCEpNhKJp=EB#YoZW*|lZFLQ@VO4nICyWK}UwFO9yWdA!0$A#(ez zNuf=e=tlA2UiG>fvSK(b!ssOd@Hb0bJAo;A>e4C~@8!{+3bh=iR(KKYKg$ucAR!0*YnlP!R ztNu>v3jan$uf3$cYNOs$$WHVIHvPN1y@$p3h2RekILJa^uSmmOJ2aocZDa=t-7ahF z;6T06K77ZAsSq>D{uSnlw$-)kv|0h^CRAE%gf3_|5&fA14Ff;4CSu#9@dRyuq3Cy( zvRa8O5tZ1ie40`!Z%1k4<`Q^q!0=u}bXH%)4c!$2uRk=Y3y+I+&lKGjLjjj}G?4FO-L$3>Rca$MzSZm~SlqmC=% zh4xkwC?Yo9shVl_`+?J-KlH{Zs|HV0xA$dwT|ix#V=DlnQWcRF8AB zmxhd&6Db>3|LoV^xZt+EYdt-RO^F=;I^1YZv}m{_z;h!!_Qv?HKnurhk{`ceoY zANp53bVeY2C2P84NGQJ!0)qerCiTuy=J=tnV??@K0ej>0xVivYS$wcQEM+!NePwT13X56pr{gUoA4>i}C%f{+_uqd29V=yDNTHy7S|=W47v*z8 zF_=|qRV+*)?ha5cGFo*T5?B=JBtV<%>WKWbxM>^JQQZz(e&#vna^1IR?}`zK&t$PG zboX|b92wsDh|o}kYA3}Fi|=*TUJ-pLSum07Ej?h zsbUsC*bErmc=?21P%@}c5RR+!=3>F10D22gBKIuXeCn|oxIWmdF49QSYFvu1{6Ayz zkV~Q&UbGz=UE_?~p=?;Z<$)4I!ox6#j0M5S`ZHoSFTXI+aLl~opX0}9|8z5B zN}JE@K`C;`CA8L!%}-^`mAnp98LXV?xoQ5xJa%@C`l$c-2QC?UR-vmnmzYQ`^9wuS zZQ!*nY(aVXR?$RPUJ`Po6K-_I7u1IKfb5e>2iSsN#&(T#BDk_?uB24nE>T#1m>a03S4&l=_uzCAvb2i^3jljR3%r-Uo!ejBb9!EtKY-Y^iA-LSm~d*Nk5=OG<>zOqs|f%A3Dh%APH^9V7x><4=FQAYCU0@l zV(l(`5lWqKgZ)&k=-`T#JDVg_CV0Ew}b zw4V`+GFU76gCV=ltQRNh4G#ZKEIHeWMUpnIQaMf9o*|4L6tMwO^@X6w5=mpLQ>Ob= zobSpjkry+SAJ{WsbNrp4u5u3Pw}V{X-5l23r;&=s`;fxC-bS#sNq4CcG&KUyB%HZK zhkR4aiAx15ZhtM$%0@7QRPO}=Cq5Cdjd2>1HPC-#uNDlxL=d=G)D~bR7ESGT>P{Mt zoYQ1zEYHDc)?Gi&O*cMWQ2RQ5q7QX^{5jd;0HT4i&zn(?xtQ8>MwVSo^7VurP4I-J za&W!i0Vg7CFgeX9QF|%Z4tF2Ssyex0h8MYwAL>pI3@rx0AV<7+Y#xz1m4L2pOw@Why z^=J_2=ULz)J;RVAI$aOORL_-|bfK6bI#cd8e;smdZh9a)BhWsoPg}|u^dV6;Hab#_ zjd+w<$B2gH743G1^5G1RskKrhbwzEI{Bl2J zuIzF|=F&A$$r{8Mk@)X?+l-;ul<7nf>Y)hXrFAwU{GV=kIAu{9h#4sJXlW`3{2yuv zoNji+KdHB{1;UeryJ)5>PE@RM3~%F>(Fmufko@&#u~H3QQsT*zx8w{v=0qA?ogEK( z>Im1Hk+Q^_HdeX|6wEcg0=mcQ>CGt{SML&Cc($BPue?CH9P%nI2k>f-n1oMtwUpm% zE+QTx=WMNeFydJ%=CGmc5o!-rw5B@=q8;Ued!PN+Nlhf7|vo@^FLiBsIz8D zI7D#jNeNQLNlXdKB7i2`;m{@gMn(*W)(if*bjK|km?IL@WD-!hB?H(Ohuhx+6leWy@qX1Iby~WeBg{stVvZ~&E zR|=+Y4^FrCpvmXI!}3^o#={_Gq!#E8Z*AXAYU7t$!gwumrdh*o)0lE&c097GRHVW#ZEy987VH| zJuq%5UYV(r$i-o=mYRAm_|7^fT3lFE2qVhUIq7j9nrS{(PCe@a#?C_j=%vm=O(822 z3c7)1h!zv{PFANubPo}9s_{U~;u!m2GLxz>Qf)57JVTZT&01#Pt0B>M$Y7eBF1a#q z1

K@EYSTu{42J<&;O!YX+Nyq#%y_ks9orui`w6H7a?>$u<;JS|;kfo~!eMU5g}| zncGKnwJ)joH(vym`&`fPd>8QRSJ46=b0mtx%!$c821IX_W*#ba<0A9X+WowAov%`@ zG&2(`qVte5Qa7~f{g|@(B3?pH8il)h2D;qPi}U223Z4?1qOaa#E3&=8+B*Cr7X`(; zAY)fkt(ebZgj2RXs_baj_$U46u#AzxZ_SGdB+=wz3tH_lk57s37DFc+!8nQfcJR0# zV5aVmyW-#z!l5;NZGXucn?}#bh~{TvXxz$1 zC=4U^4My^}4JcI7Jk;?~t8p~w4BkKL2*-|G*g*+@KLaGUJxb8ss z(jRv=1XP%yR`<3~LF3ly+~iZWaM(w({Fz*nLD;W-`N<@r*f}AD z3&pAuriE$5v6;x1kK(RZypmVizGMrsJ|qi`57$%V1&`3)4F}A~X7;9kMPE-x&vDyx zu=c!v3fI^+;P<#PFLzn zOG-<~)n~rVWS%HYe)HgzCc@?m#LR*h%lbefOl?gqf-OSVwv~uLFuk6MNkssIBQ5qK zde6`?0S7)c^o(4iPDPH=m`d0ul)Y9e4q~onZNMGC(r`L2vy6(r@w~b0V7dGu)xJ=S z)ykwiB%Q#t@J(7ONl8KMQPysCYO+SMQpoZ#YMcC~Do-yh!eV{F9X%`G#d6H1bNV&j z23{L3YY($_^K70#MPh&{!SS4aA(19sO|JWlbq%F!P^9>KIr{YJxm8mC>f6<^v9W%u zn`{qcn+SQ%JI7z=7+&0Vi>2)ya3A|noR9=tO@-A;r?O_zEnSpNFTds72YC|QaOL36 zC{&sjCtAV}u;J47M$w}O#<#4E4*+Kwin?T8eMY~_nGHYI#ygC7_|fOzkrh&ZLo~}k z8ZsSL1@}sL$QDF*Qyl1LTAkv+tpg*x7<8zFRLikdqAJrSdmJv>M+jX@lhV~J9)yjR zT#=%JJsB%yCgt-@z=EYKSEjv2fIMoR^16q#{H^P*#aN2K81+I&$cVqH(PX#^ZNGR4 zZ%k??FQc=^POw}7;@Do*Sn9-jit}j%`fi5TZhJ6W4D!^MH-vL762l%;C^PD!fe_~w znr~?*iWqqaQ}Gh#zq;u!=tt;nI+%r=5`5JR)qp})rso_+X&xK}WN*;;Esouy7khY@JU>kUMzft{C za`o;K-u95OQ@2w@@cg{Ppv>z!@%mlyz1D5V)uP&)$(%xBuD2D#!~*MEwRD0+X~mh# zbKf29#J{ApRae$$3gCk-A;%w*6z3)=aHGStLzn287QN3B@aw~Jzc49zr_0wl+0mk> zE%Rjc;F@a=b42BSKh}kz@x!22{Rh5F=HGqT>8XuAW!F3q-u>(og$$Q@MC zT_>%U2cgaUg2^^Yp(7<1zq7tDtvAZF!{Um^x5V)_#fD{oMh&Donv^n|UElPZ969?!7c}TL7GvX;6fi&v?zg%G%`X{6-{rHTl9~> zgesG}7)J>}0=B|PhxkG3I8O7pTMqemU*Y#KDDEi3MC=9JlZR23h7?$u>_DU$enB>& z8m4p2JgRO((wkuts5|Dl7DI)`H_9e{Zl*|DJ4|jG%U0`kzv^xwx^DwnyEosMCR6Ux ze_kf>a8HZ+{8*e*8Gpu*nRXkVKSsh+qa7Bi~Tesvl0 zcJ+fq@pNHh`%UgnP^sh6&iu>N+Mo4%=h|xv6J4a4U7`xOVye^A04hc}nv;zU*9Av3 zEq`mfGLJ!|9rGALtFurhsWJy{5P0sFl%AMcIz*r4bXKNgpbO;_-| z(klh}96|a}74VC85%u{uQWsA~l3Q&)66PHDhr7*kb8vD_pGL4*UmnE8{aSRW(4APx zGynrXymf28Ep+aDJmBt>XmRlWB_Z6|VlW8TIc{l0eEg%u;5!3++uPBmb^J>YvrgRa zOczQ`^gnLq74UzCEU<DrJN+tk!#7SLn6sxV|g&2@t?%HIBYI-V!osay;7zLh;oP zb>={mtXs8Ca~)n=oB)RL8MV9TmAEzp>j-}$LJm9@&@kxnF#DTm5a=ksdT6 zVP6<`&5Q0G-ej1(IocnN3P9j!DR1)ZDY&9%1{yW{`9S3DQ?|Z z^oA5h7$Z+s%XXo6R6ic(8FyiGpg$JBQt7XWD520JgnWr%Sr&JRwQot^H*Za}f7{r~ zSTQoSU61GA)(rXqi%J@-_F{7jojevNLbxX(!)8#fren3mYNR9)Efb**)?|q>%6r@k zPnD0G#YZg7;O=~R)i>PY!Abs0jil?^=3Ai;6&ZrR+;3T<6iQb#Yu0KnFInve1@O0A zDaq4{X;m^*0-6NmKG9TQ-RmOQ@((cTZ}mVexPQwAPoSXkq_}Dc~*kC35 zZ4*1FX{@72=TQV9yM-KMPLS{Xz=lpg?8`M1i%AFIHX?;txLz59@=7?(#E~R&88b3s zKTH>634l5=q7~edyPtSu)45z>!i8b_07K9npiwGqpIWb(IV~ahRhueTqa#Jah#QB} zh<_L_NhFDhRjOR|kbv})<3ZS-ZOH)ff}dP%lw2hit9z^)K;OOi>)iUM3fMUtsJYTk zbN>F%>C@JZN)gUmrUy9ur&SHy7QTJ47U&zlE^MIcAw|9_m+p{-){n}+UcUXv^K1Uh zVEv+BJ7RWcGBcYGMMlD{*eGHP&N$!ALpUa{H_fkVMC+6pou9Iy;AhX^mdox7yLC9E zpt6NmbH&Py{KbU>fSf^>cizP;^;9qC5gECwjV#CKUo5bhgMW4b=v*HHnO#b<3s}ei zeaVdID*)LqZq+4oIuJEg6zrd-&b5_P9o>8Pta+*a#wy^@7u&^N30T= z`_8Pju$X^?w)*bki`(HW2AKAq*|7j&uJ~fvj(^u$RhOTxL7H;Fxq&NdmqcN>A|jqz znZyD|k7e66c@S%OSTseory|r@#j)GRM23r1g0>H9i z7xoL<&wa8`A5=753>}m8D6AH?>OF*q3C~fgo6D44H`*wC5ByzMC;f>-na#;Q$ea{t zHSsiy;p_GZ9==m}uXqLcTH8LnREvMOZ=y{am`aFZVrDnk{~@F2h@<(aP1G@k6fb1P zKHm1jo^$^z`#H>ap}^h@b6NZl8A5Qdz!agzq@Gp>Rua}MR)hDex!B5xYxx}%Ciq#H zrSd3|cUizO??#@tSvDUJ>32t%%x1+KRGeZj{qFpH{69@=?J!)apQ(!^HcOPml^)rU za1Of;xbAA0m6lG-F>%ROey>EXOykpY2JFZTfb5eIL@Nm<<-q8i_tuZXU4$o#0gvu# zhyATry2myX3S6Yq0xF$^LR!i>nXoP48_C=@u*lYLElIid{s3ZO8DcRvPmer5&?E$-#Dp!cDfVW3sjbEs5#pk=1Yt{6M!fn2~#odOig^Cb^)KYu~^1c zS{%u_u*MO^{u#(2TD=LBXiSd@GPzU=t)onTISEMW55@nIK%ePBE=OW+V=8nmBRLx5%9Z zyd3ZnktbShQuzidRb78TEWB^p+!~|y4@I_A>o-epP9_bqH*G&&;ajIxhWF+Kg{_@SP|^mC{$4)ctTb&i zji8y{qvO>Yb())!)fou-!SIBSVejP%k2!Ebna!?pu1xfqgH>H@`iD|x1^W}1fYdOZ zUYiqHjGVCNjqg`W13b*LmuBPybEZ!<%~)NMZ~89Pi#W;?nDD~cVZ*#dz%71pTC*eC z889`fsnzs2sEuzQhFk}9+C76281cdbGs~gI}CX@IbCR!HGzmpUNuWs_iyxuD&T3pFmsRrzc zQDxWS(@_G=JJ2`J*~$=A3qM{2e??+Pger?OsH;!Vf`(}GQnu9=5Y0<8qS6OR4p*zA z{>7ynRO)gk^P@hhGY+}hgm5Cth-?XD)uQ6}NL^88WQRLwNcs`h&PCT8A4u9?pH9xG zPe(O6zIT1V?ht9aU$wNwEo7naxfcn#3aY5=$*#2C!P%1RR8vJtSK4Ma%U3rpu1%<4 ztGD#e=ZOKRico@Uy>ppD)u2R@PS^#$wdB5VeLy9(!aoZ3THD>j*2AYZM1)x*+8Q|| zR=E*{4T8SUPf4(%#IlOT(wd+K=Sp6F5CrMVVd5}TVpo&@a>5(#E@P?19+0raC%;@q9Xr;U6!*`h;ZJwOZJm=f_d#sS?sqf3O2)u6p11P~7^;oc#Qp>gMgl_q&Xp zy<)BIXC&0Bhb11<$x@guwe$lUzy2d=GJW0fqs%2~2HH0If8?RYe(b6v=Si?W}3cYf_OE zf}?lxfUM3RPKnbUTiJ+-u6RPKx?IqkeY z>@3Ll(bB(j_!dcX;k1&-=-zu57w}U1zS8$%=5*|S|69epAe-T-a+M^yKoKowZU}|g zbuB8bjXVl42PjaQ@<9zAGK385iA%$Y);jeySayxP@T02y>*ZyG?gXrd z7sjmHG-R~{8}$si1Dj_I9TfDWA;myx&O+^AP0WWS!y2gioVygJ#!VreD`Y zl{Zw>mt((`F5Oc9%3v00Q*Tl4R_}nbXkEfRr_-R_xX2Zro|&E?pC!K{-{Ivl3|Y=_ zOff0+M+{55Fr!+T%FY~Uw3|!#vih!3zSOp>c+n<$gY}@OoCkp|x?Y>ilx(vu{<|B# zAhZ>WDe3EL4%9n|grNRn=~qs9YU{E%eUlM=VSNj)obdQ?YJj3xhpL)fxEvOTUB|?x zqo2v61&?QH!_&(?EJH{yeho>#YVMKT+826>D%P+y(=b40Gi7tr;8MTTI9IWx;jtHZ zwYmK``Eb$e-1@P_`|4=B&{6g_mg-oMlviZd*2-$3rapgL@4C1l`s>HXqT__6>+Ymv zg^%y7MoY3Ji2D$Eoz1+@83SHD&LXej-S(_xQU-rc=6to+brflYd z)z8;c={4QMmKW>g8mzrERP;-+q3pzYquavGhs*583i<4QVHW}$qxVG|@L*Iu1voDx~NP?C_lBORz_sd9@3xtc7ozQjW;N5y`68uLWx%7bZWv0f+zF=eNA!PbJ|aX zXicgMrWZ+>9wkLF0*v&{BBE?Hob4l2eElc-K@Sp7AZ|1vZ@W7p4SnbzRx_}DOcSMH+jJM#S6>5i_E@* zk@a}oQU(7O5>JGc zoVm|(04_5MWYaPKK=|rm|16Ibm`CjCkvCC#h^gFA;?O8sov_2sB0(KuKQ)M8ZQ{`C zkq;2wXgRtTWnWXg(flsA56$33PT@wU^o>8TZ5#1du;}r`<5)!X2P@oYHs$c@MeIdb z{!mNy2k1ozUA^%Afn8+VHo(EP5a^H;!>!)J7u04%LuY^iqLUpwq?l9}d)J;t!X0m7 zety8Hurq_`(I@?yYHO7PtVYFZv-Ot9LzznQbpcSaA|~q97JMX$zT&&$*j>A=W>LHy zUOHY=iZO}7&ZHO%>>+iR*PkBRjm`j{oZb1|sONx=6O3Jp7Gl}Ln z1N<=UAk;r?V|v>_X?_c;@qul(vtg($54Ba@F9%kZ7n_clr!K@vH9+(aRVpXa&1$fM zS2RW5CCQVy#7d?O^_bV>=ZJa5?cBBn$wH8H$mh!(@Epk$fK zm^kbjB?=Udb3cg>yBH=PIN5v&nhmJLiJBc=G`k5jOHu{Fn`TLz%nmM_Y_UX+vQW6a z9_kbuEaMY1pY;TXNSP}r(Zwx|ntCFoCiQ}*(uSrHQ$?uFKLp4%#Ou8BZ=MlFrsYj} zkIZJ3+bai-3Bws36$_J7@-Tm&TEwsX!zDvL+JSa+FK4uAX=om#V19(kW?t77;S>jA z5YPsEIzJvQqKp$#F+V>x&67JXGYx#Po*h3lGx(hZfqhvTObHXe35eBkr=^UMk`rqD zPR+*{HN*nY>cORDj52kZ=)s0JPZ0^W%MLlG867vr{t97RC#1#+9E8NlC4TiQwn&J% zq*#!6EiDH04 zNPyGpFx7lgbmNfgqk&M=5k?!a_wgLDp+Mj2)cLhkc`A*)E+WY?2D^z+LLze)Q+TX3 zQX_$NSwxFC!mrYDY(}YwJr#7Q-UDKYXqgkDNDvt$xG`EOswSLdTidSG$1$?K$1Avp zedv{=PG}rPGo$~VDu151{cEb6o0I*YsdC=`o#F*OPfP#5A<%kpWA*_+)c#YC4Aks-9rR73-3ifJ&q}Ci`zGP>TXDKtY}$3Q$GR%i{RvXTW@%W*ZHRZY2s{) zdwA9+X4K7zO{1i>Z&m@hppzz~xIyQ*SrN?o?Q$_w6TjJG<_03jl&o%jTl%w2uB`N~ zEtrPeGpq2c(@}S^@qwrPX_>7jACK+vMxufZ$%&ol!34at@rF5Wi$T6u!ZzDW2 zUt2I8`>qs@&? zZ;0u=U|2_WiVY10tx@=?!}cvsN+8;${A|ca+%0frBmErHt@{$WTNJ2VL4UE0y@pwI zZAo8olUncBH41J@-`xN9_U}IYn~dQ7r?)}R7{cE(`hPfrXVO8;%-IBNxEG(bEil=OTOR5kvI}HHOpX;9uslScqCO**$W^90GpyA(u z!!yzGzwY>xor^K#H_f5;git)Y>bELl(iB{{P%<;`UG97+Bd`09fBz*$dk{Tm3EkF8ZyNDa6wGi5g)CJvZau3p+0m@Vx$e z;bP|k{J*S!uHj+-?>zsV*MA>#@jQb!|6cx-{%6bi^e+)0fRp#Be2%9**AqDNyytyx z==0H2T|v)}9j>QA@O;F}!SkeZJlSl2o#)_wep+z=LBDm*CphLgANHs9FHGhMcX3d+ zw|f301Mr;Hzx(umhxjv(^OT%Pfqe^D=>e-k}QU z=uM2QzzzTq3y6ya2%xuwKpdT4vp%6iPpSpj$kEcu#F@nbZ1L*PO(6R8kqE@f!Cvec zJE9kR%?4!S0CEFCpeIfV#HIb}-{t*ZT4yqH}{h>o__%Qzce5_7uQoy|Ch$e`Q-2a)PO+HQy>4Q z#{Qi9e`+8$cHmR>|0^92&(n?d9~ub633}?q|JI&I-G6FqK;VDZ1;hqq|L=By*m$1C z_kUmeJd*y=AYh}XdmjAz+a6UbPqSxVJwM%492}nd`w!~$Y)jjlKXvQxA@PS+G6F+> Sdl$t1 Date: Fri, 12 Jul 2024 10:25:13 +0300 Subject: [PATCH 008/171] Removed from tracking and updated tests --- examples/dev/LLaMAMLP.py | 35 +- examples/dev/conv2d_relu.py | 30 ++ examples/dev/log.out | 795 ------------------------------------ examples/dev/sdpa.py | 28 ++ examples/dev/sdpa_slow.py | 82 ++++ examples/dev/simple.py | 39 +- examples/dev/test_del.py | 33 ++ examples/dev/transformer.py | 15 + 8 files changed, 240 insertions(+), 817 deletions(-) create mode 100644 examples/dev/conv2d_relu.py delete mode 100644 examples/dev/log.out create mode 100644 examples/dev/sdpa.py create mode 100644 examples/dev/sdpa_slow.py create mode 100644 examples/dev/test_del.py create mode 100644 examples/dev/transformer.py diff --git a/examples/dev/LLaMAMLP.py b/examples/dev/LLaMAMLP.py index 9a3418a447..15493f86ed 100644 --- a/examples/dev/LLaMAMLP.py +++ b/examples/dev/LLaMAMLP.py @@ -1,5 +1,6 @@ import torch import thunder +import time class LLaMAMLP(torch.nn.Module): def __init__(self, n_embd, intermediate_size) -> None: @@ -14,15 +15,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.proj(x) with torch.device('cuda'): - model = LLaMAMLP(4096, 11008) - x = torch.randn(2, 2048, 4096, requires_grad=True) + a = 4096 * 3 + b = 11008 * 3 + model = LLaMAMLP(a, b) + x = torch.randn(2, 2048, a, requires_grad=True) jmodel = thunder.jit(model) - ans = jmodel(x) - print('---------------------------------------------- all traces') - for t in thunder.last_traces(jmodel): - print(t) - print('##############################################') - print('---------------------------------------------- ans') - print(ans) + tot_time = 0 + iters = 12 + for i in range(iters): + start = time.time_ns() + ans = jmodel(x) + torch.cuda.synchronize() + end = time.time_ns() + + # Skip the model without cache + if i > 1: + tot_time += (end - start) + print(f'tot time = {(end - start) / 1000000} ms') + + + # for t in thunder.last_traces(jmodel): + # print(t) + print(thunder.last_traces(jmodel)[-1]) + print(thunder.last_backward_traces(jmodel)[-1]) + print(f'Mean time = {(tot_time/(iters-2))/1000000} ms') + + print('deviation:', (jmodel(x) - model(x)).abs().max().item()) diff --git a/examples/dev/conv2d_relu.py b/examples/dev/conv2d_relu.py new file mode 100644 index 0000000000..2089184082 --- /dev/null +++ b/examples/dev/conv2d_relu.py @@ -0,0 +1,30 @@ +import torch +import thunder + +class Module(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1) -> None: + super().__init__() + self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride) + self.relu = torch.nn.ReLU() + + def forward(self, x: torch.Tensor): + a = self.conv2d(x) + b = self.conv2d(x) + c = self.conv2d(x + x) + d = self.relu(b * a) + return c + d + +with torch.device('cuda'): + model = Module(16, 33, 3, stride=2) + x = torch.randn(20, 16, 50, 100) + + jmodel = thunder.jit(model) + + ans = jmodel(x) + # print('---------------------------------------------- all traces') + # for t in thunder.last_traces(jmodel): + # print(t) + # print('##############################################') + # print('---------------------------------------------- ans') + # print(ans) + diff --git a/examples/dev/log.out b/examples/dev/log.out deleted file mode 100644 index 1cfe6efed1..0000000000 --- a/examples/dev/log.out +++ /dev/null @@ -1,795 +0,0 @@ -============================================ START: LABEL default -============================================ START: computation_trc split_forward_backward -# Constructed by Dead Code Elimination (took 0 milliseconds) -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - - # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" - # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" - # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" - # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" - # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" - # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" - result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - return t18 -============================================ END: computation_trc split_forward_backward -============================================ START: primal_trace sort_data_parallel_syncs -# Constructed by Dead Code Elimination (took 0 milliseconds) -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - - # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" - # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" - # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" - # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" - # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" - # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" - result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - return t18 -============================================ END: primal_trace sort_data_parallel_syncs -============================================ START: augmented_forward_pass -result -t8 -env -{'a': VJPDual(primal=t6, - residuals=((t0, False), None, ([], [t3], [], [t5], [t5, t0]))), - 'result': VJPDual(primal=t7, residuals=((t6, t1), None, ([t1, t6],))), - 't18': VJPDual(primal=t8, - residuals=((t7, t_proj_weight, None), - None, - ([t_proj_weight, t7],))), - 't_fc_1_weight': VJPDual(primal=t_fc_1_weight, residuals=()), - 't_fc_2_weight': VJPDual(primal=t_fc_2_weight, residuals=()), - 't_proj_weight': VJPDual(primal=t_proj_weight, residuals=()), - 'x': VJPDual(primal=x, residuals=()), - 'x_fc_1': VJPDual(primal=t0, - residuals=((x, t_fc_1_weight, None), - None, - ([t_fc_1_weight, x],))), - 'x_fc_2': VJPDual(primal=t1, - residuals=((x, t_fc_2_weight, None), - None, - ([t_fc_2_weight, x],)))} -============================================ END: augmented_forward_pass -============================================ START: primal_trace forward_and_backward_from_trace -# Constructed by Augmented forward pass -import thunder -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - t4 = ltorch.add(1.0, t3, alpha=None) # t4: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - t6 = ltorch.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - t7 = ltorch.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) -============================================ END: primal_trace forward_and_backward_from_trace -============================================ START: before forward_trc transform_for_execution -# Constructed by Augmented forward pass -import thunder -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - t4 = ltorch.add(1.0, t3, alpha=None) # t4: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - t6 = ltorch.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - t7 = ltorch.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) -============================================ END: after forward_trc transform_for_execution -============================================ START: LABEL forward_trc -============================================ START: before _transform_for_operator_executor_execution -# Constructed by Dead Code Elimination (took 0 milliseconds) -import thunder -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - t4 = ltorch.add(1.0, t3, alpha=None) # t4: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - t6 = ltorch.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - t7 = ltorch.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) -============================================ END: before _transform_for_operator_executor_execution -============================================ START: after _transform_for_operator_executor_execution -# Constructed by Transform for operator executor execution (took 0 milliseconds) -import thunder -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -import torch.nn.functional -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - t4 = ltorch.add(1.0, t3, alpha=None) # t4: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - t6 = ltorch.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - t7 = ltorch.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) -============================================ GRAPH: _transform_for_operator_executor_execution -graph roots: 0, 1, 2, 3, -traversal nodes: -node ID 0 : [# x: "cuda:0 f32[2, 2048, 4096]"] - parents ids: - children ids: 13, 4, 5, -node ID 1 : [# t_fc_1_weight: "cuda:0 f32[11008, 4096]"] - parents ids: - children ids: 4, 13, -node ID 2 : [# t_fc_2_weight: "cuda:0 f32[11008, 4096]"] - parents ids: - children ids: 13, 5, -node ID 3 : [# t_proj_weight: "cuda:0 f32[4096, 11008]"] - parents ids: - children ids: 12, 13, -node ID 13 : [return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ())] - parents ids: 0, 1, 2, 3, 4, 5, 7, 9, 10, 11, 12, - children ids: -node ID 4 : [t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 0, 1, - children ids: 10, 13, 6, -node ID 5 : [t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 0, 2, - children ids: 11, 13, -node ID 12 : [t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]"] - parents ids: 3, 11, - children ids: 13, -node ID 10 : [t6 = ltorch.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 9, 4, - children ids: 11, 13, -node ID 6 : [t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 4, - children ids: 7, -node ID 11 : [t7 = ltorch.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 10, 5, - children ids: 12, 13, -node ID 7 : [t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 6, - children ids: 8, 13, -node ID 8 : [t4 = ltorch.add(1.0, t3, alpha=None) # t4: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 7, - children ids: 9, -node ID 9 : [t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 8, - children ids: 10, 13, - -============================================ END: after _transform_for_operator_executor_execution -============================================ START: after fusion_pass -# Constructed by Fusion (took 1 milliseconds) -import torch -import torch.nn.functional -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - [t3, t5, t6, t7] = nvFusion0(t0, t1) - # t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - # t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - # t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) -============================================ END: after fusion_pass -============================================ START: after _transform_for_operator_executor_execution (always) -# Constructed by Transform for operator executor execution (took 0 milliseconds) -import torch -import torch.nn.functional -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - [t3, t5, t6, t7] = nvFusion0(t0, t1) - # t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - # t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - # t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) -============================================ END: after _transform_for_operator_executor_execution (always) -============================================ START: LABEL backward_trc ----------------------------------------------- all traces -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - - # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" - # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" - # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" - # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" - # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" - # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" - result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - return t18 -############################################## -# Constructed by Dead Code Elimination (took 0 milliseconds) -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - - # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" - # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" - # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" - # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" - # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" - # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" - result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - return t18 -############################################## -# Constructed by Dead Code Elimination (took 0 milliseconds) -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - - # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" - # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" - # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" - # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" - # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" - # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" - result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - return t18 -############################################## -# Constructed by Augmented forward pass -import thunder -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - t4 = ltorch.add(1.0, t3, alpha=None) # t4: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - t6 = ltorch.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - t7 = ltorch.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) -############################################## -# Constructed by Transform for execution (took 18 milliseconds) -import torch -import torch.nn.functional -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - [t3, t5, t6, t7] = nvFusion0(t0, t1) - # t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - # t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - # t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) -############################################## -# Constructed by Update Call Context (took 0 milliseconds) -import torch -import torch.nn.functional -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - [t7] = nvFusion0(t0, t1) - # t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - # t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - # t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((t0, t1, t7, t_fc_1_weight, t_fc_2_weight, t_proj_weight, x), ()) -############################################## -# Constructed by Delete Last Used (took 0 milliseconds) -import torch -import torch.nn.functional -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - [t7] = nvFusion0(t0, t1) - # t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - # t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - # t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((t0, t1, t7, t_fc_1_weight, t_fc_2_weight, t_proj_weight, x), ()) -############################################## -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - - # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" - # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" - # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" - # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" - # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" - # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" - result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - return t18 -############################################## -# Constructed by Dead Code Elimination (took 0 milliseconds) -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - - # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" - # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" - # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" - # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" - # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" - # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" - result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - return t18 -############################################## -# Constructed by Dead Code Elimination (took 0 milliseconds) -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - - # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" - # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" - # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" - # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" - # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" - # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" - result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - return t18 -############################################## -# Constructed by Augmented forward pass -import thunder -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - t4 = ltorch.add(1.0, t3, alpha=None) # t4: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - t6 = ltorch.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - t7 = ltorch.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) -############################################## -# Constructed by Transform for execution (took 18 milliseconds) -import torch -import torch.nn.functional -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - [t3, t5, t6, t7] = nvFusion0(t0, t1) - # t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - # t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - # t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) -############################################## -# Constructed by Update Call Context (took 0 milliseconds) -import torch -import torch.nn.functional -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - [t7] = nvFusion0(t0, t1) - # t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - # t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - # t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((t0, t1, t7, t_fc_1_weight, t_fc_2_weight, t_proj_weight, x), ()) -############################################## -# Constructed by Delete Last Used (took 0 milliseconds) -import torch -import torch.nn.functional -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - [t7] = nvFusion0(t0, t1) - # t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - # t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - # t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((t0, t1, t7, t_fc_1_weight, t_fc_2_weight, t_proj_weight, x), ()) -############################################## ----------------------------------------------- ans -tensor([[[-0.0110, 0.0542, 0.0908, ..., -0.2110, 0.2082, 0.0331], - [ 0.4479, 0.0610, -0.0296, ..., 0.0564, -0.0904, -0.0877], - [ 0.1233, 0.0146, -0.1153, ..., 0.1049, 0.0266, -0.0702], - ..., - [-0.0202, 0.0180, -0.0293, ..., -0.0630, 0.1042, 0.0283], - [-0.0221, -0.0508, 0.1574, ..., 0.1687, -0.0135, 0.0040], - [ 0.0555, 0.0216, 0.2707, ..., -0.0414, 0.1786, -0.2664]], - - [[-0.0476, -0.1409, -0.0704, ..., 0.0162, 0.0102, -0.0570], - [-0.1471, 0.0132, -0.2057, ..., 0.0787, -0.0048, -0.0167], - [-0.0957, -0.1662, -0.0485, ..., 0.0173, 0.0265, -0.0916], - ..., - [-0.0264, -0.0388, -0.2041, ..., 0.0679, 0.0027, -0.0122], - [ 0.0222, 0.0553, -0.2055, ..., 0.0905, 0.1831, 0.0558], - [ 0.0816, -0.0930, 0.0024, ..., -0.1418, -0.0122, -0.0344]]], - device='cuda:0', grad_fn=) diff --git a/examples/dev/sdpa.py b/examples/dev/sdpa.py new file mode 100644 index 0000000000..a4d1d35328 --- /dev/null +++ b/examples/dev/sdpa.py @@ -0,0 +1,28 @@ +import torch +import thunder + +class Module(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, query, key, value): + query = query + query + key = key * key + a = torch.nn.functional.scaled_dot_product_attention(query, key, value) + return a + + +with torch.device('cuda'): + module = Module() + j_module = thunder.jit(module) + + query = torch.rand(32, 8, 128, 64, dtype=torch.float16) + key = torch.rand(32, 8, 128, 64, dtype=torch.float16) + value = torch.rand(32, 8, 128, 64, dtype=torch.float16) + + ans = j_module(query, key, value) + + print(thunder.last_traces(j_module)[-1]) + + + diff --git a/examples/dev/sdpa_slow.py b/examples/dev/sdpa_slow.py new file mode 100644 index 0000000000..c9d4381757 --- /dev/null +++ b/examples/dev/sdpa_slow.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import thunder +import math + +class ModelConfig: + def __init__(self, n_embd=256, n_head=8, dropout=0.1, block_size=64, bias=True): + self.n_embd = n_embd + self.n_head = n_head + self.dropout = dropout + self.bias = bias + self.block_size = block_size + +class Module(nn.Module): + def __init__(self, config): + """ + My implementation of NanoGPT Causal Self Attention module for PyTorch. + + Args: + - config: Configuration object containing parameters for the attention module. + """ + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # regularization + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + self.n_head = config.n_head + self.n_embd = config.n_embd + self.dropout = config.dropout + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) + .view(1, 1, config.block_size, config.block_size)) + + def forward(self, x): + """ + Forward pass of the Causal Self Attention module. + + Args: + - x: Input tensor. + + Returns: + - torch.Tensor: Output tensor after self-attention. + """ + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # manual implementation of attention + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + +with torch.device('cuda'): + config = ModelConfig(n_embd = 1536) + module = Module(config) + j_module = thunder.jit(module) + + batch_size, sequence_length, embedding_dim = 8, 16, config.n_embd + x = torch.randn((batch_size, sequence_length, embedding_dim)) + + ans = j_module(x) + + print(thunder.last_traces(j_module)[-1]) + + + diff --git a/examples/dev/simple.py b/examples/dev/simple.py index a2ceffcf7e..631caa8a89 100644 --- a/examples/dev/simple.py +++ b/examples/dev/simple.py @@ -1,26 +1,39 @@ import torch import thunder +import time class Module(torch.nn.Module): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - + def __init__(self, in_features, out_features) -> None: + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + self.silu = torch.nn.SiLU() def forward(self, x: torch.Tensor): a = x + x - return a + # a_silu = self.silu(a) + b: torch.Tensor = self.linear(a) + c = b * b + # c_silu = self.silu(c) + d = c + c + return d with torch.device('cuda'): - model = Module() - x = torch.randn(2, 2) + multiplier = 10 + in_features = 20 * multiplier + out_features = 30 * multiplier + model = Module(in_features, out_features) + x = torch.randn(128, in_features) jmodel = thunder.jit(model) - ans = jmodel(x) - print('---------------------------------------------- all traces') - for t in thunder.last_traces(jmodel): - print(t) - print('##############################################') - print('---------------------------------------------- ans') - print(ans) + for _ in range(100): + start = time.time_ns() + ans = jmodel(x) + end = time.time_ns() + # print('---------------------------------------------- all traces') + # for t in thunder.last_traces(jmodel): + # print(t) + # print('##############################################') + + print(f'tot time = {(end - start) / 1000000} ms') diff --git a/examples/dev/test_del.py b/examples/dev/test_del.py new file mode 100644 index 0000000000..71aab23cb0 --- /dev/null +++ b/examples/dev/test_del.py @@ -0,0 +1,33 @@ +import torch +import time + +iters = 1000 + +with torch.device('cuda'): + + tot_time = 0 + for i in range(iters): + s = time.time_ns() + a = torch.randn(2, 2048, 4096 // 1, requires_grad=True) + b = torch.randn(2, 2048, 4096 // 1, requires_grad=True) + c = a + b + a + b + c = c * c + del a + del b + del c + torch.cuda.synchronize() + tot_time += (time.time_ns() - s) + + print(f"With del = {(tot_time / iters) / 1000000}") + + tot_time = 0 + for i in range(iters): + s = time.time_ns() + a = torch.randn(2, 2048, 4096 // 1, requires_grad=True) + b = torch.randn(2, 2048, 4096 // 1, requires_grad=True) + c = a + b + a + b + c = c * c + torch.cuda.synchronize() + tot_time += (time.time_ns() - s) + + print(f"With no del = {(tot_time / iters) / 1000000}") diff --git a/examples/dev/transformer.py b/examples/dev/transformer.py new file mode 100644 index 0000000000..2a4c5d4be6 --- /dev/null +++ b/examples/dev/transformer.py @@ -0,0 +1,15 @@ + +import torch +import thunder + +with torch.device('cuda'): + transformer_model = torch.nn.Transformer(nhead=16, num_encoder_layers=12) + src = torch.rand((10, 32, 512)) + tgt = torch.rand((20, 32, 512)) + out = transformer_model(src, tgt) + print(out) + + + jmodel = thunder.jit(transformer_model) + out = jmodel(src, tgt) + From f09f9b3ce6a24908a30bca62a9d23f09ec0a2efe Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Fri, 12 Jul 2024 10:32:24 +0300 Subject: [PATCH 009/171] Removed import --- thunder/backend_optimizer/optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index ba4213701d..a198ceda3a 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -1,7 +1,6 @@ from typing import Any, Hashable import torch import thunder -from thunder.clang import sub from thunder.core.baseutils import BoundSymbolInterface from thunder.core.utils import check, safe_map_flat from thunder.core.proxies import Proxy, TensorProxy, variableify, Variable From 17df099a04861d4a8ad13c628819cd5d5ec82dcc Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Fri, 12 Jul 2024 12:44:34 +0300 Subject: [PATCH 010/171] Change timing function --- thunder/executors/passes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index 70d50f847e..b37298dd8f 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -142,7 +142,7 @@ def autotune_transform_for_execution(trace: TraceCtx, executors_list: Sequence[E # Recover the function name sig_name = cutils.get_siginfo_name(trace) - start_time_ns = time.time_ns() + start_time_ns = time.perf_counter_ns() if torch.distributed.is_available(): # Apply AllReduce bucketing if possible & needed @@ -157,7 +157,7 @@ def autotune_transform_for_execution(trace: TraceCtx, executors_list: Sequence[E backend_optimizer.benchmark_traces() extrace = backend_optimizer.get_optimal_trace() - end_time_ns = time.time_ns() + end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns elapsed_time_millis = elapsed_time_ns // 1000000 From bdca10b5c5329edeaefb18bda70b49f573f99357 Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Fri, 12 Jul 2024 12:45:23 +0300 Subject: [PATCH 011/171] Cleaned impl and updated debug files generation --- thunder/backend_optimizer/optimizer.py | 499 ++++--------------------- 1 file changed, 76 insertions(+), 423 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index a198ceda3a..9372885988 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -1,19 +1,19 @@ -from typing import Any, Hashable -import torch -import thunder +from collections.abc import Callable, Sequence +from enum import Enum +from itertools import chain from thunder.core.baseutils import BoundSymbolInterface -from thunder.core.utils import check, safe_map_flat from thunder.core.proxies import Proxy, TensorProxy, variableify, Variable from thunder.core.symbol import BoundSymbol, Symbol -from thunder.executors.data_dependent_partition import Graph, Node from thunder.core.trace import from_trace, set_tracectx, reset_tracectx, get_tracectx, TraceCtx +from thunder.core.utils import check, safe_map_flat +from thunder.executors.data_dependent_partition import Graph, Node from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_always_executors -import thunder.core.transforms as transforms from thunder.visualizer.visualizer_helper import Visualizer -from collections.abc import Callable, Sequence -from enum import Enum -from itertools import chain +from typing import Any, Hashable +import thunder +import thunder.core.transforms as transforms import time +import torch # import concurrent.futures class OptimizerNode(): @@ -28,7 +28,7 @@ class BackendOptimizer(): def log(self, what: str): print(f'================================================================================ Autotune: {what}') - def __init__(self, trace: TraceCtx, executors: Sequence[Executor], produce_log=True, log_file_name='autotune_traces_computation_time.log', visualizer: Visualizer | None = None) -> None: + def __init__(self, trace: TraceCtx, executors: Sequence[Executor], produce_log=True, log_file_name='autotune_debug.log', visualizer: Visualizer | None = None) -> None: self.trace: TraceCtx = trace self.incremental_search_out_trace: TraceCtx self.optimal_trace: TraceCtx = trace @@ -41,7 +41,7 @@ def __init__(self, trace: TraceCtx, executors: Sequence[Executor], produce_log=T self.always_executors: tuple[Executor, ...] = get_always_executors() self.produce_log: bool = produce_log self.log_file_name: str = log_file_name - self.log_str: str = "" + self.debug_msg: str = "" self.visualizer: Visualizer | None = visualizer self.partial_costs: dict[TraceCtx, float] = {} @@ -61,19 +61,17 @@ def write(self, file_name): file.write(s) file.close() + class SearchNode: + def __init__(self, symbol: BoundSymbolInterface, idx: int)-> None: + self.symbol = symbol + self.idx = idx + # TODO (matteochen): this has a lot in common with the exaustive search, compact them def build_placement_options_incremental(self): import sys old_max_recursion = sys.getrecursionlimit() - # TODO (matteochen): parametrize this - sys.setrecursionlimit(20000) - - - class SearchNode: - def __init__(self, symbol: BoundSymbolInterface, idx: int)-> None: - self.symbol = symbol - self.idx = idx + sys.setrecursionlimit(2000) # Last index inclusive def benchmark_partial_trace(trace_in: TraceCtx, last_idx: int, configuration: list[Executor]) -> tuple[float, TraceCtx]: @@ -99,9 +97,6 @@ def benchmark_partial_trace(trace_in: TraceCtx, last_idx: int, configuration: li t.bound_symbols.append(forced_return_bsym) configuration.append(Executor(name=self.empty_executor_hashable_placeholder)) # Empty executor for the forced_return - # self.log(f'Debug\n{len(t.bound_symbols)}\n{len(exs)}') - # self.log(f'Debug\n{(t)}\n') - # Place the assigned symbols placed_t = self.place_optimizers(t, configuration) @@ -116,13 +111,13 @@ def benchmark_partial_trace(trace_in: TraceCtx, last_idx: int, configuration: li return cost, placed_t # We assign an internal id to each symbol based on its idx inside the bound_symbols list - def search(node: SearchNode, configuration: list[Executor]): + def search(node: self.SearchNode, configuration: list[Executor]): - def continue_search(time_inc: float): + def continue_search(): if node.idx+1 < max_len: new_idx: int = node.idx + 1 new_symbol: BoundSymbolInterface = bound_symbols[new_idx] - search(SearchNode(new_symbol, new_idx), configuration) + search(self.SearchNode(new_symbol, new_idx), configuration) else: all_configurations.append(configuration) @@ -132,46 +127,37 @@ def continue_search(time_inc: float): ex: Executor # TODO (matteochen): do parallel for for ex in self.executors: + cost = float('inf') if not isinstance(node.symbol, BoundSymbol): raise AssertionError("Receive a symbol which is not a BoundSymbol") if (isinstance(ex, OperatorExecutor) and ex.can_execute(node.symbol)): - # print(f'{node.idx}-{ex._name} can execute symbol {node.symbol.sym.name}') - # safe_update_dict(node.idx, ExecutorType.OPERATOR, ex) has_backend = True configuration.append(ex) - cost, extrace = benchmark_partial_trace(self.trace, node.idx, list(configuration)) + cost, _ = benchmark_partial_trace(self.trace, node.idx, list(configuration)) configuration.pop() - if cost < min_cost: - min_cost = cost - min_cost_ex = ex - if (isinstance(ex, FusionExecutor) and ex.can_fuse(node.symbol)): - # print(f'{node.idx}-{ex._name} can fuse symbol {node.symbol.sym.name}') - # safe_update_dict(node.idx, ExecutorType.FUSER, ex) has_backend = True configuration.append(ex) - cost, extrace = benchmark_partial_trace(self.trace, node.idx, list(configuration)) + cost, _ = benchmark_partial_trace(self.trace, node.idx, list(configuration)) configuration.pop() - if cost < min_cost: - min_cost = cost - min_cost_ex = ex + if cost < min_cost: + min_cost = cost + min_cost_ex = ex if not has_backend: configuration.append(empty_executor) - continue_search(0.0) + continue_search() else: if min_cost_ex is None: raise AssertionError("Unexpected min cost executor or trace: None") self.log(f'For id: {node.idx} - {node.symbol.sym.name} -> best backend {min_cost_ex.name}\n') - # log_min_cost_trace(min_cost_trace) configuration.append(min_cost_ex) - continue_search(min_cost) + continue_search() - # res: dict[int, dict[ExecutorType, list[Executor]]] = {} bound_symbols: list[BoundSymbolInterface] = self.trace.bound_symbols max_len = len(bound_symbols) @@ -180,202 +166,17 @@ def continue_search(time_inc: float): empty_executor = Executor(name=self.empty_executor_hashable_placeholder) if len(bound_symbols) > 0: - search(SearchNode(bound_symbols[0], 0), []) + search(self.SearchNode(bound_symbols[0], 0), []) self.placement_options = all_configurations sys.setrecursionlimit(old_max_recursion) - # TODO (matteochen): this has a lot in common with the exaustive search, compact them - # def build_placement_options_incremental(self): - # class SearchNode: - # def __init__(self, symbol: BoundSymbolInterface, idx: int)-> None: - # self.symbol = symbol - # self.idx = idx - - # def retrieve_executors_from_trace(trace_in: TraceCtx, last_symbol_idx:int = -1) -> list[Executor]: - # executors: list[Executor] = [] - # if last_symbol_idx == -1: - # last_symbol_idx = len(trace_in.bound_symbols) - # for i in range(last_symbol_idx): - # if not isinstance(trace_in.bound_symbols[i], BoundSymbol): - # raise AssertionError('Expected BoundSymbol but received BoundSymbolInterface') - # s = trace_in.bound_symbols[i] - # if s.sym.executor is None: - # executors.append(empty_executor) - # else: - # executors.append(s.sym.executor) - # return executors - - # # Last index inclusive - # def benchmark_partial_trace(trace_in: TraceCtx, last_idx: int, new_ex: Executor) -> tuple[float, TraceCtx, Any]: - - # exs: list[Executor] = retrieve_executors_from_trace(trace_in, last_idx) - # # for i in range(last_idx): - # # if not isinstance(trace_in.bound_symbols[i], BoundSymbol): - # # raise AssertionError('Expected BoundSymbol but received BoundSymbolInterface') - # # s = trace_in.bound_symbols[i] - # # if s.sym.executor is None: - # # exs.append(empty_executor) - # # else: - # # exs.append(s.sym.executor) - # exs.append(new_ex) - - # # Retrive all output tensors from each subregion - # tensors = [] - # for i in range(last_idx+1): - # if not isinstance(trace_in.bound_symbols[i], BoundSymbol): - # raise AssertionError('Expected BoundSymbol but received BoundSymbolInterface') - # s = trace_in.bound_symbols[i] - # # For each bsym region we expect to output a Tensor - # tensors.append(s.output) - # # print('Tensors inside partial trace') - # # for t in tensors: - # # print(t) - - # forced_return_bsym = trace_in.bound_symbols[-1].from_bsym(args=tensors) # Should not be an Interface type at this point - - # t = from_trace(trace_in) - # # Cut the trace to the required depth - # t.bound_symbols = list(trace_in.bound_symbols)[:last_idx+1] - - # t.bound_symbols.append(forced_return_bsym) - # exs.append(Executor(name=self.empty_executor_hashable_placeholder)) # Empty executor for the forced_return - - # # self.log(f'Debug\n{len(t.bound_symbols)}\n{len(exs)}') - # # self.log(f'Debug\n{(t)}\n') - - # # Place the assigned symbols - # placed_t = self.place_optimizers(t, exs) - - # cost, answer = benchmark_trace(placed_t) - # self.log(f'Executing partial trace for incremental benchmark:\n{placed_t}') - # self.log(f'Symbol under test = {t.bound_symbols[-2].sym.name}') - # self.log(f'Assigned executor = {exs[-2].name}') - # self.log(f'Time = {cost/1000000} ms') - # self.partial_costs[t] = cost - # return cost, placed_t, answer - - # # We assign an internal id to each symbol based on its idx inside the bound_symbols list - # def search(node: SearchNode, time_so_far: float): - - # def continue_search(time_inc: float): - # if node.idx+1 < max_len: - # new_idx: int = node.idx + 1 - # new_symbol: BoundSymbolInterface = bound_symbols[new_idx] - # search(SearchNode(new_symbol, new_idx), time_so_far + time_inc) - # else: - # all_configurations.append(retrieve_executors_from_trace(self.incremental_search_out_trace)) - # self.log(f'Incremental search ended:\n{self.incremental_search_out_trace}\n{all_configurations[0]}') - - # # def safe_update_dict(idx: int, type: ExecutorType, ex: Executor): - # # if idx not in res: - # # res[idx] = {} - # # res[node.idx][type] = [ex] - # # else: - # # if type not in res[idx]: - # # res[node.idx][type] = [ex] - # # else: - # # res[node.idx][type].append(ex) - - # def log_min_cost_trace(trace: TraceCtx): - # self.log(f'Min cost trace:\n{trace}') - # b: BoundSymbol - # for b in trace.bound_symbols: - # self.log(f'sym = {b.sym.name} , ex = {b.sym.executor}') - - # def extend_min_cost_trace(trace_in: TraceCtx, idx_from_to_extend: int): - # new_items = list(self.trace.bound_symbols[idx_from_to_extend:]) - # # Remove the mock return statement - # trace_in.bound_symbols.pop() - # trace_in.bound_symbols.extend(new_items) - - # def update_self_trace(trace_in: TraceCtx): - # self.incremental_search_out_trace = from_trace(trace_in) - # self.incremental_search_out_trace.bound_symbols = list(trace_in.bound_symbols) - - # has_backend = False - # min_cost = float('inf') - # min_cost_ex = None - # min_cost_trace = from_trace(self.incremental_search_out_trace) - # min_cost_trace.bound_symbols = list(self.incremental_search_out_trace.bound_symbols) - # # self.log(f'New iter, node idx = {node.idx}') - # log_min_cost_trace(min_cost_trace) - - # trace_iter = from_trace(self.incremental_search_out_trace) - # trace_iter.bound_symbols = list(self.incremental_search_out_trace.bound_symbols) - - # # Seach for last placed executor index in min_cost_trace - # idx = 0 - # while idx < len(min_cost_trace.bound_symbols) and not self.bsym_assigned(min_cost_trace.bound_symbols[idx]): - # idx += 1 - # while idx < len(min_cost_trace.bound_symbols) and self.bsym_assigned(min_cost_trace.bound_symbols[idx]): - # idx += 1 - # # With Fusion operators, our trace will be collapsed. If the min_cost_trace is assigned to a trace that comes out from a fusion pass - # # the length of the partial trace (local optimal) to be injected inside benchmark_partial_trace is < node.idx - # idx = min(idx, node.idx) - - # ex: Executor - # for ex in self.executors: - # if not isinstance(node.symbol, BoundSymbol): - # raise AssertionError("Receive a symbol which is not a BoundSymbol") - # if (isinstance(ex, OperatorExecutor) and ex.can_execute(node.symbol)): - # # print(f'{node.idx}-{ex._name} can execute symbol {node.symbol.sym.name}') - # # safe_update_dict(node.idx, ExecutorType.OPERATOR, ex) - # has_backend = True - - # cost, extrace, tensor_out = benchmark_partial_trace(self.incremental_search_out_trace, idx, ex) - - # if cost < min_cost: - # min_cost = cost - # min_cost_ex = ex - # min_cost_trace = extrace - - # if (isinstance(ex, FusionExecutor) and ex.can_fuse(node.symbol)): - # # print(f'{node.idx}-{ex._name} can fuse symbol {node.symbol.sym.name}') - # # safe_update_dict(node.idx, ExecutorType.FUSER, ex) - # has_backend = True - - # cost, extrace, tensor_out = benchmark_partial_trace(self.incremental_search_out_trace, idx, ex) - - # if cost < min_cost: - # min_cost = cost - # min_cost_ex = ex - # min_cost_trace = extrace - - # if not has_backend: - # continue_search(0.0) - # # configuration.pop(-1) - # else: - # if min_cost_ex is None or min_cost_trace is None: - # raise AssertionError("Unexpected min cost executor or trace: None") - # self.log(f'For id: {node.idx} - {node.symbol.sym.name} -> best backend {min_cost_ex.name}') - # if node.idx + 1 < max_len: - # extend_min_cost_trace(min_cost_trace, node.idx+1) - # # log_min_cost_trace(min_cost_trace) - # update_self_trace(min_cost_trace) - # continue_search(min_cost) - - # # Assign search initial trace - # self.incremental_search_out_trace = from_trace(self.trace) - # self.incremental_search_out_trace.bound_symbols = list(self.trace.bound_symbols) - - # # res: dict[int, dict[ExecutorType, list[Executor]]] = {} - # bound_symbols: list[BoundSymbolInterface] = self.trace.bound_symbols - # max_len = len(bound_symbols) - - # all_configurations: list[list[Executor]] = [] - # # Is the name reserved? - # empty_executor = Executor(name=self.empty_executor_hashable_placeholder) - - # if len(bound_symbols) > 0: - # search(SearchNode(bound_symbols[0], 0), 0.0) - # self.placement_options = all_configurations - # This expects a trace after the placement call. # Fusion operators as nvFuser can be slower on the single trace region but can be faster by combining more of them, # try to fuse then and compare def try_to_fuse_after_executors_placement(self, trace_in: TraceCtx) -> TraceCtx: + # Fuser call is expecting a fresh trace with no prior fusion calls hence the fusion regions count will start from 0. We need to override that logic def count_fusion_regions(trace_in: TraceCtx) -> int: count = 0 for bsym in trace_in.bound_symbols: @@ -389,9 +190,6 @@ def count_fusion_regions(trace_in: TraceCtx) -> int: del answer trace_in_time = best_time - # for bsym in trace_in.bound_symbols: - # print(f'subsymbols: {bsym.subsymbols}') - fusion_regions = count_fusion_regions(trace_in) self.log(f'Try to fuse. Fusion regions already present: {fusion_regions}') @@ -413,51 +211,28 @@ def count_fusion_regions(trace_in: TraceCtx) -> int: return best_trace def build_placement_options_exaustive(self): - class ExecutorType(Enum): - OPERATOR = 1 - FUSER = 1 - - class SearchNode: - def __init__(self, symbol: BoundSymbolInterface, idx: int)-> None: - self.symbol = symbol - self.idx = idx # We assign an internal id to each symbol based on its idx inside the bound_symbols list - def search(node: SearchNode, configuration): + def search(node: self.SearchNode, configuration): def continue_search(): if node.idx+1 < max_len: new_idx: int = node.idx + 1 new_symbol: BoundSymbolInterface = bound_symbols[new_idx] - search(SearchNode(new_symbol, new_idx), configuration) + search(self.SearchNode(new_symbol, new_idx), configuration) else: - # print(f'reached end of search for this tree branch {configuration}') all_configurations.append(list(configuration)) - def safe_update_dict(idx: int, type: ExecutorType, ex: Executor): - if idx not in res: - res[idx] = {} - res[node.idx][type] = [ex] - else: - if type not in res[idx]: - res[node.idx][type] = [ex] - else: - res[node.idx][type].append(ex) - ex: Executor has_backend = False for ex in self.executors: if not isinstance(node.symbol, BoundSymbol): raise AssertionError("Receive a symbol which is not a BoundSymbol") if (isinstance(ex, OperatorExecutor) and ex.can_execute(node.symbol)): - # print(f'{node.idx}-{ex._name} can execute symbol {node.symbol.sym.name}') - safe_update_dict(node.idx, ExecutorType.OPERATOR, ex) has_backend = True configuration.append(ex) continue_search() configuration.pop(-1) if (isinstance(ex, FusionExecutor) and ex.can_fuse(node.symbol)): - # print(f'{node.idx}-{ex._name} can fuse symbol {node.symbol.sym.name}') - safe_update_dict(node.idx, ExecutorType.FUSER, ex) has_backend = True configuration.append(ex) continue_search() @@ -468,7 +243,6 @@ def safe_update_dict(idx: int, type: ExecutorType, ex: Executor): continue_search() configuration.pop(-1) - res: dict[int, dict[ExecutorType, list[Executor]]] = {} bound_symbols: list[BoundSymbolInterface] = self.trace.bound_symbols max_len = len(bound_symbols) @@ -477,106 +251,43 @@ def safe_update_dict(idx: int, type: ExecutorType, ex: Executor): empty_executor = Executor(name=self.empty_executor_hashable_placeholder) if len(bound_symbols) > 0: - search(SearchNode(bound_symbols[0], 0), []) + search(self.SearchNode(bound_symbols[0], 0), []) self.placement_options = all_configurations - # def build_placement_options_parallel(self): - # class ExecutorType(Enum): - # OPERATOR = 1 - # FUSER = 1 - - # class SearchNode: - # def __init__(self, symbol: BoundSymbolInterface, idx: int)-> None: - # self.symbol = symbol - # self.idx = idx - - # # We assign an internal id to each symbol based on its idx inside the bound_symbols list - # def search(node: SearchNode, configuration, all_configurations, level = 0): - # def update(): - # # print(f'{node.idx + 1} >= {max_len}, reached end of search for this tree branch (len = {len(configuration)}) {configuration}') - # all_configurations.append(list(configuration)) - - # def safe_update_dict(idx: int, type: ExecutorType, ex: Executor): - # if idx not in res: - # res[idx] = {} - # res[node.idx][type] = [ex] - # else: - # if type not in res[idx]: - # res[node.idx][type] = [ex] - # else: - # res[node.idx][type].append(ex) - - # futures = [] - # with concurrent.futures.ThreadPoolExecutor(max_workers=100) as concurrent_executor: - - # has_backend = False - # new_idx: int = node.idx + 1 - - # if new_idx >= max_len: - # # As this is the last symbol, we expect a return statement by default - # configuration.append(empty_executor) - # update() - # return - - # new_symbol: BoundSymbolInterface = bound_symbols[new_idx] - # new_node = SearchNode(new_symbol, new_idx) - - # ex: Executor - # for ex in self.executors: - - # if not isinstance(node.symbol, BoundSymbol): - # raise AssertionError("Receive a symbol which is not a BoundSymbol") - # if (isinstance(ex, OperatorExecutor) and ex.can_execute(node.symbol)): - # safe_update_dict(node.idx, ExecutorType.OPERATOR, ex) - # has_backend = True - # configuration.append(ex) - # futures.append(concurrent_executor.submit(search, new_node, list(configuration), all_configurations, level+1)) - # configuration.pop(-1) - # if (isinstance(ex, FusionExecutor) and ex.can_fuse(node.symbol)): - # safe_update_dict(node.idx, ExecutorType.FUSER, ex) - # has_backend = True - # configuration.append(ex) - # futures.append(concurrent_executor.submit(search, new_node, list(configuration), all_configurations, level+1)) - # configuration.pop(-1) - - # if not has_backend: - # configuration.append(empty_executor) - # futures.append(concurrent_executor.submit(search, new_node, list(configuration), all_configurations, level+1)) - # configuration.pop(-1) - - # if level == 0: - # concurrent.futures.wait(futures) - - # res: dict[int, dict[ExecutorType, list[Executor]]] = {} - # bound_symbols: list[BoundSymbolInterface] = self.trace.bound_symbols - # bound_symbols_name = [s.sym.name for s in bound_symbols] - # max_len = len(bound_symbols) - - # all: list[list[Executor]] = [] - # # Is the name reserved? - # empty_executor = Executor(name=self.empty_executor_hashable_placeholder) - - # print(f'input trace bound symbols name len {len(bound_symbols_name)}: {bound_symbols_name}') - - # import time - - # if len(bound_symbols) > 0: - # start = time.time_ns() - # search(SearchNode(bound_symbols[0], 0), [], all) - # end = time.time_ns() - # print(f'End of search, tot time = {(end - start)/1000000} ms. Configurations len = {len(all)}') - # self.placement_options = all - # # for config in all_configurations: - # # c_str = [str(c.name) for c in config] - # # c_str = " ".join(c_str) - # # print(c_str) - def place_optimizers(self, in_trace, executor_list: list[Executor]) -> TraceCtx: from thunder.executors.passes import _transform_for_operator_executor_execution swapmap: dict[Variable, Proxy] = {} + # During the fusion pass and CSE optimizatons some args in trace regions could be different from the cached args. Restore the correct arguments + # https://pytorch-lightning.slack.com/archives/C06QA9M8L3C/p1720732254341999 + def restore_correct_args(trace_in: TraceCtx): + + def args_eq(a, b) -> bool: + if len(a) != len(b): + return False + for obj_a, obj_b in zip(a, b): + if type(obj_a) == type(obj_b) and isinstance(obj_a, TensorProxy): + if obj_a.name != obj_b.name: + return False + elif type(obj_a) == type(obj_b) and not isinstance(obj_a, TensorProxy): + if obj_a != obj_b: + raise AssertionError(f'What do you want to do here:\nobj_a:\n{obj_a}\nobj_b:{obj_b}') + return True + + def clear(bsym: BoundSymbol, input): + size = len(bsym.subsymbols) + if size > 0: + for subsym in bsym.subsymbols: + if not args_eq(subsym.args, input): + subsym.args = tuple(list(input)) + clear(subsym, input) + + for bsym in trace_in.bound_symbols: + if isinstance(bsym.sym.executor, OperatorExecutor): + clear(bsym, bsym.args) + def update_swapmap(o: Any, no: Any) -> None: if isinstance(o, Proxy): check( @@ -603,17 +314,9 @@ def visit_helper(bsym: BoundSymbol, ex: Executor) -> None | bool: if bsym.sym.python_impl is not None: return None - # if self.bsym_assigned(bsym): - # return None - # if bsym.sym.executor is not None: - # return None - # We have mapped this at previous stages if ex.name == self.empty_executor_hashable_placeholder: return None - # The call above represent: - # if bsym.sym.executor is not None: - # return None execution_transform: None | Callable = ex.get_execution_transform(bsym.sym) out: Any @@ -649,9 +352,6 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: extrace.bound_symbols = bound_symbols - # self.log(f'Place optimizer, before fusion pass trace:\n{extrace}') - - # proxy_names_to_ignore = set() unique_fusion_executors = set() cached_subsymbols: dict[str, Sequence[BoundSymbol]] = {} @@ -668,19 +368,11 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: # This will leave out these symbols from the fusion pass bsym.subsymbols = [] - # proxy_names_to_ignore.add(t_proxy.name) - - # self.log(f'To ignore:\n{proxy_names_to_ignore}') - - self.log(f'Before fusion pass trace\n{extrace}') - # Perform fusion pass # TODO (matteochen): filter for the current fusion operator as we wanna find the most efficient one for ex in unique_fusion_executors: extrace = ex.fusion_pass(extrace) - self.log(f'After fusion pass trace\n{extrace}') - # Restore subsymbols # TODO (matteochen): Improve this search for k, v in cached_subsymbols.items(): @@ -694,41 +386,14 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: if isinstance(bsym.output, TensorProxy) and bsym.output.name == k: bsym.subsymbols = v + restore_correct_args(extrace) + # Apply always executors extrace = _transform_for_operator_executor_execution(extrace, self.always_executors) - # self.log(f'Place optimizer, after always executors:\n{extrace}') - return extrace - def clear_bad_inputs(self, trace_in: TraceCtx): - - def args_eq(a, b) -> bool: - if len(a) != len(b): - return False - for obj_a, obj_b in zip(a, b): - if type(obj_a) == type(obj_b) and isinstance(obj_a, TensorProxy): - if obj_a.name != obj_b.name: - return False - return True - - def clear(bsym: BoundSymbol, input): - size = len(bsym.subsymbols) - if size > 0: - for subsym in bsym.subsymbols: - if not args_eq(subsym.args, input): - print(f'Sub = {subsym.sym.name} Args = {subsym.args}') - print(f'Got subsymbol {subsym.sym.name} with different inputs from {bsym.sym.name}') - subsym.args = tuple(list(input)) - clear(subsym, input) - - # Solve the issue of nvfuser mismatrch input args - for bsym in trace_in.bound_symbols: - if bsym.sym.executor is not None: - print(f'Calling clear for {bsym.sym.name} Args({type(bsym.args)}) = {bsym.args}\n') - clear(bsym, bsym.args) - # TODO (matteochen): add config for exaustive search or incremental one def optimize(self, strat: OptimizationStrat = OptimizationStrat.GREEDY): import thunder.core.codeutils as cutils @@ -742,20 +407,7 @@ def greedy(): raise AssertionError("Unexpected placement options size") option = self.placement_options[0] - # self.log(f'sym len: {len(self.trace.bound_symbols)} options len = {len(option)}') - # self.log(f'Trace to optimize\n{self.trace}') - # self.log('Chosen options:') - # for s, o in zip(self.trace.bound_symbols, option): - # print(f'{s.sym.name} -> {o.name}') - # Place the assigned executors - self.log(f'Placing optimizers for greedy trace:\n{self.trace}') - for s, o in zip(self.trace.bound_symbols, option): - print(f'{s.sym.name} -> {o.name}') trace_greedy = self.place_optimizers(self.trace, option) - self.log(f'Greedy trace:\n{trace_greedy}') - - self.clear_bad_inputs(trace_greedy) - # Append the unique trace self.optimized_traces.append({'greedy': trace_greedy}) @@ -800,16 +452,16 @@ def exaustive(): def get_optimal_trace(self) -> TraceCtx: return self.optimal_trace - def bsym_assigned(self, bsym: BoundSymbol) -> bool: return isinstance(bsym.sym.executor, OperatorExecutor) or isinstance(bsym.sym.executor, FusionExecutor) - def benchmark_traces(self): min_run_time = float('inf') optimal_trace: TraceCtx = self.trace # Assign initial value for unbound errors best_label = "" + self.debug_msg += 'Traces benchmarks:\n\n' + for trace_info in self.optimized_traces: label = None @@ -820,23 +472,24 @@ def benchmark_traces(self): trace_time, res = benchmark_trace(trace) del res + self.debug_msg += f'Trace name = [{label}] - Time = {trace_time / 1000000} ms\n{trace}\n\n' self.log(f'Benchmark trace "{label}" (time = {trace_time / 1000000} ms):\n{trace}') if trace_time < min_run_time: min_run_time = trace_time optimal_trace = trace best_label = label - self.log(f'Benchmark end: Best trace "{best_label}":\n{optimal_trace}') + self.log(f'Benchmark end: Best trace "{best_label} (time = {min_run_time})":\n{optimal_trace}') self.optimal_trace = optimal_trace - with open(self.log_file_name, 'w') as file: - file.write(self.log_str) - file.close() - + if self.produce_log: + with open(self.log_file_name, 'w') as file: + file.write(self.debug_msg) + file.close() -# This will benpchmark the input trace with the del_last_used call -def benchmark_trace(trace: TraceCtx) -> tuple[float, Any]: +# This will benchmark the input trace with the del_last_used call +def benchmark_trace(trace: TraceCtx, iters: int = 1) -> tuple[float, Any]: from thunder.executors.passes import del_last_used input_args = [] @@ -845,10 +498,10 @@ def compute_time_cost(fn: Callable, iters: int, *args) -> tuple[float, Any]: total_time = 0 out = None for _ in range(iters): - time_s = time.time_ns() + time_s = time.perf_counter_ns() out = fn(*args) torch.cuda.synchronize() - time_e = time.time_ns() + time_e = time.perf_counter_ns() total_time += (time_e - time_s) return total_time / iters, out @@ -932,7 +585,7 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: # Obtain the python executable string executable_str = trace.python_callable() # TODO (matteochen): make the iters configurable - t, answer = compute_time_cost(executable_str, 1, *input_args) + t, answer = compute_time_cost(executable_str, iters, *input_args) reset_tracectx(trace_tok) From a9f0dbdff921e1d3740f18a6a76eab7129c5dc05 Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Fri, 12 Jul 2024 16:45:30 +0300 Subject: [PATCH 012/171] Moved already present fusion regions count at FusionOperator class level / Incremented optimzier benchmark iters --- thunder/backend_optimizer/optimizer.py | 33 ++++++++++---------------- thunder/executors/nvfuserex_impl.py | 4 ++-- thunder/executors/torch_compile.py | 2 +- thunder/extend/__init__.py | 12 ++++++++++ 4 files changed, 27 insertions(+), 24 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 9372885988..7f6216a2b6 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -7,7 +7,7 @@ from thunder.core.trace import from_trace, set_tracectx, reset_tracectx, get_tracectx, TraceCtx from thunder.core.utils import check, safe_map_flat from thunder.executors.data_dependent_partition import Graph, Node -from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_always_executors +from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_all_executors, get_always_executors from thunder.visualizer.visualizer_helper import Visualizer from typing import Any, Hashable import thunder @@ -45,7 +45,10 @@ def __init__(self, trace: TraceCtx, executors: Sequence[Executor], produce_log=T self.visualizer: Visualizer | None = visualizer self.partial_costs: dict[TraceCtx, float] = {} - self.log(f'New trace to optimize\n{self.trace}') + self.log(f'New trace to optimize:\n{self.trace}') + self.log('Executors:') + for o in self.executors: + print(f'{o.name} -> {type(o)}, is operator = {isinstance(o, OperatorExecutor)}, is fusion = {isinstance(o, FusionExecutor)}') class OptimizationStrat(Enum): EXAUSTIVE = 1 @@ -100,7 +103,7 @@ def benchmark_partial_trace(trace_in: TraceCtx, last_idx: int, configuration: li # Place the assigned symbols placed_t = self.place_optimizers(t, configuration) - cost, answer = benchmark_trace(placed_t) + cost, answer = benchmark_trace(placed_t, iters=10) del answer self.log(f'Executing partial trace for incremental benchmark:\n{placed_t}') self.log(f'Symbol under test = {t.bound_symbols[-2].sym.name}') @@ -154,7 +157,7 @@ def continue_search(): else: if min_cost_ex is None: raise AssertionError("Unexpected min cost executor or trace: None") - self.log(f'For id: {node.idx} - {node.symbol.sym.name} -> best backend {min_cost_ex.name}\n') + self.log(f'\nFor id: {node.idx} - {node.symbol.sym.name} -> best backend {min_cost_ex.name}\n') configuration.append(min_cost_ex) continue_search() @@ -176,28 +179,16 @@ def continue_search(): # try to fuse then and compare def try_to_fuse_after_executors_placement(self, trace_in: TraceCtx) -> TraceCtx: - # Fuser call is expecting a fresh trace with no prior fusion calls hence the fusion regions count will start from 0. We need to override that logic - def count_fusion_regions(trace_in: TraceCtx) -> int: - count = 0 - for bsym in trace_in.bound_symbols: - if isinstance(bsym.sym.executor, FusionExecutor): - count += 1 - # ex.fuseion_pass regions are zero indexed - return max(0, count) - best_trace: TraceCtx = trace_in - best_time, answer = benchmark_trace(best_trace) + best_time, answer = benchmark_trace(best_trace, iters=10) del answer trace_in_time = best_time - fusion_regions = count_fusion_regions(trace_in) - self.log(f'Try to fuse. Fusion regions already present: {fusion_regions}') - for ex in self.fusion_executors: self.log(f'Try to fuse executor {ex.name} with trace:\n{trace_in}') - extrace = ex.fusion_pass(trace_in, fusion_regions) + extrace = ex.fusion_pass(trace_in) self.log(f'Fused trace:\n{extrace}') - extrace_time, answer = benchmark_trace(extrace) + extrace_time, answer = benchmark_trace(extrace, iters=10) del answer self.log(f'Fused trace time:{extrace_time/1000000} ms') @@ -470,7 +461,7 @@ def benchmark_traces(self): label = k trace = v - trace_time, res = benchmark_trace(trace) + trace_time, res = benchmark_trace(trace, iters=10) del res self.debug_msg += f'Trace name = [{label}] - Time = {trace_time / 1000000} ms\n{trace}\n\n' self.log(f'Benchmark trace "{label}" (time = {trace_time / 1000000} ms):\n{trace}') @@ -479,7 +470,7 @@ def benchmark_traces(self): optimal_trace = trace best_label = label - self.log(f'Benchmark end: Best trace "{best_label} (time = {min_run_time})":\n{optimal_trace}') + self.log(f'Benchmark end: Best trace "{best_label} (time = {min_run_time / 1000000} ms)":\n{optimal_trace}') self.optimal_trace = optimal_trace diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 2774ff8bb4..d8ef99b26e 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -751,7 +751,7 @@ def map_redundant(x: Any) -> Any: return cse_trace # TODO Restore fusion logic here -- this just replaces supported operations in isolation at the moment - def fusion_pass(self, trace: TraceCtx, fusion_regions_in_trace: int = 0) -> TraceCtx: + def fusion_pass(self, trace: TraceCtx) -> TraceCtx: start_time_ns: int = time.perf_counter_ns() # Replace uniform with uniform_philox and rng state operators for better rematerialization from thunder.core.rematerialization import replace_uniform @@ -784,7 +784,7 @@ def _can_fuse_node(n: Node): # Counts how many fusions (per executor) have been constructed # (Used to name fusions like nvFusion0, nvFusion1, ...) - fusion_counter: int = fusion_regions_in_trace + fusion_counter: int = self.count_fusion_regions(trace, nvFuserExecutor) for bsyms in bound_symbol_groups: # TODO The following allows generating single node fusions, which # may be suboptimal for real-world performance. diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index ae7452fafa..fd5a753a25 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -156,7 +156,7 @@ def _can_fuse_node(n: Node): fused_bsyms = [] # Counts how many fusions (per executor) have been constructed - fusion_counter: int = 0 + fusion_counter: int = self.count_fusion_regions(trace, TorchCompileExecutor) for bsyms in bound_symbol_groups: if len(bsyms) == 1: bsym: BoundSymbol = bsyms[0] diff --git a/thunder/extend/__init__.py b/thunder/extend/__init__.py index e525c2e8a4..0135494de8 100644 --- a/thunder/extend/__init__.py +++ b/thunder/extend/__init__.py @@ -193,6 +193,18 @@ def _bind_postprocess(bsym: BoundSymbol) -> None: sym = Symbol(name=name, meta=_meta, is_fusion=True, _bind_postprocess=_bind_postprocess, executor=self) return sym.bind(*inputs, output=outputs) + # If a trace comes in with already placed fusion region we have to updated the initial counter (see derived class) + def count_fusion_regions(self, trace_in: TraceCtx, ex_type: type) -> int: + count = 0 + for bsym in trace_in.bound_symbols: + if not isinstance(bsym, BoundSymbol): + raise AssertionError(f"Expected a BoundSymbol, got: {type(bsym)}") + if type(bsym.sym.executor) is ex_type: + # if isinstance(bsym.sym.executor, FusionExecutor): + count += 1 + # ex.fuseion_pass regions are zero indexed + return max(0, count) + class OperatorExecutor(Executor): def __init__(self, name: Hashable, *, version: None | Any = None): From 3bfb69090d35c7f4d69945a47b4941e4ccaee70b Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Mon, 15 Jul 2024 10:50:17 +0300 Subject: [PATCH 013/171] Wip --- examples/dev/LLaMAMLP.py | 4 +- examples/dev/backward-log.out | 1305 -- examples/dev/backward_trc | 96 - examples/dev/backward_trc_final | 63 - examples/dev/backward_trc_fusion | 63 - examples/dev/forward_trc | 45 - examples/dev/forward_trc.dot | 4518 ---- examples/dev/litGPT.out | 25381 ----------------------- examples/dev/my_graph.png | Bin 144 -> 0 bytes examples/dev/simple.py | 22 +- examples/dev/simple_log.out | 132 - thunder/__init__.py | 3 +- thunder/backend_optimizer/optimizer.py | 2 +- thunder/executors/torch_autograd.py | 36 +- 14 files changed, 39 insertions(+), 31631 deletions(-) delete mode 100644 examples/dev/backward-log.out delete mode 100644 examples/dev/backward_trc delete mode 100644 examples/dev/backward_trc_final delete mode 100644 examples/dev/backward_trc_fusion delete mode 100644 examples/dev/forward_trc delete mode 100644 examples/dev/forward_trc.dot delete mode 100644 examples/dev/litGPT.out delete mode 100644 examples/dev/my_graph.png delete mode 100644 examples/dev/simple_log.out diff --git a/examples/dev/LLaMAMLP.py b/examples/dev/LLaMAMLP.py index 15493f86ed..2a68d16e49 100644 --- a/examples/dev/LLaMAMLP.py +++ b/examples/dev/LLaMAMLP.py @@ -25,10 +25,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: tot_time = 0 iters = 12 for i in range(iters): - start = time.time_ns() + start = time.perf_counter_ns() ans = jmodel(x) torch.cuda.synchronize() - end = time.time_ns() + end = time.perf_counter_ns() # Skip the model without cache if i > 1: diff --git a/examples/dev/backward-log.out b/examples/dev/backward-log.out deleted file mode 100644 index f74f642d6f..0000000000 --- a/examples/dev/backward-log.out +++ /dev/null @@ -1,1305 +0,0 @@ -============================================ START: LABEL default -============================================ START: computation_trc split_forward_backward -# Constructed by Dead Code Elimination (took 0 milliseconds) -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - - # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" - # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" - # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" - # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" - # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" - # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" - result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - return t18 -============================================ END: computation_trc split_forward_backward -============================================ START: primal_trace sort_data_parallel_syncs -# Constructed by Dead Code Elimination (took 0 milliseconds) -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - - # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" - # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" - # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" - # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" - # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" - # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" - result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - return t18 -============================================ END: primal_trace sort_data_parallel_syncs -============================================ START: augmented_forward_pass -result -t8 -env -{'a': VJPDual(primal=t6, - residuals=((t0, False), None, ([], [t3], [], [t5], [t5, t0]))), - 'result': VJPDual(primal=t7, residuals=((t6, t1), None, ([t1, t6],))), - 't18': VJPDual(primal=t8, - residuals=((t7, t_proj_weight, None), - None, - ([t_proj_weight, t7],))), - 't_fc_1_weight': VJPDual(primal=t_fc_1_weight, residuals=()), - 't_fc_2_weight': VJPDual(primal=t_fc_2_weight, residuals=()), - 't_proj_weight': VJPDual(primal=t_proj_weight, residuals=()), - 'x': VJPDual(primal=x, residuals=()), - 'x_fc_1': VJPDual(primal=t0, - residuals=((x, t_fc_1_weight, None), - None, - ([t_fc_1_weight, x],))), - 'x_fc_2': VJPDual(primal=t1, - residuals=((x, t_fc_2_weight, None), - None, - ([t_fc_2_weight, x],)))} -============================================ END: augmented_forward_pass -============================================ START: primal_trace forward_and_backward_from_trace -# Constructed by Augmented forward pass -import thunder -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - t4 = ltorch.add(1.0, t3, alpha=None) # t4: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - t6 = ltorch.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - t7 = ltorch.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) -============================================ END: primal_trace forward_and_backward_from_trace -============================================ START: before forward_trc transform_for_execution -# Constructed by Augmented forward pass -import thunder -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - t4 = ltorch.add(1.0, t3, alpha=None) # t4: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - t6 = ltorch.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - t7 = ltorch.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) -============================================ END: after forward_trc transform_for_execution -============================================ START: LABEL forward_trc -============================================ START: LABEL backward_trc -============================================ START: before _transform_for_operator_executor_execution -# Constructed by Dead Code Elimination (took 0 milliseconds) -import thunder -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def backward_fn(saved_for_backward, cotangents): - # saved_for_backward: "Collection" - # cotangents: "Collection" - C0, _, = saved_for_backward - t9, = cotangents - x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight, = C0 - t25 = ltorch.reshape(t9, -1, 4096) # t25: "cuda:0 f32[4096, 4096]" - # t25 = prims.reshape(t9, (4096, 4096)) # t25: "cuda:0 f32[4096, 4096]" - t26 = ltorch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" - # t26 = prims.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" - t27 = ltorch.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" - # t27 = prims.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" - t28 = ltorch.reshape(t9, -1, 4096) # t28: "cuda:0 f32[4096, 4096]" - # t28 = prims.reshape(t9, (4096, 4096)) # t28: "cuda:0 f32[4096, 4096]" - t29 = prims.transpose(t28, (1, 0)) # t29: "cuda:0 f32[4096, 4096]" - t30 = ltorch.reshape(t7, -1, 11008) # t30: "cuda:0 f32[4096, 11008]" - # t30 = prims.reshape(t7, (4096, 11008)) # t30: "cuda:0 f32[4096, 11008]" - t31 = ltorch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" - # t31 = prims.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" - t32 = ltorch.mul(t1, t27) # t32: "cuda:0 f32[2, 2048, 11008]" - # t32 = prims.mul(t1, t27) # t32: "cuda:0 f32[2, 2048, 11008]" - t33 = ltorch.mul(t6, t27) # t33: "cuda:0 f32[2, 2048, 11008]" - # t33 = prims.mul(t6, t27) # t33: "cuda:0 f32[2, 2048, 11008]" - t34 = ltorch.mul(t5, t32) # t34: "cuda:0 f32[2, 2048, 11008]" - # t34 = prims.mul(t5, t32) # t34: "cuda:0 f32[2, 2048, 11008]" - t35 = ltorch.mul(t0, t32) # t35: "cuda:0 f32[2, 2048, 11008]" - # t35 = prims.mul(t0, t32) # t35: "cuda:0 f32[2, 2048, 11008]" - t36 = ltorch.neg(t35) # t36: "cuda:0 f32[2, 2048, 11008]" - # t36 = prims.neg(t35) # t36: "cuda:0 f32[2, 2048, 11008]" - t37 = ltorch.mul(t36, t5) # t37: "cuda:0 f32[2, 2048, 11008]" - # t37 = prims.mul(t36, t5) # t37: "cuda:0 f32[2, 2048, 11008]" - t38 = ltorch.mul(t37, t5) # t38: "cuda:0 f32[2, 2048, 11008]" - # t38 = prims.mul(t37, t5) # t38: "cuda:0 f32[2, 2048, 11008]" - t39 = ltorch.mul(t38, t3) # t39: "cuda:0 f32[2, 2048, 11008]" - # t39 = prims.mul(t38, t3) # t39: "cuda:0 f32[2, 2048, 11008]" - t40 = ltorch.neg(t39) # t40: "cuda:0 f32[2, 2048, 11008]" - # t40 = prims.neg(t39) # t40: "cuda:0 f32[2, 2048, 11008]" - t41 = prims.add(t34, t40) # t41: "cuda:0 f32[2, 2048, 11008]" - t42 = ltorch.reshape(t33, -1, 11008) # t42: "cuda:0 f32[4096, 11008]" - # t42 = prims.reshape(t33, (4096, 11008)) # t42: "cuda:0 f32[4096, 11008]" - t43 = ltorch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" - # t43 = prims.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" - t44 = ltorch.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" - # t44 = prims.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" - t45 = ltorch.reshape(t33, -1, 11008) # t45: "cuda:0 f32[4096, 11008]" - # t45 = prims.reshape(t33, (4096, 11008)) # t45: "cuda:0 f32[4096, 11008]" - t46 = prims.transpose(t45, (1, 0)) # t46: "cuda:0 f32[11008, 4096]" - t47 = ltorch.reshape(x, -1, 4096) # t47: "cuda:0 f32[4096, 4096]" - # t47 = prims.reshape(x, (4096, 4096)) # t47: "cuda:0 f32[4096, 4096]" - t48 = ltorch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" - # t48 = prims.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" - t49 = ltorch.reshape(t41, -1, 11008) # t49: "cuda:0 f32[4096, 11008]" - # t49 = prims.reshape(t41, (4096, 11008)) # t49: "cuda:0 f32[4096, 11008]" - t50 = ltorch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" - # t50 = prims.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" - t51 = ltorch.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" - # t51 = prims.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" - t52 = ltorch.reshape(t41, -1, 11008) # t52: "cuda:0 f32[4096, 11008]" - # t52 = prims.reshape(t41, (4096, 11008)) # t52: "cuda:0 f32[4096, 11008]" - t53 = prims.transpose(t52, (1, 0)) # t53: "cuda:0 f32[11008, 4096]" - t54 = ltorch.reshape(x, -1, 4096) # t54: "cuda:0 f32[4096, 4096]" - # t54 = prims.reshape(x, (4096, 4096)) # t54: "cuda:0 f32[4096, 4096]" - t55 = ltorch.matmul(t53, t54) # t55: "cuda:0 f32[11008, 4096]" - # t55 = prims.matmul(t53, t54) # t55: "cuda:0 f32[11008, 4096]" - t56 = prims.add(t44, t51) # t56: "cuda:0 f32[2, 2048, 4096]" - return (t56, t55, t48, t31) -============================================ END: before _transform_for_operator_executor_execution -============================================ START: after _transform_for_operator_executor_execution -# Constructed by Transform for operator executor execution (took 1 milliseconds) -import thunder -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def backward_fn(saved_for_backward, cotangents): - # saved_for_backward: "Collection" - # cotangents: "Collection" - C0, _, = saved_for_backward - t9, = cotangents - x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight, = C0 - t25 = ltorch.reshape(t9, -1, 4096) # t25: "cuda:0 f32[4096, 4096]" - # t25 = prims.reshape(t9, (4096, 4096)) # t25: "cuda:0 f32[4096, 4096]" - t26 = torch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" - # t26 = ltorch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" - # t26 = prims.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" - t27 = ltorch.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" - # t27 = prims.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" - t28 = ltorch.reshape(t9, -1, 4096) # t28: "cuda:0 f32[4096, 4096]" - # t28 = prims.reshape(t9, (4096, 4096)) # t28: "cuda:0 f32[4096, 4096]" - t29 = prims.transpose(t28, (1, 0)) # t29: "cuda:0 f32[4096, 4096]" - t30 = ltorch.reshape(t7, -1, 11008) # t30: "cuda:0 f32[4096, 11008]" - # t30 = prims.reshape(t7, (4096, 11008)) # t30: "cuda:0 f32[4096, 11008]" - t31 = torch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" - # t31 = ltorch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" - # t31 = prims.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" - t32 = ltorch.mul(t1, t27) # t32: "cuda:0 f32[2, 2048, 11008]" - # t32 = prims.mul(t1, t27) # t32: "cuda:0 f32[2, 2048, 11008]" - t33 = ltorch.mul(t6, t27) # t33: "cuda:0 f32[2, 2048, 11008]" - # t33 = prims.mul(t6, t27) # t33: "cuda:0 f32[2, 2048, 11008]" - t34 = ltorch.mul(t5, t32) # t34: "cuda:0 f32[2, 2048, 11008]" - # t34 = prims.mul(t5, t32) # t34: "cuda:0 f32[2, 2048, 11008]" - t35 = ltorch.mul(t0, t32) # t35: "cuda:0 f32[2, 2048, 11008]" - # t35 = prims.mul(t0, t32) # t35: "cuda:0 f32[2, 2048, 11008]" - t36 = ltorch.neg(t35) # t36: "cuda:0 f32[2, 2048, 11008]" - # t36 = prims.neg(t35) # t36: "cuda:0 f32[2, 2048, 11008]" - t37 = ltorch.mul(t36, t5) # t37: "cuda:0 f32[2, 2048, 11008]" - # t37 = prims.mul(t36, t5) # t37: "cuda:0 f32[2, 2048, 11008]" - t38 = ltorch.mul(t37, t5) # t38: "cuda:0 f32[2, 2048, 11008]" - # t38 = prims.mul(t37, t5) # t38: "cuda:0 f32[2, 2048, 11008]" - t39 = ltorch.mul(t38, t3) # t39: "cuda:0 f32[2, 2048, 11008]" - # t39 = prims.mul(t38, t3) # t39: "cuda:0 f32[2, 2048, 11008]" - t40 = ltorch.neg(t39) # t40: "cuda:0 f32[2, 2048, 11008]" - # t40 = prims.neg(t39) # t40: "cuda:0 f32[2, 2048, 11008]" - t41 = prims.add(t34, t40) # t41: "cuda:0 f32[2, 2048, 11008]" - t42 = ltorch.reshape(t33, -1, 11008) # t42: "cuda:0 f32[4096, 11008]" - # t42 = prims.reshape(t33, (4096, 11008)) # t42: "cuda:0 f32[4096, 11008]" - t43 = torch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" - # t43 = ltorch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" - # t43 = prims.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" - t44 = ltorch.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" - # t44 = prims.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" - t45 = ltorch.reshape(t33, -1, 11008) # t45: "cuda:0 f32[4096, 11008]" - # t45 = prims.reshape(t33, (4096, 11008)) # t45: "cuda:0 f32[4096, 11008]" - t46 = prims.transpose(t45, (1, 0)) # t46: "cuda:0 f32[11008, 4096]" - t47 = ltorch.reshape(x, -1, 4096) # t47: "cuda:0 f32[4096, 4096]" - # t47 = prims.reshape(x, (4096, 4096)) # t47: "cuda:0 f32[4096, 4096]" - t48 = torch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" - # t48 = ltorch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" - # t48 = prims.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" - t49 = ltorch.reshape(t41, -1, 11008) # t49: "cuda:0 f32[4096, 11008]" - # t49 = prims.reshape(t41, (4096, 11008)) # t49: "cuda:0 f32[4096, 11008]" - t50 = torch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" - # t50 = ltorch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" - # t50 = prims.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" - t51 = ltorch.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" - # t51 = prims.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" - t52 = ltorch.reshape(t41, -1, 11008) # t52: "cuda:0 f32[4096, 11008]" - # t52 = prims.reshape(t41, (4096, 11008)) # t52: "cuda:0 f32[4096, 11008]" - t53 = prims.transpose(t52, (1, 0)) # t53: "cuda:0 f32[11008, 4096]" - t54 = ltorch.reshape(x, -1, 4096) # t54: "cuda:0 f32[4096, 4096]" - # t54 = prims.reshape(x, (4096, 4096)) # t54: "cuda:0 f32[4096, 4096]" - t55 = torch.matmul(t53, t54) # t55: "cuda:0 f32[11008, 4096]" - # t55 = ltorch.matmul(t53, t54) # t55: "cuda:0 f32[11008, 4096]" - # t55 = prims.matmul(t53, t54) # t55: "cuda:0 f32[11008, 4096]" - t56 = prims.add(t44, t51) # t56: "cuda:0 f32[2, 2048, 4096]" - return (t56, t55, t48, t31) -============================================ GRAPH: _transform_for_operator_executor_execution -graph roots: 0, 1, -traversal nodes: -node ID 0 : [# saved_for_backward: "Collection"] - parents ids: - children ids: 2, -node ID 1 : [# cotangents: "Collection"] - parents ids: - children ids: 3, -node ID 2 : [C0, _, = saved_for_backward] - parents ids: 0, - children ids: 4, -node ID 3 : [t9, = cotangents] - parents ids: 1, - children ids: 8, 5, -node ID 4 : [x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight, = C0] - parents ids: 2, - children ids: 34, 6, 10, 12, 13, 14, 15, 17, 18, 19, 23, 27, 30, -node ID 8 : [t28 = ltorch.reshape(t9, -1, 4096) # t28: "cuda:0 f32[4096, 4096]" - # t28 = prims.reshape(t9, (4096, 4096)) # t28: "cuda:0 f32[4096, 4096]"] - parents ids: 3, - children ids: 9, -node ID 5 : [t25 = ltorch.reshape(t9, -1, 4096) # t25: "cuda:0 f32[4096, 4096]" - # t25 = prims.reshape(t9, (4096, 4096)) # t25: "cuda:0 f32[4096, 4096]"] - parents ids: 3, - children ids: 6, -node ID 34 : [t54 = ltorch.reshape(x, -1, 4096) # t54: "cuda:0 f32[4096, 4096]" - # t54 = prims.reshape(x, (4096, 4096)) # t54: "cuda:0 f32[4096, 4096]"] - parents ids: 4, - children ids: 35, -node ID 6 : [t26 = ltorch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" - # t26 = prims.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]"] - parents ids: 4, 5, - children ids: 7, -node ID 10 : [t30 = ltorch.reshape(t7, -1, 11008) # t30: "cuda:0 f32[4096, 11008]" - # t30 = prims.reshape(t7, (4096, 11008)) # t30: "cuda:0 f32[4096, 11008]"] - parents ids: 4, - children ids: 11, -node ID 12 : [t32 = ltorch.mul(t1, t27) # t32: "cuda:0 f32[2, 2048, 11008]" - # t32 = prims.mul(t1, t27) # t32: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 4, 7, - children ids: 14, 15, -node ID 13 : [t33 = ltorch.mul(t6, t27) # t33: "cuda:0 f32[2, 2048, 11008]" - # t33 = prims.mul(t6, t27) # t33: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 4, 7, - children ids: 25, 22, -node ID 14 : [t34 = ltorch.mul(t5, t32) # t34: "cuda:0 f32[2, 2048, 11008]" - # t34 = prims.mul(t5, t32) # t34: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 4, 12, - children ids: 21, -node ID 15 : [t35 = ltorch.mul(t0, t32) # t35: "cuda:0 f32[2, 2048, 11008]" - # t35 = prims.mul(t0, t32) # t35: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 4, 12, - children ids: 16, -node ID 17 : [t37 = ltorch.mul(t36, t5) # t37: "cuda:0 f32[2, 2048, 11008]" - # t37 = prims.mul(t36, t5) # t37: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 16, 4, - children ids: 18, -node ID 18 : [t38 = ltorch.mul(t37, t5) # t38: "cuda:0 f32[2, 2048, 11008]" - # t38 = prims.mul(t37, t5) # t38: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 17, 4, - children ids: 19, -node ID 19 : [t39 = ltorch.mul(t38, t3) # t39: "cuda:0 f32[2, 2048, 11008]" - # t39 = prims.mul(t38, t3) # t39: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 18, 4, - children ids: 20, -node ID 23 : [t43 = ltorch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" - # t43 = prims.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]"] - parents ids: 4, 22, - children ids: 24, -node ID 27 : [t47 = ltorch.reshape(x, -1, 4096) # t47: "cuda:0 f32[4096, 4096]" - # t47 = prims.reshape(x, (4096, 4096)) # t47: "cuda:0 f32[4096, 4096]"] - parents ids: 4, - children ids: 28, -node ID 30 : [t50 = ltorch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" - # t50 = prims.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]"] - parents ids: 4, 29, - children ids: 31, -node ID 9 : [t29 = prims.transpose(t28, (1, 0)) # t29: "cuda:0 f32[4096, 4096]"] - parents ids: 8, - children ids: 11, -node ID 35 : [t55 = ltorch.matmul(t53, t54) # t55: "cuda:0 f32[11008, 4096]" - # t55 = prims.matmul(t53, t54) # t55: "cuda:0 f32[11008, 4096]"] - parents ids: 33, 34, - children ids: 37, -node ID 7 : [t27 = ltorch.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" - # t27 = prims.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 6, - children ids: 12, 13, -node ID 11 : [t31 = ltorch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" - # t31 = prims.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]"] - parents ids: 9, 10, - children ids: 37, -node ID 25 : [t45 = ltorch.reshape(t33, -1, 11008) # t45: "cuda:0 f32[4096, 11008]" - # t45 = prims.reshape(t33, (4096, 11008)) # t45: "cuda:0 f32[4096, 11008]"] - parents ids: 13, - children ids: 26, -node ID 22 : [t42 = ltorch.reshape(t33, -1, 11008) # t42: "cuda:0 f32[4096, 11008]" - # t42 = prims.reshape(t33, (4096, 11008)) # t42: "cuda:0 f32[4096, 11008]"] - parents ids: 13, - children ids: 23, -node ID 21 : [t41 = prims.add(t34, t40) # t41: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 20, 14, - children ids: 32, 29, -node ID 16 : [t36 = ltorch.neg(t35) # t36: "cuda:0 f32[2, 2048, 11008]" - # t36 = prims.neg(t35) # t36: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 15, - children ids: 17, -node ID 20 : [t40 = ltorch.neg(t39) # t40: "cuda:0 f32[2, 2048, 11008]" - # t40 = prims.neg(t39) # t40: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 19, - children ids: 21, -node ID 24 : [t44 = ltorch.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" - # t44 = prims.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]"] - parents ids: 23, - children ids: 36, -node ID 28 : [t48 = ltorch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" - # t48 = prims.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]"] - parents ids: 26, 27, - children ids: 37, -node ID 31 : [t51 = ltorch.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" - # t51 = prims.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]"] - parents ids: 30, - children ids: 36, -node ID 37 : [return (t56, t55, t48, t31)] - parents ids: 11, 35, 36, 28, - children ids: -node ID 26 : [t46 = prims.transpose(t45, (1, 0)) # t46: "cuda:0 f32[11008, 4096]"] - parents ids: 25, - children ids: 28, -node ID 32 : [t52 = ltorch.reshape(t41, -1, 11008) # t52: "cuda:0 f32[4096, 11008]" - # t52 = prims.reshape(t41, (4096, 11008)) # t52: "cuda:0 f32[4096, 11008]"] - parents ids: 21, - children ids: 33, -node ID 29 : [t49 = ltorch.reshape(t41, -1, 11008) # t49: "cuda:0 f32[4096, 11008]" - # t49 = prims.reshape(t41, (4096, 11008)) # t49: "cuda:0 f32[4096, 11008]"] - parents ids: 21, - children ids: 30, -node ID 36 : [t56 = prims.add(t44, t51) # t56: "cuda:0 f32[2, 2048, 4096]"] - parents ids: 24, 31, - children ids: 37, -node ID 33 : [t53 = prims.transpose(t52, (1, 0)) # t53: "cuda:0 f32[11008, 4096]"] - parents ids: 32, - children ids: 35, - -============================================ END: after _transform_for_operator_executor_execution -============================================ START: after fusion_pass -# Constructed by Fusion (took 3 milliseconds) -import thunder -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def backward_fn(saved_for_backward, cotangents): - # saved_for_backward: "Collection" - # cotangents: "Collection" - C0, _, = saved_for_backward - t9, = cotangents - x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight, = C0 - t25 = ltorch.reshape(t9, -1, 4096) # t25: "cuda:0 f32[4096, 4096]" - # t25 = prims.reshape(t9, (4096, 4096)) # t25: "cuda:0 f32[4096, 4096]" - t29 = prims.transpose(t25, (1, 0)) # t29: "cuda:0 f32[4096, 4096]" - t30 = ltorch.reshape(t7, -1, 11008) # t30: "cuda:0 f32[4096, 11008]" - # t30 = prims.reshape(t7, (4096, 11008)) # t30: "cuda:0 f32[4096, 11008]" - t47 = ltorch.reshape(x, -1, 4096) # t47: "cuda:0 f32[4096, 4096]" - # t47 = prims.reshape(x, (4096, 4096)) # t47: "cuda:0 f32[4096, 4096]" - t26 = torch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" - # t26 = ltorch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" - # t26 = prims.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" - t31 = torch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" - # t31 = ltorch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" - # t31 = prims.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" - t27 = ltorch.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" - # t27 = prims.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" - [t33, t41] = nvFusion0(t0, t1, t27, t3, t5, t6) - # t32 = prims.mul(t1, t27) # t32: "cuda:0 f32[2, 2048, 11008]" - # t33 = prims.mul(t6, t27) # t33: "cuda:0 f32[2, 2048, 11008]" - # t34 = prims.mul(t5, t32) # t34: "cuda:0 f32[2, 2048, 11008]" - # t35 = prims.mul(t0, t32) # t35: "cuda:0 f32[2, 2048, 11008]" - # t36 = prims.neg(t35) # t36: "cuda:0 f32[2, 2048, 11008]" - # t37 = prims.mul(t36, t5) # t37: "cuda:0 f32[2, 2048, 11008]" - # t38 = prims.mul(t37, t5) # t38: "cuda:0 f32[2, 2048, 11008]" - # t39 = prims.mul(t38, t3) # t39: "cuda:0 f32[2, 2048, 11008]" - # t40 = prims.neg(t39) # t40: "cuda:0 f32[2, 2048, 11008]" - # t41 = prims.add(t34, t40) # t41: "cuda:0 f32[2, 2048, 11008]" - t42 = ltorch.reshape(t33, -1, 11008) # t42: "cuda:0 f32[4096, 11008]" - # t42 = prims.reshape(t33, (4096, 11008)) # t42: "cuda:0 f32[4096, 11008]" - t46 = prims.transpose(t42, (1, 0)) # t46: "cuda:0 f32[11008, 4096]" - t49 = ltorch.reshape(t41, -1, 11008) # t49: "cuda:0 f32[4096, 11008]" - # t49 = prims.reshape(t41, (4096, 11008)) # t49: "cuda:0 f32[4096, 11008]" - t53 = prims.transpose(t49, (1, 0)) # t53: "cuda:0 f32[11008, 4096]" - t55 = torch.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]" - # t55 = ltorch.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]" - # t55 = prims.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]" - t48 = torch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" - # t48 = ltorch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" - # t48 = prims.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" - t50 = torch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" - # t50 = ltorch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" - # t50 = prims.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" - t43 = torch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" - # t43 = ltorch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" - # t43 = prims.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" - t44 = ltorch.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" - # t44 = prims.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" - t51 = ltorch.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" - # t51 = prims.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" - [t56] = nvFusion1(t44, t51) - # t56 = prims.add(t44, t51) # t56: "cuda:0 f32[2, 2048, 4096]" - return (t56, t55, t48, t31) -============================================ GRAPH: fusion_pass -graph roots: 0, 1, -traversal nodes: -node ID 0 : [# saved_for_backward: "Collection"] - parents ids: - children ids: 2, -node ID 1 : [# cotangents: "Collection"] - parents ids: - children ids: 3, -node ID 2 : [C0, _, = saved_for_backward] - parents ids: 0, - children ids: 4, -node ID 3 : [t9, = cotangents] - parents ids: 1, - children ids: 5, -node ID 4 : [x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight, = C0] - parents ids: 2, - children ids: 7, 8, 9, 12, 19, 20, -node ID 5 : [t25 = ltorch.reshape(t9, -1, 4096) # t25: "cuda:0 f32[4096, 4096]" - # t25 = prims.reshape(t9, (4096, 4096)) # t25: "cuda:0 f32[4096, 4096]"] - parents ids: 3, - children ids: 9, 6, -node ID 7 : [t30 = ltorch.reshape(t7, -1, 11008) # t30: "cuda:0 f32[4096, 11008]" - # t30 = prims.reshape(t7, (4096, 11008)) # t30: "cuda:0 f32[4096, 11008]"] - parents ids: 4, - children ids: 10, -node ID 8 : [t47 = ltorch.reshape(x, -1, 4096) # t47: "cuda:0 f32[4096, 4096]" - # t47 = prims.reshape(x, (4096, 4096)) # t47: "cuda:0 f32[4096, 4096]"] - parents ids: 4, - children ids: 17, 18, -node ID 9 : [t26 = torch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" - # t26 = ltorch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" - # t26 = prims.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]"] - parents ids: 4, 5, - children ids: 11, -node ID 12 : [[t33, t41] = nvFusion0(t0, t1, t27, t3, t5, t6) - # t32 = prims.mul(t1, t27) # t32: "cuda:0 f32[2, 2048, 11008]" - # t33 = prims.mul(t6, t27) # t33: "cuda:0 f32[2, 2048, 11008]" - # t34 = prims.mul(t5, t32) # t34: "cuda:0 f32[2, 2048, 11008]" - # t35 = prims.mul(t0, t32) # t35: "cuda:0 f32[2, 2048, 11008]" - # t36 = prims.neg(t35) # t36: "cuda:0 f32[2, 2048, 11008]" - # t37 = prims.mul(t36, t5) # t37: "cuda:0 f32[2, 2048, 11008]" - # t38 = prims.mul(t37, t5) # t38: "cuda:0 f32[2, 2048, 11008]" - # t39 = prims.mul(t38, t3) # t39: "cuda:0 f32[2, 2048, 11008]" - # t40 = prims.neg(t39) # t40: "cuda:0 f32[2, 2048, 11008]" - # t41 = prims.add(t34, t40) # t41: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 11, 4, - children ids: 13, 15, -node ID 19 : [t50 = torch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" - # t50 = ltorch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" - # t50 = prims.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]"] - parents ids: 4, 15, - children ids: 22, -node ID 20 : [t43 = torch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" - # t43 = ltorch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" - # t43 = prims.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]"] - parents ids: 4, 13, - children ids: 21, -node ID 6 : [t29 = prims.transpose(t25, (1, 0)) # t29: "cuda:0 f32[4096, 4096]"] - parents ids: 5, - children ids: 10, -node ID 10 : [t31 = torch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" - # t31 = ltorch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" - # t31 = prims.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]"] - parents ids: 6, 7, - children ids: 24, -node ID 17 : [t55 = torch.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]" - # t55 = ltorch.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]" - # t55 = prims.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]"] - parents ids: 16, 8, - children ids: 24, -node ID 18 : [t48 = torch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" - # t48 = ltorch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" - # t48 = prims.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]"] - parents ids: 8, 14, - children ids: 24, -node ID 11 : [t27 = ltorch.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" - # t27 = prims.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 9, - children ids: 12, -node ID 13 : [t42 = ltorch.reshape(t33, -1, 11008) # t42: "cuda:0 f32[4096, 11008]" - # t42 = prims.reshape(t33, (4096, 11008)) # t42: "cuda:0 f32[4096, 11008]"] - parents ids: 12, - children ids: 20, 14, -node ID 15 : [t49 = ltorch.reshape(t41, -1, 11008) # t49: "cuda:0 f32[4096, 11008]" - # t49 = prims.reshape(t41, (4096, 11008)) # t49: "cuda:0 f32[4096, 11008]"] - parents ids: 12, - children ids: 16, 19, -node ID 22 : [t51 = ltorch.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" - # t51 = prims.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]"] - parents ids: 19, - children ids: 23, -node ID 21 : [t44 = ltorch.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" - # t44 = prims.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]"] - parents ids: 20, - children ids: 23, -node ID 24 : [return (t56, t55, t48, t31)] - parents ids: 17, 18, 10, 23, - children ids: -node ID 14 : [t46 = prims.transpose(t42, (1, 0)) # t46: "cuda:0 f32[11008, 4096]"] - parents ids: 13, - children ids: 18, -node ID 16 : [t53 = prims.transpose(t49, (1, 0)) # t53: "cuda:0 f32[11008, 4096]"] - parents ids: 15, - children ids: 17, -node ID 23 : [[t56] = nvFusion1(t44, t51) - # t56 = prims.add(t44, t51) # t56: "cuda:0 f32[2, 2048, 4096]"] - parents ids: 21, 22, - children ids: 24, - -============================================ END: after fusion_pass -============================================ START: after _transform_for_operator_executor_execution (always) -# Constructed by Transform for operator executor execution (took 1 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def backward_fn(saved_for_backward, cotangents): - # saved_for_backward: "Collection" - # cotangents: "Collection" - C0, _, = saved_for_backward - t9, = cotangents - x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight, = C0 - t25 = torch.reshape(t9, (-1, 4096)) # t25: "cuda:0 f32[4096, 4096]" - # t25 = ltorch.reshape(t9, (-1, 4096)) # t25: "cuda:0 f32[4096, 4096]" - # t25 = prims.reshape(t9, (4096, 4096)) # t25: "cuda:0 f32[4096, 4096]" - t29 = torch.permute(t25, (1, 0)) # t29: "cuda:0 f32[4096, 4096]" - # t29 = ltorch.permute(t25, (1, 0)) # t29: "cuda:0 f32[4096, 4096]" - # t29 = prims.transpose(t25, (1, 0)) # t29: "cuda:0 f32[4096, 4096]" - t30 = torch.reshape(t7, (-1, 11008)) # t30: "cuda:0 f32[4096, 11008]" - # t30 = ltorch.reshape(t7, (-1, 11008)) # t30: "cuda:0 f32[4096, 11008]" - # t30 = prims.reshape(t7, (4096, 11008)) # t30: "cuda:0 f32[4096, 11008]" - t47 = torch.reshape(x, (-1, 4096)) # t47: "cuda:0 f32[4096, 4096]" - # t47 = ltorch.reshape(x, (-1, 4096)) # t47: "cuda:0 f32[4096, 4096]" - # t47 = prims.reshape(x, (4096, 4096)) # t47: "cuda:0 f32[4096, 4096]" - t26 = torch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" - # t26 = ltorch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" - # t26 = prims.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" - t31 = torch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" - # t31 = ltorch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" - # t31 = prims.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" - t27 = torch.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" - # t27 = ltorch.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" - # t27 = prims.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" - [t33, t41] = nvFusion0(t0, t1, t27, t3, t5, t6) - # t32 = prims.mul(t1, t27) # t32: "cuda:0 f32[2, 2048, 11008]" - # t33 = prims.mul(t6, t27) # t33: "cuda:0 f32[2, 2048, 11008]" - # t34 = prims.mul(t5, t32) # t34: "cuda:0 f32[2, 2048, 11008]" - # t35 = prims.mul(t0, t32) # t35: "cuda:0 f32[2, 2048, 11008]" - # t36 = prims.neg(t35) # t36: "cuda:0 f32[2, 2048, 11008]" - # t37 = prims.mul(t36, t5) # t37: "cuda:0 f32[2, 2048, 11008]" - # t38 = prims.mul(t37, t5) # t38: "cuda:0 f32[2, 2048, 11008]" - # t39 = prims.mul(t38, t3) # t39: "cuda:0 f32[2, 2048, 11008]" - # t40 = prims.neg(t39) # t40: "cuda:0 f32[2, 2048, 11008]" - # t41 = prims.add(t34, t40) # t41: "cuda:0 f32[2, 2048, 11008]" - t42 = torch.reshape(t33, (-1, 11008)) # t42: "cuda:0 f32[4096, 11008]" - # t42 = ltorch.reshape(t33, (-1, 11008)) # t42: "cuda:0 f32[4096, 11008]" - # t42 = prims.reshape(t33, (4096, 11008)) # t42: "cuda:0 f32[4096, 11008]" - t46 = torch.permute(t42, (1, 0)) # t46: "cuda:0 f32[11008, 4096]" - # t46 = ltorch.permute(t42, (1, 0)) # t46: "cuda:0 f32[11008, 4096]" - # t46 = prims.transpose(t42, (1, 0)) # t46: "cuda:0 f32[11008, 4096]" - t49 = torch.reshape(t41, (-1, 11008)) # t49: "cuda:0 f32[4096, 11008]" - # t49 = ltorch.reshape(t41, (-1, 11008)) # t49: "cuda:0 f32[4096, 11008]" - # t49 = prims.reshape(t41, (4096, 11008)) # t49: "cuda:0 f32[4096, 11008]" - t53 = torch.permute(t49, (1, 0)) # t53: "cuda:0 f32[11008, 4096]" - # t53 = ltorch.permute(t49, (1, 0)) # t53: "cuda:0 f32[11008, 4096]" - # t53 = prims.transpose(t49, (1, 0)) # t53: "cuda:0 f32[11008, 4096]" - t55 = torch.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]" - # t55 = ltorch.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]" - # t55 = prims.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]" - t48 = torch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" - # t48 = ltorch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" - # t48 = prims.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" - t50 = torch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" - # t50 = ltorch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" - # t50 = prims.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" - t43 = torch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" - # t43 = ltorch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" - # t43 = prims.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" - t44 = torch.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" - # t44 = ltorch.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" - # t44 = prims.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" - t51 = torch.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" - # t51 = ltorch.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" - # t51 = prims.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" - [t56] = nvFusion1(t44, t51) - # t56 = prims.add(t44, t51) # t56: "cuda:0 f32[2, 2048, 4096]" - return (t56, t55, t48, t31) -============================================ GRAPH: fusion_pass -graph roots: 0, 1, -traversal nodes: -node ID 0 : [# saved_for_backward: "Collection"] - parents ids: - children ids: 2, -node ID 1 : [# cotangents: "Collection"] - parents ids: - children ids: 3, -node ID 2 : [C0, _, = saved_for_backward] - parents ids: 0, - children ids: 4, -node ID 3 : [t9, = cotangents] - parents ids: 1, - children ids: 5, -node ID 4 : [x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight, = C0] - parents ids: 2, - children ids: 7, 8, 9, 12, 19, 20, -node ID 5 : [t25 = torch.reshape(t9, (-1, 4096)) # t25: "cuda:0 f32[4096, 4096]" - # t25 = ltorch.reshape(t9, (-1, 4096)) # t25: "cuda:0 f32[4096, 4096]" - # t25 = prims.reshape(t9, (4096, 4096)) # t25: "cuda:0 f32[4096, 4096]"] - parents ids: 3, - children ids: 9, 6, -node ID 7 : [t30 = torch.reshape(t7, (-1, 11008)) # t30: "cuda:0 f32[4096, 11008]" - # t30 = ltorch.reshape(t7, (-1, 11008)) # t30: "cuda:0 f32[4096, 11008]" - # t30 = prims.reshape(t7, (4096, 11008)) # t30: "cuda:0 f32[4096, 11008]"] - parents ids: 4, - children ids: 10, -node ID 8 : [t47 = torch.reshape(x, (-1, 4096)) # t47: "cuda:0 f32[4096, 4096]" - # t47 = ltorch.reshape(x, (-1, 4096)) # t47: "cuda:0 f32[4096, 4096]" - # t47 = prims.reshape(x, (4096, 4096)) # t47: "cuda:0 f32[4096, 4096]"] - parents ids: 4, - children ids: 17, 18, -node ID 9 : [t26 = torch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" - # t26 = ltorch.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]" - # t26 = prims.matmul(t25, t_proj_weight) # t26: "cuda:0 f32[4096, 11008]"] - parents ids: 4, 5, - children ids: 11, -node ID 12 : [[t33, t41] = nvFusion0(t0, t1, t27, t3, t5, t6) - # t32 = prims.mul(t1, t27) # t32: "cuda:0 f32[2, 2048, 11008]" - # t33 = prims.mul(t6, t27) # t33: "cuda:0 f32[2, 2048, 11008]" - # t34 = prims.mul(t5, t32) # t34: "cuda:0 f32[2, 2048, 11008]" - # t35 = prims.mul(t0, t32) # t35: "cuda:0 f32[2, 2048, 11008]" - # t36 = prims.neg(t35) # t36: "cuda:0 f32[2, 2048, 11008]" - # t37 = prims.mul(t36, t5) # t37: "cuda:0 f32[2, 2048, 11008]" - # t38 = prims.mul(t37, t5) # t38: "cuda:0 f32[2, 2048, 11008]" - # t39 = prims.mul(t38, t3) # t39: "cuda:0 f32[2, 2048, 11008]" - # t40 = prims.neg(t39) # t40: "cuda:0 f32[2, 2048, 11008]" - # t41 = prims.add(t34, t40) # t41: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 11, 4, - children ids: 13, 15, -node ID 19 : [t50 = torch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" - # t50 = ltorch.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]" - # t50 = prims.matmul(t49, t_fc_1_weight) # t50: "cuda:0 f32[4096, 4096]"] - parents ids: 4, 15, - children ids: 22, -node ID 20 : [t43 = torch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" - # t43 = ltorch.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]" - # t43 = prims.matmul(t42, t_fc_2_weight) # t43: "cuda:0 f32[4096, 4096]"] - parents ids: 4, 13, - children ids: 21, -node ID 6 : [t29 = torch.permute(t25, (1, 0)) # t29: "cuda:0 f32[4096, 4096]" - # t29 = ltorch.permute(t25, (1, 0)) # t29: "cuda:0 f32[4096, 4096]" - # t29 = prims.transpose(t25, (1, 0)) # t29: "cuda:0 f32[4096, 4096]"] - parents ids: 5, - children ids: 10, -node ID 10 : [t31 = torch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" - # t31 = ltorch.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]" - # t31 = prims.matmul(t29, t30) # t31: "cuda:0 f32[4096, 11008]"] - parents ids: 6, 7, - children ids: 24, -node ID 17 : [t55 = torch.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]" - # t55 = ltorch.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]" - # t55 = prims.matmul(t53, t47) # t55: "cuda:0 f32[11008, 4096]"] - parents ids: 16, 8, - children ids: 24, -node ID 18 : [t48 = torch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" - # t48 = ltorch.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]" - # t48 = prims.matmul(t46, t47) # t48: "cuda:0 f32[11008, 4096]"] - parents ids: 8, 14, - children ids: 24, -node ID 11 : [t27 = torch.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" - # t27 = ltorch.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]" - # t27 = prims.reshape(t26, (2, 2048, 11008)) # t27: "cuda:0 f32[2, 2048, 11008]"] - parents ids: 9, - children ids: 12, -node ID 13 : [t42 = torch.reshape(t33, (-1, 11008)) # t42: "cuda:0 f32[4096, 11008]" - # t42 = ltorch.reshape(t33, (-1, 11008)) # t42: "cuda:0 f32[4096, 11008]" - # t42 = prims.reshape(t33, (4096, 11008)) # t42: "cuda:0 f32[4096, 11008]"] - parents ids: 12, - children ids: 20, 14, -node ID 15 : [t49 = torch.reshape(t41, (-1, 11008)) # t49: "cuda:0 f32[4096, 11008]" - # t49 = ltorch.reshape(t41, (-1, 11008)) # t49: "cuda:0 f32[4096, 11008]" - # t49 = prims.reshape(t41, (4096, 11008)) # t49: "cuda:0 f32[4096, 11008]"] - parents ids: 12, - children ids: 16, 19, -node ID 22 : [t51 = torch.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" - # t51 = ltorch.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]" - # t51 = prims.reshape(t50, (2, 2048, 4096)) # t51: "cuda:0 f32[2, 2048, 4096]"] - parents ids: 19, - children ids: 23, -node ID 21 : [t44 = torch.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" - # t44 = ltorch.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]" - # t44 = prims.reshape(t43, (2, 2048, 4096)) # t44: "cuda:0 f32[2, 2048, 4096]"] - parents ids: 20, - children ids: 23, -node ID 24 : [return (t56, t55, t48, t31)] - parents ids: 17, 18, 10, 23, - children ids: -node ID 14 : [t46 = torch.permute(t42, (1, 0)) # t46: "cuda:0 f32[11008, 4096]" - # t46 = ltorch.permute(t42, (1, 0)) # t46: "cuda:0 f32[11008, 4096]" - # t46 = prims.transpose(t42, (1, 0)) # t46: "cuda:0 f32[11008, 4096]"] - parents ids: 13, - children ids: 18, -node ID 16 : [t53 = torch.permute(t49, (1, 0)) # t53: "cuda:0 f32[11008, 4096]" - # t53 = ltorch.permute(t49, (1, 0)) # t53: "cuda:0 f32[11008, 4096]" - # t53 = prims.transpose(t49, (1, 0)) # t53: "cuda:0 f32[11008, 4096]"] - parents ids: 15, - children ids: 17, -node ID 23 : [[t56] = nvFusion1(t44, t51) - # t56 = prims.add(t44, t51) # t56: "cuda:0 f32[2, 2048, 4096]"] - parents ids: 21, 22, - children ids: 24, - -============================================ END: after _transform_for_operator_executor_execution (always) ----------------------------------------------- all traces -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - - # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" - # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" - # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" - # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" - # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" - # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" - result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - return t18 -############################################## -# Constructed by Dead Code Elimination (took 0 milliseconds) -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - - # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" - # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" - # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" - # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" - # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" - # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" - result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - return t18 -############################################## -# Constructed by Dead Code Elimination (took 0 milliseconds) -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - - # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" - # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" - # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" - # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" - # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" - # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" - result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - return t18 -############################################## -# Constructed by Augmented forward pass -import thunder -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - t4 = ltorch.add(1.0, t3, alpha=None) # t4: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - t6 = ltorch.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - t7 = ltorch.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) -############################################## -# Constructed by Transform for execution (took 2 milliseconds) -import torch -import torch.nn.functional -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - [t3, t5, t6, t7] = nvFusion0(t0, t1) - # t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - # t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - # t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) -############################################## -# Constructed by Update Call Context (took 0 milliseconds) -import torch -import torch.nn.functional -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - [t7] = nvFusion0(t0, t1) - # t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - # t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - # t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((t0, t1, t7, t_fc_1_weight, t_fc_2_weight, t_proj_weight, x), ()) -############################################## -# Constructed by Delete Last Used (took 0 milliseconds) -import torch -import torch.nn.functional -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - [t7] = nvFusion0(t0, t1) - # t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - # t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - # t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((t0, t1, t7, t_fc_1_weight, t_fc_2_weight, t_proj_weight, x), ()) -############################################## -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - - # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" - # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" - # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" - # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" - # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" - # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" - result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - return t18 -############################################## -# Constructed by Dead Code Elimination (took 0 milliseconds) -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - - # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" - # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" - # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" - # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" - # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" - # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" - result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - return t18 -############################################## -# Constructed by Dead Code Elimination (took 0 milliseconds) -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - x_fc_1 = ltorch.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - # x_fc_1 = prims.linear(x, t_fc_1_weight, None) # x_fc_1: "cuda:0 f32[2, 2048, 11008]" - x_fc_2 = ltorch.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - # x_fc_2 = prims.linear(x, t_fc_2_weight, None) # x_fc_2: "cuda:0 f32[2, 2048, 11008]" - - # /workspace/pj/lightning-thunder/examples/dev/LLaMAMLP.py:13: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - a = ltorch.silu(x_fc_1, False) # a: "cuda:0 f32[2, 2048, 11008]" - # t9 = prims.neg(x_fc_1) # t9: "cuda:0 f32[2, 2048, 11008]" - # t10 = prims.exp(t9) # t10: "cuda:0 f32[2, 2048, 11008]" - # t11 = prims.add(1.0, t10) # t11: "cuda:0 f32[2, 2048, 11008]" - # t12 = prims.reciprocal(t11) # t12: "cuda:0 f32[2, 2048, 11008]" - # a = prims.mul(x_fc_1, t12) # a: "cuda:0 f32[2, 2048, 11008]" - result = ltorch.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - # result = prims.mul(a, x_fc_2) # result: "cuda:0 f32[2, 2048, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t18 = ltorch.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - # t18 = prims.linear(result, t_proj_weight, None) # t18: "cuda:0 f32[2, 2048, 4096]" - return t18 -############################################## -# Constructed by Augmented forward pass -import thunder -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - t4 = ltorch.add(1.0, t3, alpha=None) # t4: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - t6 = ltorch.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - t7 = ltorch.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) -############################################## -# Constructed by Transform for execution (took 2 milliseconds) -import torch -import torch.nn.functional -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - [t3, t5, t6, t7] = nvFusion0(t0, t1) - # t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - # t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - # t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((x, t_fc_1_weight, t_fc_2_weight, t0, t3, t5, t6, t1, t7, t_proj_weight), ()) -############################################## -# Constructed by Update Call Context (took 0 milliseconds) -import torch -import torch.nn.functional -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - [t7] = nvFusion0(t0, t1) - # t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - # t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - # t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((t0, t1, t7, t_fc_1_weight, t_fc_2_weight, t_proj_weight, x), ()) -############################################## -# Constructed by Delete Last Used (took 0 milliseconds) -import torch -import torch.nn.functional -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(x, t_fc_1_weight, t_fc_2_weight, t_proj_weight): - # x: "cuda:0 f32[2, 2048, 4096]" - # t_fc_1_weight: "cuda:0 f32[11008, 4096]" - # t_fc_2_weight: "cuda:0 f32[11008, 4096]" - # t_proj_weight: "cuda:0 f32[4096, 11008]" - t0 = torch.nn.functional.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = ltorch.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - # t0 = prims.linear(x, t_fc_1_weight, None) # t0: "cuda:0 f32[2, 2048, 11008]" - t1 = torch.nn.functional.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = ltorch.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - # t1 = prims.linear(x, t_fc_2_weight, None) # t1: "cuda:0 f32[2, 2048, 11008]" - [t7] = nvFusion0(t0, t1) - # t2 = prims.neg(t0) # t2: "cuda:0 f32[2, 2048, 11008]" - # t3 = prims.exp(t2) # t3: "cuda:0 f32[2, 2048, 11008]" - # t4 = prims.add(1.0, t3) # t4: "cuda:0 f32[2, 2048, 11008]" - # t5 = prims.reciprocal(t4) # t5: "cuda:0 f32[2, 2048, 11008]" - # t6 = prims.mul(t0, t5) # t6: "cuda:0 f32[2, 2048, 11008]" - # t7 = prims.mul(t6, t1) # t7: "cuda:0 f32[2, 2048, 11008]" - t8 = torch.nn.functional.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = ltorch.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - # t8 = prims.linear(t7, t_proj_weight, None) # t8: "cuda:0 f32[2, 2048, 4096]" - return {'output': t8, 'flat_args': [x, t_fc_1_weight, t_fc_2_weight, t_proj_weight], 'flat_output': (t8,)}, ((t0, t1, t7, t_fc_1_weight, t_fc_2_weight, t_proj_weight, x), ()) -############################################## ----------------------------------------------- ans -tensor([[[-0.0380, 0.1292, 0.0922, ..., 0.0574, -0.0760, -0.0142], - [-0.0312, -0.1352, 0.1404, ..., -0.0036, -0.1777, -0.0775], - [ 0.0121, -0.0281, -0.1634, ..., 0.0387, -0.2150, 0.0118], - ..., - [-0.1302, 0.0754, -0.1463, ..., -0.0835, -0.1263, 0.1630], - [-0.0158, 0.2085, 0.0153, ..., -0.0273, -0.0947, -0.0970], - [-0.2236, -0.1944, 0.0894, ..., 0.0347, -0.0962, 0.1017]], - - [[ 0.0363, -0.1088, 0.1518, ..., 0.0293, 0.1325, 0.0490], - [-0.1212, -0.2084, 0.1211, ..., -0.1555, 0.0875, -0.0580], - [ 0.1207, -0.0828, -0.0089, ..., 0.0490, 0.0931, 0.0576], - ..., - [-0.0100, 0.0776, 0.1118, ..., 0.0961, 0.0167, 0.0933], - [-0.1560, 0.0455, -0.0116, ..., 0.0028, -0.0157, -0.0022], - [ 0.3174, 0.0314, -0.0429, ..., 0.1140, 0.0264, 0.0614]]], - device='cuda:0', grad_fn=) diff --git a/examples/dev/backward_trc b/examples/dev/backward_trc deleted file mode 100644 index 84a36f90a6..0000000000 --- a/examples/dev/backward_trc +++ /dev/null @@ -1,96 +0,0 @@ -digraph { - "0([Symbol name=unpack_trivial])" - "1([Symbol name=unpack_trivial])" - "0([Symbol name=unpack_trivial])" - "1([Symbol name=unpack_trivial])" - "2([Symbol name=unpack_sequence])" - "0([Symbol name=unpack_trivial])" -> "2([Symbol name=unpack_sequence])" - "3([Symbol name=unpack_sequence])" - "1([Symbol name=unpack_trivial])" -> "3([Symbol name=unpack_sequence])" - "4([Symbol name=unpack_sequence])" - "2([Symbol name=unpack_sequence])" -> "4([Symbol name=unpack_sequence])" - "8([Symbol name=reshape])" - "3([Symbol name=unpack_sequence])" -> "8([Symbol name=reshape])" - "5([Symbol name=reshape])" - "3([Symbol name=unpack_sequence])" -> "5([Symbol name=reshape])" - "34([Symbol name=reshape])" - "4([Symbol name=unpack_sequence])" -> "34([Symbol name=reshape])" - "6([Symbol name=matmul])" - "4([Symbol name=unpack_sequence])" -> "6([Symbol name=matmul])" - "5([Symbol name=reshape])" -> "6([Symbol name=matmul])" - "10([Symbol name=reshape])" - "4([Symbol name=unpack_sequence])" -> "10([Symbol name=reshape])" - "12([Symbol name=mul])" - "4([Symbol name=unpack_sequence])" -> "12([Symbol name=mul])" - "7([Symbol name=reshape])" -> "12([Symbol name=mul])" - "13([Symbol name=mul])" - "4([Symbol name=unpack_sequence])" -> "13([Symbol name=mul])" - "7([Symbol name=reshape])" -> "13([Symbol name=mul])" - "14([Symbol name=mul])" - "4([Symbol name=unpack_sequence])" -> "14([Symbol name=mul])" - "12([Symbol name=mul])" -> "14([Symbol name=mul])" - "15([Symbol name=mul])" - "4([Symbol name=unpack_sequence])" -> "15([Symbol name=mul])" - "12([Symbol name=mul])" -> "15([Symbol name=mul])" - "17([Symbol name=mul])" - "16([Symbol name=neg])" -> "17([Symbol name=mul])" - "4([Symbol name=unpack_sequence])" -> "17([Symbol name=mul])" - "18([Symbol name=mul])" - "17([Symbol name=mul])" -> "18([Symbol name=mul])" - "4([Symbol name=unpack_sequence])" -> "18([Symbol name=mul])" - "19([Symbol name=mul])" - "18([Symbol name=mul])" -> "19([Symbol name=mul])" - "4([Symbol name=unpack_sequence])" -> "19([Symbol name=mul])" - "23([Symbol name=matmul])" - "4([Symbol name=unpack_sequence])" -> "23([Symbol name=matmul])" - "22([Symbol name=reshape])" -> "23([Symbol name=matmul])" - "27([Symbol name=reshape])" - "4([Symbol name=unpack_sequence])" -> "27([Symbol name=reshape])" - "30([Symbol name=matmul])" - "4([Symbol name=unpack_sequence])" -> "30([Symbol name=matmul])" - "29([Symbol name=reshape])" -> "30([Symbol name=matmul])" - "9([Symbol name=transpose])" - "8([Symbol name=reshape])" -> "9([Symbol name=transpose])" - "35([Symbol name=matmul])" - "33([Symbol name=transpose])" -> "35([Symbol name=matmul])" - "34([Symbol name=reshape])" -> "35([Symbol name=matmul])" - "7([Symbol name=reshape])" - "6([Symbol name=matmul])" -> "7([Symbol name=reshape])" - "11([Symbol name=matmul])" - "9([Symbol name=transpose])" -> "11([Symbol name=matmul])" - "10([Symbol name=reshape])" -> "11([Symbol name=matmul])" - "25([Symbol name=reshape])" - "13([Symbol name=mul])" -> "25([Symbol name=reshape])" - "22([Symbol name=reshape])" - "13([Symbol name=mul])" -> "22([Symbol name=reshape])" - "21([Symbol name=add])" - "20([Symbol name=neg])" -> "21([Symbol name=add])" - "14([Symbol name=mul])" -> "21([Symbol name=add])" - "16([Symbol name=neg])" - "15([Symbol name=mul])" -> "16([Symbol name=neg])" - "20([Symbol name=neg])" - "19([Symbol name=mul])" -> "20([Symbol name=neg])" - "24([Symbol name=reshape])" - "23([Symbol name=matmul])" -> "24([Symbol name=reshape])" - "28([Symbol name=matmul])" - "26([Symbol name=transpose])" -> "28([Symbol name=matmul])" - "27([Symbol name=reshape])" -> "28([Symbol name=matmul])" - "31([Symbol name=reshape])" - "30([Symbol name=matmul])" -> "31([Symbol name=reshape])" - "37([Symbol name=return])" - "11([Symbol name=matmul])" -> "37([Symbol name=return])" - "35([Symbol name=matmul])" -> "37([Symbol name=return])" - "36([Symbol name=add])" -> "37([Symbol name=return])" - "28([Symbol name=matmul])" -> "37([Symbol name=return])" - "26([Symbol name=transpose])" - "25([Symbol name=reshape])" -> "26([Symbol name=transpose])" - "32([Symbol name=reshape])" - "21([Symbol name=add])" -> "32([Symbol name=reshape])" - "29([Symbol name=reshape])" - "21([Symbol name=add])" -> "29([Symbol name=reshape])" - "36([Symbol name=add])" - "24([Symbol name=reshape])" -> "36([Symbol name=add])" - "31([Symbol name=reshape])" -> "36([Symbol name=add])" - "33([Symbol name=transpose])" - "32([Symbol name=reshape])" -> "33([Symbol name=transpose])" -} diff --git a/examples/dev/backward_trc_final b/examples/dev/backward_trc_final deleted file mode 100644 index a6d47c2df1..0000000000 --- a/examples/dev/backward_trc_final +++ /dev/null @@ -1,63 +0,0 @@ -digraph { - "0([Symbol name=unpack_trivial])" - "1([Symbol name=unpack_trivial])" - "0([Symbol name=unpack_trivial])" - "1([Symbol name=unpack_trivial])" - "2([Symbol name=unpack_sequence])" - "0([Symbol name=unpack_trivial])" -> "2([Symbol name=unpack_sequence])" - "3([Symbol name=unpack_sequence])" - "1([Symbol name=unpack_trivial])" -> "3([Symbol name=unpack_sequence])" - "4([Symbol name=unpack_sequence])" - "2([Symbol name=unpack_sequence])" -> "4([Symbol name=unpack_sequence])" - "5([Symbol name=reshape])" - "3([Symbol name=unpack_sequence])" -> "5([Symbol name=reshape])" - "7([Symbol name=reshape])" - "4([Symbol name=unpack_sequence])" -> "7([Symbol name=reshape])" - "8([Symbol name=reshape])" - "4([Symbol name=unpack_sequence])" -> "8([Symbol name=reshape])" - "9([Symbol name=matmul])" - "4([Symbol name=unpack_sequence])" -> "9([Symbol name=matmul])" - "5([Symbol name=reshape])" -> "9([Symbol name=matmul])" - "12([Symbol name=nvFusion0])" - "11([Symbol name=reshape])" -> "12([Symbol name=nvFusion0])" - "4([Symbol name=unpack_sequence])" -> "12([Symbol name=nvFusion0])" - "19([Symbol name=matmul])" - "4([Symbol name=unpack_sequence])" -> "19([Symbol name=matmul])" - "15([Symbol name=reshape])" -> "19([Symbol name=matmul])" - "20([Symbol name=matmul])" - "4([Symbol name=unpack_sequence])" -> "20([Symbol name=matmul])" - "13([Symbol name=reshape])" -> "20([Symbol name=matmul])" - "6([Symbol name=permute])" - "5([Symbol name=reshape])" -> "6([Symbol name=permute])" - "10([Symbol name=matmul])" - "6([Symbol name=permute])" -> "10([Symbol name=matmul])" - "7([Symbol name=reshape])" -> "10([Symbol name=matmul])" - "17([Symbol name=matmul])" - "16([Symbol name=permute])" -> "17([Symbol name=matmul])" - "8([Symbol name=reshape])" -> "17([Symbol name=matmul])" - "18([Symbol name=matmul])" - "8([Symbol name=reshape])" -> "18([Symbol name=matmul])" - "14([Symbol name=permute])" -> "18([Symbol name=matmul])" - "11([Symbol name=reshape])" - "9([Symbol name=matmul])" -> "11([Symbol name=reshape])" - "13([Symbol name=reshape])" - "12([Symbol name=nvFusion0])" -> "13([Symbol name=reshape])" - "15([Symbol name=reshape])" - "12([Symbol name=nvFusion0])" -> "15([Symbol name=reshape])" - "22([Symbol name=reshape])" - "19([Symbol name=matmul])" -> "22([Symbol name=reshape])" - "21([Symbol name=reshape])" - "20([Symbol name=matmul])" -> "21([Symbol name=reshape])" - "24([Symbol name=return])" - "17([Symbol name=matmul])" -> "24([Symbol name=return])" - "18([Symbol name=matmul])" -> "24([Symbol name=return])" - "10([Symbol name=matmul])" -> "24([Symbol name=return])" - "23([Symbol name=nvFusion1])" -> "24([Symbol name=return])" - "14([Symbol name=permute])" - "13([Symbol name=reshape])" -> "14([Symbol name=permute])" - "16([Symbol name=permute])" - "15([Symbol name=reshape])" -> "16([Symbol name=permute])" - "23([Symbol name=nvFusion1])" - "21([Symbol name=reshape])" -> "23([Symbol name=nvFusion1])" - "22([Symbol name=reshape])" -> "23([Symbol name=nvFusion1])" -} diff --git a/examples/dev/backward_trc_fusion b/examples/dev/backward_trc_fusion deleted file mode 100644 index 7f9f9df5b0..0000000000 --- a/examples/dev/backward_trc_fusion +++ /dev/null @@ -1,63 +0,0 @@ -digraph { - "0([Symbol name=unpack_trivial])" - "1([Symbol name=unpack_trivial])" - "0([Symbol name=unpack_trivial])" - "1([Symbol name=unpack_trivial])" - "2([Symbol name=unpack_sequence])" - "0([Symbol name=unpack_trivial])" -> "2([Symbol name=unpack_sequence])" - "3([Symbol name=unpack_sequence])" - "1([Symbol name=unpack_trivial])" -> "3([Symbol name=unpack_sequence])" - "4([Symbol name=unpack_sequence])" - "2([Symbol name=unpack_sequence])" -> "4([Symbol name=unpack_sequence])" - "5([Symbol name=reshape])" - "3([Symbol name=unpack_sequence])" -> "5([Symbol name=reshape])" - "7([Symbol name=reshape])" - "4([Symbol name=unpack_sequence])" -> "7([Symbol name=reshape])" - "8([Symbol name=reshape])" - "4([Symbol name=unpack_sequence])" -> "8([Symbol name=reshape])" - "9([Symbol name=matmul])" - "4([Symbol name=unpack_sequence])" -> "9([Symbol name=matmul])" - "5([Symbol name=reshape])" -> "9([Symbol name=matmul])" - "12([Symbol name=nvFusion0])" - "11([Symbol name=reshape])" -> "12([Symbol name=nvFusion0])" - "4([Symbol name=unpack_sequence])" -> "12([Symbol name=nvFusion0])" - "19([Symbol name=matmul])" - "4([Symbol name=unpack_sequence])" -> "19([Symbol name=matmul])" - "15([Symbol name=reshape])" -> "19([Symbol name=matmul])" - "20([Symbol name=matmul])" - "4([Symbol name=unpack_sequence])" -> "20([Symbol name=matmul])" - "13([Symbol name=reshape])" -> "20([Symbol name=matmul])" - "6([Symbol name=transpose])" - "5([Symbol name=reshape])" -> "6([Symbol name=transpose])" - "10([Symbol name=matmul])" - "6([Symbol name=transpose])" -> "10([Symbol name=matmul])" - "7([Symbol name=reshape])" -> "10([Symbol name=matmul])" - "17([Symbol name=matmul])" - "16([Symbol name=transpose])" -> "17([Symbol name=matmul])" - "8([Symbol name=reshape])" -> "17([Symbol name=matmul])" - "18([Symbol name=matmul])" - "8([Symbol name=reshape])" -> "18([Symbol name=matmul])" - "14([Symbol name=transpose])" -> "18([Symbol name=matmul])" - "11([Symbol name=reshape])" - "9([Symbol name=matmul])" -> "11([Symbol name=reshape])" - "13([Symbol name=reshape])" - "12([Symbol name=nvFusion0])" -> "13([Symbol name=reshape])" - "15([Symbol name=reshape])" - "12([Symbol name=nvFusion0])" -> "15([Symbol name=reshape])" - "22([Symbol name=reshape])" - "19([Symbol name=matmul])" -> "22([Symbol name=reshape])" - "21([Symbol name=reshape])" - "20([Symbol name=matmul])" -> "21([Symbol name=reshape])" - "24([Symbol name=return])" - "17([Symbol name=matmul])" -> "24([Symbol name=return])" - "18([Symbol name=matmul])" -> "24([Symbol name=return])" - "10([Symbol name=matmul])" -> "24([Symbol name=return])" - "23([Symbol name=nvFusion1])" -> "24([Symbol name=return])" - "14([Symbol name=transpose])" - "13([Symbol name=reshape])" -> "14([Symbol name=transpose])" - "16([Symbol name=transpose])" - "15([Symbol name=reshape])" -> "16([Symbol name=transpose])" - "23([Symbol name=nvFusion1])" - "21([Symbol name=reshape])" -> "23([Symbol name=nvFusion1])" - "22([Symbol name=reshape])" -> "23([Symbol name=nvFusion1])" -} diff --git a/examples/dev/forward_trc b/examples/dev/forward_trc deleted file mode 100644 index 33b077e8ec..0000000000 --- a/examples/dev/forward_trc +++ /dev/null @@ -1,45 +0,0 @@ -digraph { - "0([Symbol name=unpack_trivial])" - "1([Symbol name=unpack_trivial])" - "2([Symbol name=unpack_trivial])" - "3([Symbol name=unpack_trivial])" - "0([Symbol name=unpack_trivial])" - "1([Symbol name=unpack_trivial])" - "2([Symbol name=unpack_trivial])" - "3([Symbol name=unpack_trivial])" - "13([Symbol name=return])" - "0([Symbol name=unpack_trivial])" -> "13([Symbol name=return])" - "1([Symbol name=unpack_trivial])" -> "13([Symbol name=return])" - "2([Symbol name=unpack_trivial])" -> "13([Symbol name=return])" - "3([Symbol name=unpack_trivial])" -> "13([Symbol name=return])" - "4([Symbol name=linear])" -> "13([Symbol name=return])" - "5([Symbol name=linear])" -> "13([Symbol name=return])" - "7([Symbol name=exp])" -> "13([Symbol name=return])" - "9([Symbol name=reciprocal])" -> "13([Symbol name=return])" - "10([Symbol name=mul])" -> "13([Symbol name=return])" - "11([Symbol name=mul])" -> "13([Symbol name=return])" - "12([Symbol name=linear])" -> "13([Symbol name=return])" - "4([Symbol name=linear])" - "0([Symbol name=unpack_trivial])" -> "4([Symbol name=linear])" - "1([Symbol name=unpack_trivial])" -> "4([Symbol name=linear])" - "5([Symbol name=linear])" - "0([Symbol name=unpack_trivial])" -> "5([Symbol name=linear])" - "2([Symbol name=unpack_trivial])" -> "5([Symbol name=linear])" - "12([Symbol name=linear])" - "3([Symbol name=unpack_trivial])" -> "12([Symbol name=linear])" - "11([Symbol name=mul])" -> "12([Symbol name=linear])" - "10([Symbol name=mul])" - "9([Symbol name=reciprocal])" -> "10([Symbol name=mul])" - "4([Symbol name=linear])" -> "10([Symbol name=mul])" - "6([Symbol name=neg])" - "4([Symbol name=linear])" -> "6([Symbol name=neg])" - "11([Symbol name=mul])" - "10([Symbol name=mul])" -> "11([Symbol name=mul])" - "5([Symbol name=linear])" -> "11([Symbol name=mul])" - "7([Symbol name=exp])" - "6([Symbol name=neg])" -> "7([Symbol name=exp])" - "8([Symbol name=add])" - "7([Symbol name=exp])" -> "8([Symbol name=add])" - "9([Symbol name=reciprocal])" - "8([Symbol name=add])" -> "9([Symbol name=reciprocal])" -} diff --git a/examples/dev/forward_trc.dot b/examples/dev/forward_trc.dot deleted file mode 100644 index 866907bfe1..0000000000 --- a/examples/dev/forward_trc.dot +++ /dev/null @@ -1,4518 +0,0 @@ -digraph { - "0([Symbol name=unpack_trivial])" - "1([Symbol name=unpack_trivial])" - "2([Symbol name=unpack_trivial])" - "3([Symbol name=unpack_trivial])" - "4([Symbol name=unpack_trivial])" - "5([Symbol name=unpack_trivial])" - "6([Symbol name=unpack_trivial])" - "7([Symbol name=unpack_trivial])" - "8([Symbol name=unpack_trivial])" - "9([Symbol name=unpack_trivial])" - "10([Symbol name=unpack_trivial])" - "11([Symbol name=unpack_trivial])" - "12([Symbol name=unpack_trivial])" - "13([Symbol name=unpack_trivial])" - "14([Symbol name=unpack_trivial])" - "15([Symbol name=unpack_trivial])" - "16([Symbol name=unpack_trivial])" - "17([Symbol name=unpack_trivial])" - "18([Symbol name=unpack_trivial])" - "19([Symbol name=unpack_trivial])" - "20([Symbol name=unpack_trivial])" - "21([Symbol name=unpack_trivial])" - "22([Symbol name=unpack_trivial])" - "23([Symbol name=unpack_trivial])" - "24([Symbol name=unpack_trivial])" - "25([Symbol name=unpack_trivial])" - "26([Symbol name=unpack_trivial])" - "27([Symbol name=unpack_trivial])" - "28([Symbol name=unpack_trivial])" - "29([Symbol name=unpack_trivial])" - "30([Symbol name=unpack_trivial])" - "31([Symbol name=unpack_trivial])" - "32([Symbol name=unpack_trivial])" - "33([Symbol name=unpack_trivial])" - "34([Symbol name=unpack_trivial])" - "35([Symbol name=unpack_trivial])" - "36([Symbol name=unpack_trivial])" - "37([Symbol name=unpack_trivial])" - "38([Symbol name=unpack_trivial])" - "39([Symbol name=unpack_trivial])" - "40([Symbol name=unpack_trivial])" - "41([Symbol name=unpack_trivial])" - "42([Symbol name=unpack_trivial])" - "43([Symbol name=unpack_trivial])" - "44([Symbol name=unpack_trivial])" - "45([Symbol name=unpack_trivial])" - "46([Symbol name=unpack_trivial])" - "47([Symbol name=unpack_trivial])" - "48([Symbol name=unpack_trivial])" - "49([Symbol name=unpack_trivial])" - "50([Symbol name=unpack_trivial])" - "51([Symbol name=unpack_trivial])" - "52([Symbol name=unpack_trivial])" - "53([Symbol name=unpack_trivial])" - "54([Symbol name=unpack_trivial])" - "55([Symbol name=unpack_trivial])" - "56([Symbol name=unpack_trivial])" - "57([Symbol name=unpack_trivial])" - "58([Symbol name=unpack_trivial])" - "59([Symbol name=unpack_trivial])" - "60([Symbol name=unpack_trivial])" - "61([Symbol name=unpack_trivial])" - "62([Symbol name=unpack_trivial])" - "63([Symbol name=unpack_trivial])" - "64([Symbol name=unpack_trivial])" - "65([Symbol name=unpack_trivial])" - "66([Symbol name=unpack_trivial])" - "67([Symbol name=unpack_trivial])" - "68([Symbol name=unpack_trivial])" - "69([Symbol name=unpack_trivial])" - "70([Symbol name=unpack_trivial])" - "71([Symbol name=unpack_trivial])" - "72([Symbol name=unpack_trivial])" - "73([Symbol name=unpack_trivial])" - "74([Symbol name=unpack_trivial])" - "75([Symbol name=unpack_trivial])" - "76([Symbol name=unpack_trivial])" - "77([Symbol name=unpack_trivial])" - "78([Symbol name=unpack_trivial])" - "79([Symbol name=unpack_trivial])" - "80([Symbol name=unpack_trivial])" - "81([Symbol name=unpack_trivial])" - "82([Symbol name=unpack_trivial])" - "83([Symbol name=unpack_trivial])" - "84([Symbol name=unpack_trivial])" - "85([Symbol name=unpack_trivial])" - "86([Symbol name=unpack_trivial])" - "87([Symbol name=unpack_trivial])" - "88([Symbol name=unpack_trivial])" - "89([Symbol name=unpack_trivial])" - "90([Symbol name=unpack_trivial])" - "91([Symbol name=unpack_trivial])" - "92([Symbol name=unpack_trivial])" - "93([Symbol name=unpack_trivial])" - "94([Symbol name=unpack_trivial])" - "95([Symbol name=unpack_trivial])" - "96([Symbol name=unpack_trivial])" - "97([Symbol name=unpack_trivial])" - "98([Symbol name=unpack_trivial])" - "99([Symbol name=unpack_trivial])" - "100([Symbol name=unpack_trivial])" - "101([Symbol name=unpack_trivial])" - "102([Symbol name=unpack_trivial])" - "103([Symbol name=unpack_trivial])" - "104([Symbol name=unpack_trivial])" - "105([Symbol name=unpack_trivial])" - "106([Symbol name=unpack_trivial])" - "107([Symbol name=unpack_trivial])" - "108([Symbol name=unpack_trivial])" - "109([Symbol name=unpack_trivial])" - "110([Symbol name=unpack_trivial])" - "111([Symbol name=unpack_trivial])" - "112([Symbol name=unpack_trivial])" - "113([Symbol name=unpack_trivial])" - "114([Symbol name=unpack_trivial])" - "115([Symbol name=unpack_trivial])" - "116([Symbol name=unpack_trivial])" - "117([Symbol name=unpack_trivial])" - "0([Symbol name=unpack_trivial])" - "1([Symbol name=unpack_trivial])" - "2([Symbol name=unpack_trivial])" - "3([Symbol name=unpack_trivial])" - "4([Symbol name=unpack_trivial])" - "5([Symbol name=unpack_trivial])" - "6([Symbol name=unpack_trivial])" - "7([Symbol name=unpack_trivial])" - "8([Symbol name=unpack_trivial])" - "9([Symbol name=unpack_trivial])" - "10([Symbol name=unpack_trivial])" - "11([Symbol name=unpack_trivial])" - "12([Symbol name=unpack_trivial])" - "13([Symbol name=unpack_trivial])" - "14([Symbol name=unpack_trivial])" - "15([Symbol name=unpack_trivial])" - "16([Symbol name=unpack_trivial])" - "17([Symbol name=unpack_trivial])" - "18([Symbol name=unpack_trivial])" - "19([Symbol name=unpack_trivial])" - "20([Symbol name=unpack_trivial])" - "21([Symbol name=unpack_trivial])" - "22([Symbol name=unpack_trivial])" - "23([Symbol name=unpack_trivial])" - "24([Symbol name=unpack_trivial])" - "25([Symbol name=unpack_trivial])" - "26([Symbol name=unpack_trivial])" - "27([Symbol name=unpack_trivial])" - "28([Symbol name=unpack_trivial])" - "29([Symbol name=unpack_trivial])" - "30([Symbol name=unpack_trivial])" - "31([Symbol name=unpack_trivial])" - "32([Symbol name=unpack_trivial])" - "33([Symbol name=unpack_trivial])" - "34([Symbol name=unpack_trivial])" - "35([Symbol name=unpack_trivial])" - "36([Symbol name=unpack_trivial])" - "37([Symbol name=unpack_trivial])" - "38([Symbol name=unpack_trivial])" - "39([Symbol name=unpack_trivial])" - "40([Symbol name=unpack_trivial])" - "41([Symbol name=unpack_trivial])" - "42([Symbol name=unpack_trivial])" - "43([Symbol name=unpack_trivial])" - "44([Symbol name=unpack_trivial])" - "45([Symbol name=unpack_trivial])" - "46([Symbol name=unpack_trivial])" - "47([Symbol name=unpack_trivial])" - "48([Symbol name=unpack_trivial])" - "49([Symbol name=unpack_trivial])" - "50([Symbol name=unpack_trivial])" - "51([Symbol name=unpack_trivial])" - "52([Symbol name=unpack_trivial])" - "53([Symbol name=unpack_trivial])" - "54([Symbol name=unpack_trivial])" - "55([Symbol name=unpack_trivial])" - "56([Symbol name=unpack_trivial])" - "57([Symbol name=unpack_trivial])" - "58([Symbol name=unpack_trivial])" - "59([Symbol name=unpack_trivial])" - "60([Symbol name=unpack_trivial])" - "61([Symbol name=unpack_trivial])" - "62([Symbol name=unpack_trivial])" - "63([Symbol name=unpack_trivial])" - "64([Symbol name=unpack_trivial])" - "65([Symbol name=unpack_trivial])" - "66([Symbol name=unpack_trivial])" - "67([Symbol name=unpack_trivial])" - "68([Symbol name=unpack_trivial])" - "69([Symbol name=unpack_trivial])" - "70([Symbol name=unpack_trivial])" - "71([Symbol name=unpack_trivial])" - "72([Symbol name=unpack_trivial])" - "73([Symbol name=unpack_trivial])" - "74([Symbol name=unpack_trivial])" - "75([Symbol name=unpack_trivial])" - "76([Symbol name=unpack_trivial])" - "77([Symbol name=unpack_trivial])" - "78([Symbol name=unpack_trivial])" - "79([Symbol name=unpack_trivial])" - "80([Symbol name=unpack_trivial])" - "81([Symbol name=unpack_trivial])" - "82([Symbol name=unpack_trivial])" - "83([Symbol name=unpack_trivial])" - "84([Symbol name=unpack_trivial])" - "85([Symbol name=unpack_trivial])" - "86([Symbol name=unpack_trivial])" - "87([Symbol name=unpack_trivial])" - "88([Symbol name=unpack_trivial])" - "89([Symbol name=unpack_trivial])" - "90([Symbol name=unpack_trivial])" - "91([Symbol name=unpack_trivial])" - "92([Symbol name=unpack_trivial])" - "93([Symbol name=unpack_trivial])" - "94([Symbol name=unpack_trivial])" - "95([Symbol name=unpack_trivial])" - "96([Symbol name=unpack_trivial])" - "97([Symbol name=unpack_trivial])" - "98([Symbol name=unpack_trivial])" - "99([Symbol name=unpack_trivial])" - "100([Symbol name=unpack_trivial])" - "101([Symbol name=unpack_trivial])" - "102([Symbol name=unpack_trivial])" - "103([Symbol name=unpack_trivial])" - "104([Symbol name=unpack_trivial])" - "105([Symbol name=unpack_trivial])" - "106([Symbol name=unpack_trivial])" - "107([Symbol name=unpack_trivial])" - "108([Symbol name=unpack_trivial])" - "109([Symbol name=unpack_trivial])" - "110([Symbol name=unpack_trivial])" - "111([Symbol name=unpack_trivial])" - "112([Symbol name=unpack_trivial])" - "113([Symbol name=unpack_trivial])" - "114([Symbol name=unpack_trivial])" - "115([Symbol name=unpack_trivial])" - "116([Symbol name=unpack_trivial])" - "117([Symbol name=unpack_trivial])" - "120([Symbol name=embedding])" - "0([Symbol name=unpack_trivial])" -> "120([Symbol name=embedding])" - "117([Symbol name=unpack_trivial])" -> "120([Symbol name=embedding])" - "1737([Symbol name=return])" - "0([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "1([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "2([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "3([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "4([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "5([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "6([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "7([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "8([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "9([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "10([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "11([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "12([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "13([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "14([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "15([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "16([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "17([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "18([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "19([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "20([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "21([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "22([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "23([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "24([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "25([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "26([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "27([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "28([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "29([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "30([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "31([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "32([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "33([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "34([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "35([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "36([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "37([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "38([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "39([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "40([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "41([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "42([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "43([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "44([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "45([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "46([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "47([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "48([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "49([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "50([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "51([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "52([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "53([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "54([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "55([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "56([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "57([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "58([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "59([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "60([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "61([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "62([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "63([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "64([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "65([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "66([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "67([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "68([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "69([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "70([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "71([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "72([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "73([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "74([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "75([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "76([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "77([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "78([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "79([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "80([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "81([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "82([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "83([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "84([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "85([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "86([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "87([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "88([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "89([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "90([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "91([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "92([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "93([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "94([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "95([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "96([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "97([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "98([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "99([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "100([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "101([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "102([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "103([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "104([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "105([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "106([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "107([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "108([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "109([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "110([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "111([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "112([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "113([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "114([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "115([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "116([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "117([Symbol name=unpack_trivial])" -> "1737([Symbol name=return])" - "121([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "127([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "128([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "132([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "133([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "135([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "142([Symbol name=reshape])" -> "1737([Symbol name=return])" - "150([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "151([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "153([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "154([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "165([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "166([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "168([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "169([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "174([Symbol name=cat])" -> "1737([Symbol name=return])" - "176([Symbol name=cat])" -> "1737([Symbol name=return])" - "177([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" - "179([Symbol name=reshape])" -> "1737([Symbol name=return])" - "185([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "191([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "192([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "196([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "197([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "199([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "204([Symbol name=exp])" -> "1737([Symbol name=return])" - "206([Symbol name=reciprocal])" -> "1737([Symbol name=return])" - "208([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "209([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "212([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "213([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "215([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "221([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "227([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "228([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "232([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "233([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "235([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "242([Symbol name=reshape])" -> "1737([Symbol name=return])" - "250([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "251([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "253([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "254([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "265([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "266([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "268([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "269([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "274([Symbol name=cat])" -> "1737([Symbol name=return])" - "276([Symbol name=cat])" -> "1737([Symbol name=return])" - "277([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" - "279([Symbol name=reshape])" -> "1737([Symbol name=return])" - "285([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "291([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "292([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "296([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "297([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "299([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "304([Symbol name=exp])" -> "1737([Symbol name=return])" - "306([Symbol name=reciprocal])" -> "1737([Symbol name=return])" - "308([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "309([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "312([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "313([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "315([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "321([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "327([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "328([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "332([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "333([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "335([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "342([Symbol name=reshape])" -> "1737([Symbol name=return])" - "350([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "351([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "353([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "354([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "365([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "366([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "368([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "369([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "374([Symbol name=cat])" -> "1737([Symbol name=return])" - "376([Symbol name=cat])" -> "1737([Symbol name=return])" - "377([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" - "379([Symbol name=reshape])" -> "1737([Symbol name=return])" - "385([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "391([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "392([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "396([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "397([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "399([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "404([Symbol name=exp])" -> "1737([Symbol name=return])" - "406([Symbol name=reciprocal])" -> "1737([Symbol name=return])" - "408([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "409([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "412([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "413([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "415([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "421([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "427([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "428([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "432([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "433([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "435([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "442([Symbol name=reshape])" -> "1737([Symbol name=return])" - "450([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "451([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "453([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "454([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "465([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "466([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "468([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "469([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "474([Symbol name=cat])" -> "1737([Symbol name=return])" - "476([Symbol name=cat])" -> "1737([Symbol name=return])" - "477([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" - "479([Symbol name=reshape])" -> "1737([Symbol name=return])" - "485([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "491([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "492([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "496([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "497([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "499([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "504([Symbol name=exp])" -> "1737([Symbol name=return])" - "506([Symbol name=reciprocal])" -> "1737([Symbol name=return])" - "508([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "509([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "512([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "513([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "515([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "521([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "527([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "528([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "532([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "533([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "535([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "542([Symbol name=reshape])" -> "1737([Symbol name=return])" - "550([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "551([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "553([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "554([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "565([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "566([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "568([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "569([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "574([Symbol name=cat])" -> "1737([Symbol name=return])" - "576([Symbol name=cat])" -> "1737([Symbol name=return])" - "577([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" - "579([Symbol name=reshape])" -> "1737([Symbol name=return])" - "585([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "591([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "592([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "596([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "597([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "599([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "604([Symbol name=exp])" -> "1737([Symbol name=return])" - "606([Symbol name=reciprocal])" -> "1737([Symbol name=return])" - "608([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "609([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "612([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "613([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "615([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "621([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "627([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "628([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "632([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "633([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "635([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "642([Symbol name=reshape])" -> "1737([Symbol name=return])" - "650([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "651([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "653([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "654([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "665([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "666([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "668([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "669([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "674([Symbol name=cat])" -> "1737([Symbol name=return])" - "676([Symbol name=cat])" -> "1737([Symbol name=return])" - "677([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" - "679([Symbol name=reshape])" -> "1737([Symbol name=return])" - "685([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "691([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "692([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "696([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "697([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "699([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "704([Symbol name=exp])" -> "1737([Symbol name=return])" - "706([Symbol name=reciprocal])" -> "1737([Symbol name=return])" - "708([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "709([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "712([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "713([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "715([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "721([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "727([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "728([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "732([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "733([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "735([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "742([Symbol name=reshape])" -> "1737([Symbol name=return])" - "750([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "751([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "753([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "754([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "765([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "766([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "768([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "769([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "774([Symbol name=cat])" -> "1737([Symbol name=return])" - "776([Symbol name=cat])" -> "1737([Symbol name=return])" - "777([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" - "779([Symbol name=reshape])" -> "1737([Symbol name=return])" - "785([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "791([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "792([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "796([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "797([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "799([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "804([Symbol name=exp])" -> "1737([Symbol name=return])" - "806([Symbol name=reciprocal])" -> "1737([Symbol name=return])" - "808([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "809([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "812([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "813([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "815([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "821([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "827([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "828([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "832([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "833([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "835([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "842([Symbol name=reshape])" -> "1737([Symbol name=return])" - "850([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "851([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "853([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "854([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "865([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "866([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "868([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "869([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "874([Symbol name=cat])" -> "1737([Symbol name=return])" - "876([Symbol name=cat])" -> "1737([Symbol name=return])" - "877([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" - "879([Symbol name=reshape])" -> "1737([Symbol name=return])" - "885([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "891([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "892([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "896([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "897([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "899([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "904([Symbol name=exp])" -> "1737([Symbol name=return])" - "906([Symbol name=reciprocal])" -> "1737([Symbol name=return])" - "908([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "909([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "912([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "913([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "915([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "921([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "927([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "928([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "932([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "933([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "935([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "942([Symbol name=reshape])" -> "1737([Symbol name=return])" - "950([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "951([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "953([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "954([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "965([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "966([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "968([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "969([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "974([Symbol name=cat])" -> "1737([Symbol name=return])" - "976([Symbol name=cat])" -> "1737([Symbol name=return])" - "977([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" - "979([Symbol name=reshape])" -> "1737([Symbol name=return])" - "985([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "991([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "992([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "996([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "997([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "999([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1004([Symbol name=exp])" -> "1737([Symbol name=return])" - "1006([Symbol name=reciprocal])" -> "1737([Symbol name=return])" - "1008([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1009([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1012([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1013([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1015([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1021([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1027([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "1028([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1032([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1033([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1035([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1042([Symbol name=reshape])" -> "1737([Symbol name=return])" - "1050([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1051([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1053([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1054([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1065([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1066([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1068([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1069([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1074([Symbol name=cat])" -> "1737([Symbol name=return])" - "1076([Symbol name=cat])" -> "1737([Symbol name=return])" - "1077([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" - "1079([Symbol name=reshape])" -> "1737([Symbol name=return])" - "1085([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1091([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "1092([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1096([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1097([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1099([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1104([Symbol name=exp])" -> "1737([Symbol name=return])" - "1106([Symbol name=reciprocal])" -> "1737([Symbol name=return])" - "1108([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1109([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1112([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1113([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1115([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1121([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1127([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "1128([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1132([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1133([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1135([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1142([Symbol name=reshape])" -> "1737([Symbol name=return])" - "1150([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1151([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1153([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1154([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1165([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1166([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1168([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1169([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1174([Symbol name=cat])" -> "1737([Symbol name=return])" - "1176([Symbol name=cat])" -> "1737([Symbol name=return])" - "1177([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" - "1179([Symbol name=reshape])" -> "1737([Symbol name=return])" - "1185([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1191([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "1192([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1196([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1197([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1199([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1204([Symbol name=exp])" -> "1737([Symbol name=return])" - "1206([Symbol name=reciprocal])" -> "1737([Symbol name=return])" - "1208([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1209([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1212([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1213([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1215([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1221([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1227([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "1228([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1232([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1233([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1235([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1242([Symbol name=reshape])" -> "1737([Symbol name=return])" - "1250([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1251([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1253([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1254([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1265([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1266([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1268([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1269([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1274([Symbol name=cat])" -> "1737([Symbol name=return])" - "1276([Symbol name=cat])" -> "1737([Symbol name=return])" - "1277([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" - "1279([Symbol name=reshape])" -> "1737([Symbol name=return])" - "1285([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1291([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "1292([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1296([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1297([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1299([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1304([Symbol name=exp])" -> "1737([Symbol name=return])" - "1306([Symbol name=reciprocal])" -> "1737([Symbol name=return])" - "1308([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1309([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1312([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1313([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1315([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1321([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1327([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "1328([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1332([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1333([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1335([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1342([Symbol name=reshape])" -> "1737([Symbol name=return])" - "1350([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1351([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1353([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1354([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1365([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1366([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1368([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1369([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1374([Symbol name=cat])" -> "1737([Symbol name=return])" - "1376([Symbol name=cat])" -> "1737([Symbol name=return])" - "1377([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" - "1379([Symbol name=reshape])" -> "1737([Symbol name=return])" - "1385([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1391([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "1392([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1396([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1397([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1399([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1404([Symbol name=exp])" -> "1737([Symbol name=return])" - "1406([Symbol name=reciprocal])" -> "1737([Symbol name=return])" - "1408([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1409([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1412([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1413([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1415([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1421([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1427([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "1428([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1432([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1433([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1435([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1442([Symbol name=reshape])" -> "1737([Symbol name=return])" - "1450([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1451([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1453([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1454([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1465([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1466([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1468([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1469([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1474([Symbol name=cat])" -> "1737([Symbol name=return])" - "1476([Symbol name=cat])" -> "1737([Symbol name=return])" - "1477([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" - "1479([Symbol name=reshape])" -> "1737([Symbol name=return])" - "1485([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1491([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "1492([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1496([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1497([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1499([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1504([Symbol name=exp])" -> "1737([Symbol name=return])" - "1506([Symbol name=reciprocal])" -> "1737([Symbol name=return])" - "1508([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1509([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1512([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1513([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1515([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1521([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1527([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "1528([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1532([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1533([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1535([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1542([Symbol name=reshape])" -> "1737([Symbol name=return])" - "1550([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1551([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1553([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1554([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1565([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1566([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1568([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1569([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1574([Symbol name=cat])" -> "1737([Symbol name=return])" - "1576([Symbol name=cat])" -> "1737([Symbol name=return])" - "1577([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" - "1579([Symbol name=reshape])" -> "1737([Symbol name=return])" - "1585([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1591([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "1592([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1596([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1597([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1599([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1604([Symbol name=exp])" -> "1737([Symbol name=return])" - "1606([Symbol name=reciprocal])" -> "1737([Symbol name=return])" - "1608([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1609([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1612([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1613([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1615([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1621([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1627([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "1628([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1632([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1633([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1635([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1642([Symbol name=reshape])" -> "1737([Symbol name=return])" - "1650([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1651([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1653([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1654([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1665([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1666([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1668([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1669([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1674([Symbol name=cat])" -> "1737([Symbol name=return])" - "1676([Symbol name=cat])" -> "1737([Symbol name=return])" - "1677([Symbol name=cudnn_sdpa_fwd])" -> "1737([Symbol name=return])" - "1679([Symbol name=reshape])" -> "1737([Symbol name=return])" - "1685([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1691([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "1692([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1696([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1697([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1699([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1704([Symbol name=exp])" -> "1737([Symbol name=return])" - "1706([Symbol name=reciprocal])" -> "1737([Symbol name=return])" - "1708([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1709([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1712([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1713([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1715([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1721([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1727([Symbol name=rsqrt])" -> "1737([Symbol name=return])" - "1728([Symbol name=broadcast_in_dim])" -> "1737([Symbol name=return])" - "1732([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1733([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1735([Symbol name=convert_element_type])" -> "1737([Symbol name=return])" - "1736([Symbol name=linear])" -> "1737([Symbol name=return])" - "118([Symbol name=slice_prim])" - "1([Symbol name=unpack_trivial])" -> "118([Symbol name=slice_prim])" - "1736([Symbol name=linear])" - "2([Symbol name=unpack_trivial])" -> "1736([Symbol name=linear])" - "1735([Symbol name=convert_element_type])" -> "1736([Symbol name=linear])" - "119([Symbol name=slice_prim])" - "3([Symbol name=unpack_trivial])" -> "119([Symbol name=slice_prim])" - "136([Symbol name=linear])" - "4([Symbol name=unpack_trivial])" -> "136([Symbol name=linear])" - "135([Symbol name=convert_element_type])" -> "136([Symbol name=linear])" - "180([Symbol name=linear])" - "179([Symbol name=reshape])" -> "180([Symbol name=linear])" - "5([Symbol name=unpack_trivial])" -> "180([Symbol name=linear])" - "200([Symbol name=linear])" - "6([Symbol name=unpack_trivial])" -> "200([Symbol name=linear])" - "199([Symbol name=convert_element_type])" -> "200([Symbol name=linear])" - "201([Symbol name=linear])" - "7([Symbol name=unpack_trivial])" -> "201([Symbol name=linear])" - "199([Symbol name=convert_element_type])" -> "201([Symbol name=linear])" - "216([Symbol name=linear])" - "8([Symbol name=unpack_trivial])" -> "216([Symbol name=linear])" - "215([Symbol name=convert_element_type])" -> "216([Symbol name=linear])" - "131([Symbol name=broadcast_in_dim])" - "9([Symbol name=unpack_trivial])" -> "131([Symbol name=broadcast_in_dim])" - "195([Symbol name=broadcast_in_dim])" - "10([Symbol name=unpack_trivial])" -> "195([Symbol name=broadcast_in_dim])" - "236([Symbol name=linear])" - "11([Symbol name=unpack_trivial])" -> "236([Symbol name=linear])" - "235([Symbol name=convert_element_type])" -> "236([Symbol name=linear])" - "280([Symbol name=linear])" - "12([Symbol name=unpack_trivial])" -> "280([Symbol name=linear])" - "279([Symbol name=reshape])" -> "280([Symbol name=linear])" - "300([Symbol name=linear])" - "299([Symbol name=convert_element_type])" -> "300([Symbol name=linear])" - "13([Symbol name=unpack_trivial])" -> "300([Symbol name=linear])" - "301([Symbol name=linear])" - "299([Symbol name=convert_element_type])" -> "301([Symbol name=linear])" - "14([Symbol name=unpack_trivial])" -> "301([Symbol name=linear])" - "316([Symbol name=linear])" - "315([Symbol name=convert_element_type])" -> "316([Symbol name=linear])" - "15([Symbol name=unpack_trivial])" -> "316([Symbol name=linear])" - "231([Symbol name=broadcast_in_dim])" - "16([Symbol name=unpack_trivial])" -> "231([Symbol name=broadcast_in_dim])" - "295([Symbol name=broadcast_in_dim])" - "17([Symbol name=unpack_trivial])" -> "295([Symbol name=broadcast_in_dim])" - "336([Symbol name=linear])" - "18([Symbol name=unpack_trivial])" -> "336([Symbol name=linear])" - "335([Symbol name=convert_element_type])" -> "336([Symbol name=linear])" - "380([Symbol name=linear])" - "19([Symbol name=unpack_trivial])" -> "380([Symbol name=linear])" - "379([Symbol name=reshape])" -> "380([Symbol name=linear])" - "400([Symbol name=linear])" - "20([Symbol name=unpack_trivial])" -> "400([Symbol name=linear])" - "399([Symbol name=convert_element_type])" -> "400([Symbol name=linear])" - "401([Symbol name=linear])" - "21([Symbol name=unpack_trivial])" -> "401([Symbol name=linear])" - "399([Symbol name=convert_element_type])" -> "401([Symbol name=linear])" - "416([Symbol name=linear])" - "22([Symbol name=unpack_trivial])" -> "416([Symbol name=linear])" - "415([Symbol name=convert_element_type])" -> "416([Symbol name=linear])" - "331([Symbol name=broadcast_in_dim])" - "23([Symbol name=unpack_trivial])" -> "331([Symbol name=broadcast_in_dim])" - "395([Symbol name=broadcast_in_dim])" - "24([Symbol name=unpack_trivial])" -> "395([Symbol name=broadcast_in_dim])" - "436([Symbol name=linear])" - "25([Symbol name=unpack_trivial])" -> "436([Symbol name=linear])" - "435([Symbol name=convert_element_type])" -> "436([Symbol name=linear])" - "480([Symbol name=linear])" - "26([Symbol name=unpack_trivial])" -> "480([Symbol name=linear])" - "479([Symbol name=reshape])" -> "480([Symbol name=linear])" - "500([Symbol name=linear])" - "27([Symbol name=unpack_trivial])" -> "500([Symbol name=linear])" - "499([Symbol name=convert_element_type])" -> "500([Symbol name=linear])" - "501([Symbol name=linear])" - "499([Symbol name=convert_element_type])" -> "501([Symbol name=linear])" - "28([Symbol name=unpack_trivial])" -> "501([Symbol name=linear])" - "516([Symbol name=linear])" - "515([Symbol name=convert_element_type])" -> "516([Symbol name=linear])" - "29([Symbol name=unpack_trivial])" -> "516([Symbol name=linear])" - "431([Symbol name=broadcast_in_dim])" - "30([Symbol name=unpack_trivial])" -> "431([Symbol name=broadcast_in_dim])" - "495([Symbol name=broadcast_in_dim])" - "31([Symbol name=unpack_trivial])" -> "495([Symbol name=broadcast_in_dim])" - "536([Symbol name=linear])" - "32([Symbol name=unpack_trivial])" -> "536([Symbol name=linear])" - "535([Symbol name=convert_element_type])" -> "536([Symbol name=linear])" - "580([Symbol name=linear])" - "33([Symbol name=unpack_trivial])" -> "580([Symbol name=linear])" - "579([Symbol name=reshape])" -> "580([Symbol name=linear])" - "600([Symbol name=linear])" - "34([Symbol name=unpack_trivial])" -> "600([Symbol name=linear])" - "599([Symbol name=convert_element_type])" -> "600([Symbol name=linear])" - "601([Symbol name=linear])" - "35([Symbol name=unpack_trivial])" -> "601([Symbol name=linear])" - "599([Symbol name=convert_element_type])" -> "601([Symbol name=linear])" - "616([Symbol name=linear])" - "36([Symbol name=unpack_trivial])" -> "616([Symbol name=linear])" - "615([Symbol name=convert_element_type])" -> "616([Symbol name=linear])" - "531([Symbol name=broadcast_in_dim])" - "37([Symbol name=unpack_trivial])" -> "531([Symbol name=broadcast_in_dim])" - "595([Symbol name=broadcast_in_dim])" - "38([Symbol name=unpack_trivial])" -> "595([Symbol name=broadcast_in_dim])" - "636([Symbol name=linear])" - "635([Symbol name=convert_element_type])" -> "636([Symbol name=linear])" - "39([Symbol name=unpack_trivial])" -> "636([Symbol name=linear])" - "680([Symbol name=linear])" - "40([Symbol name=unpack_trivial])" -> "680([Symbol name=linear])" - "679([Symbol name=reshape])" -> "680([Symbol name=linear])" - "700([Symbol name=linear])" - "41([Symbol name=unpack_trivial])" -> "700([Symbol name=linear])" - "699([Symbol name=convert_element_type])" -> "700([Symbol name=linear])" - "701([Symbol name=linear])" - "42([Symbol name=unpack_trivial])" -> "701([Symbol name=linear])" - "699([Symbol name=convert_element_type])" -> "701([Symbol name=linear])" - "716([Symbol name=linear])" - "43([Symbol name=unpack_trivial])" -> "716([Symbol name=linear])" - "715([Symbol name=convert_element_type])" -> "716([Symbol name=linear])" - "631([Symbol name=broadcast_in_dim])" - "44([Symbol name=unpack_trivial])" -> "631([Symbol name=broadcast_in_dim])" - "695([Symbol name=broadcast_in_dim])" - "45([Symbol name=unpack_trivial])" -> "695([Symbol name=broadcast_in_dim])" - "736([Symbol name=linear])" - "46([Symbol name=unpack_trivial])" -> "736([Symbol name=linear])" - "735([Symbol name=convert_element_type])" -> "736([Symbol name=linear])" - "780([Symbol name=linear])" - "779([Symbol name=reshape])" -> "780([Symbol name=linear])" - "47([Symbol name=unpack_trivial])" -> "780([Symbol name=linear])" - "800([Symbol name=linear])" - "48([Symbol name=unpack_trivial])" -> "800([Symbol name=linear])" - "799([Symbol name=convert_element_type])" -> "800([Symbol name=linear])" - "801([Symbol name=linear])" - "49([Symbol name=unpack_trivial])" -> "801([Symbol name=linear])" - "799([Symbol name=convert_element_type])" -> "801([Symbol name=linear])" - "816([Symbol name=linear])" - "50([Symbol name=unpack_trivial])" -> "816([Symbol name=linear])" - "815([Symbol name=convert_element_type])" -> "816([Symbol name=linear])" - "731([Symbol name=broadcast_in_dim])" - "51([Symbol name=unpack_trivial])" -> "731([Symbol name=broadcast_in_dim])" - "795([Symbol name=broadcast_in_dim])" - "52([Symbol name=unpack_trivial])" -> "795([Symbol name=broadcast_in_dim])" - "836([Symbol name=linear])" - "835([Symbol name=convert_element_type])" -> "836([Symbol name=linear])" - "53([Symbol name=unpack_trivial])" -> "836([Symbol name=linear])" - "880([Symbol name=linear])" - "54([Symbol name=unpack_trivial])" -> "880([Symbol name=linear])" - "879([Symbol name=reshape])" -> "880([Symbol name=linear])" - "900([Symbol name=linear])" - "899([Symbol name=convert_element_type])" -> "900([Symbol name=linear])" - "55([Symbol name=unpack_trivial])" -> "900([Symbol name=linear])" - "901([Symbol name=linear])" - "56([Symbol name=unpack_trivial])" -> "901([Symbol name=linear])" - "899([Symbol name=convert_element_type])" -> "901([Symbol name=linear])" - "916([Symbol name=linear])" - "57([Symbol name=unpack_trivial])" -> "916([Symbol name=linear])" - "915([Symbol name=convert_element_type])" -> "916([Symbol name=linear])" - "831([Symbol name=broadcast_in_dim])" - "58([Symbol name=unpack_trivial])" -> "831([Symbol name=broadcast_in_dim])" - "895([Symbol name=broadcast_in_dim])" - "59([Symbol name=unpack_trivial])" -> "895([Symbol name=broadcast_in_dim])" - "936([Symbol name=linear])" - "60([Symbol name=unpack_trivial])" -> "936([Symbol name=linear])" - "935([Symbol name=convert_element_type])" -> "936([Symbol name=linear])" - "980([Symbol name=linear])" - "979([Symbol name=reshape])" -> "980([Symbol name=linear])" - "61([Symbol name=unpack_trivial])" -> "980([Symbol name=linear])" - "1000([Symbol name=linear])" - "62([Symbol name=unpack_trivial])" -> "1000([Symbol name=linear])" - "999([Symbol name=convert_element_type])" -> "1000([Symbol name=linear])" - "1001([Symbol name=linear])" - "63([Symbol name=unpack_trivial])" -> "1001([Symbol name=linear])" - "999([Symbol name=convert_element_type])" -> "1001([Symbol name=linear])" - "1016([Symbol name=linear])" - "64([Symbol name=unpack_trivial])" -> "1016([Symbol name=linear])" - "1015([Symbol name=convert_element_type])" -> "1016([Symbol name=linear])" - "931([Symbol name=broadcast_in_dim])" - "65([Symbol name=unpack_trivial])" -> "931([Symbol name=broadcast_in_dim])" - "995([Symbol name=broadcast_in_dim])" - "66([Symbol name=unpack_trivial])" -> "995([Symbol name=broadcast_in_dim])" - "1036([Symbol name=linear])" - "67([Symbol name=unpack_trivial])" -> "1036([Symbol name=linear])" - "1035([Symbol name=convert_element_type])" -> "1036([Symbol name=linear])" - "1080([Symbol name=linear])" - "68([Symbol name=unpack_trivial])" -> "1080([Symbol name=linear])" - "1079([Symbol name=reshape])" -> "1080([Symbol name=linear])" - "1100([Symbol name=linear])" - "1099([Symbol name=convert_element_type])" -> "1100([Symbol name=linear])" - "69([Symbol name=unpack_trivial])" -> "1100([Symbol name=linear])" - "1101([Symbol name=linear])" - "1099([Symbol name=convert_element_type])" -> "1101([Symbol name=linear])" - "70([Symbol name=unpack_trivial])" -> "1101([Symbol name=linear])" - "1116([Symbol name=linear])" - "1115([Symbol name=convert_element_type])" -> "1116([Symbol name=linear])" - "71([Symbol name=unpack_trivial])" -> "1116([Symbol name=linear])" - "1031([Symbol name=broadcast_in_dim])" - "72([Symbol name=unpack_trivial])" -> "1031([Symbol name=broadcast_in_dim])" - "1095([Symbol name=broadcast_in_dim])" - "73([Symbol name=unpack_trivial])" -> "1095([Symbol name=broadcast_in_dim])" - "1136([Symbol name=linear])" - "74([Symbol name=unpack_trivial])" -> "1136([Symbol name=linear])" - "1135([Symbol name=convert_element_type])" -> "1136([Symbol name=linear])" - "1180([Symbol name=linear])" - "75([Symbol name=unpack_trivial])" -> "1180([Symbol name=linear])" - "1179([Symbol name=reshape])" -> "1180([Symbol name=linear])" - "1200([Symbol name=linear])" - "76([Symbol name=unpack_trivial])" -> "1200([Symbol name=linear])" - "1199([Symbol name=convert_element_type])" -> "1200([Symbol name=linear])" - "1201([Symbol name=linear])" - "77([Symbol name=unpack_trivial])" -> "1201([Symbol name=linear])" - "1199([Symbol name=convert_element_type])" -> "1201([Symbol name=linear])" - "1216([Symbol name=linear])" - "78([Symbol name=unpack_trivial])" -> "1216([Symbol name=linear])" - "1215([Symbol name=convert_element_type])" -> "1216([Symbol name=linear])" - "1131([Symbol name=broadcast_in_dim])" - "79([Symbol name=unpack_trivial])" -> "1131([Symbol name=broadcast_in_dim])" - "1195([Symbol name=broadcast_in_dim])" - "80([Symbol name=unpack_trivial])" -> "1195([Symbol name=broadcast_in_dim])" - "1236([Symbol name=linear])" - "81([Symbol name=unpack_trivial])" -> "1236([Symbol name=linear])" - "1235([Symbol name=convert_element_type])" -> "1236([Symbol name=linear])" - "1280([Symbol name=linear])" - "82([Symbol name=unpack_trivial])" -> "1280([Symbol name=linear])" - "1279([Symbol name=reshape])" -> "1280([Symbol name=linear])" - "1300([Symbol name=linear])" - "83([Symbol name=unpack_trivial])" -> "1300([Symbol name=linear])" - "1299([Symbol name=convert_element_type])" -> "1300([Symbol name=linear])" - "1301([Symbol name=linear])" - "1299([Symbol name=convert_element_type])" -> "1301([Symbol name=linear])" - "84([Symbol name=unpack_trivial])" -> "1301([Symbol name=linear])" - "1316([Symbol name=linear])" - "1315([Symbol name=convert_element_type])" -> "1316([Symbol name=linear])" - "85([Symbol name=unpack_trivial])" -> "1316([Symbol name=linear])" - "1231([Symbol name=broadcast_in_dim])" - "86([Symbol name=unpack_trivial])" -> "1231([Symbol name=broadcast_in_dim])" - "1295([Symbol name=broadcast_in_dim])" - "87([Symbol name=unpack_trivial])" -> "1295([Symbol name=broadcast_in_dim])" - "1336([Symbol name=linear])" - "88([Symbol name=unpack_trivial])" -> "1336([Symbol name=linear])" - "1335([Symbol name=convert_element_type])" -> "1336([Symbol name=linear])" - "1380([Symbol name=linear])" - "89([Symbol name=unpack_trivial])" -> "1380([Symbol name=linear])" - "1379([Symbol name=reshape])" -> "1380([Symbol name=linear])" - "1400([Symbol name=linear])" - "90([Symbol name=unpack_trivial])" -> "1400([Symbol name=linear])" - "1399([Symbol name=convert_element_type])" -> "1400([Symbol name=linear])" - "1401([Symbol name=linear])" - "91([Symbol name=unpack_trivial])" -> "1401([Symbol name=linear])" - "1399([Symbol name=convert_element_type])" -> "1401([Symbol name=linear])" - "1416([Symbol name=linear])" - "92([Symbol name=unpack_trivial])" -> "1416([Symbol name=linear])" - "1415([Symbol name=convert_element_type])" -> "1416([Symbol name=linear])" - "1331([Symbol name=broadcast_in_dim])" - "93([Symbol name=unpack_trivial])" -> "1331([Symbol name=broadcast_in_dim])" - "1395([Symbol name=broadcast_in_dim])" - "94([Symbol name=unpack_trivial])" -> "1395([Symbol name=broadcast_in_dim])" - "1436([Symbol name=linear])" - "1435([Symbol name=convert_element_type])" -> "1436([Symbol name=linear])" - "95([Symbol name=unpack_trivial])" -> "1436([Symbol name=linear])" - "1480([Symbol name=linear])" - "96([Symbol name=unpack_trivial])" -> "1480([Symbol name=linear])" - "1479([Symbol name=reshape])" -> "1480([Symbol name=linear])" - "1500([Symbol name=linear])" - "97([Symbol name=unpack_trivial])" -> "1500([Symbol name=linear])" - "1499([Symbol name=convert_element_type])" -> "1500([Symbol name=linear])" - "1501([Symbol name=linear])" - "98([Symbol name=unpack_trivial])" -> "1501([Symbol name=linear])" - "1499([Symbol name=convert_element_type])" -> "1501([Symbol name=linear])" - "1516([Symbol name=linear])" - "99([Symbol name=unpack_trivial])" -> "1516([Symbol name=linear])" - "1515([Symbol name=convert_element_type])" -> "1516([Symbol name=linear])" - "1431([Symbol name=broadcast_in_dim])" - "100([Symbol name=unpack_trivial])" -> "1431([Symbol name=broadcast_in_dim])" - "1495([Symbol name=broadcast_in_dim])" - "101([Symbol name=unpack_trivial])" -> "1495([Symbol name=broadcast_in_dim])" - "1536([Symbol name=linear])" - "102([Symbol name=unpack_trivial])" -> "1536([Symbol name=linear])" - "1535([Symbol name=convert_element_type])" -> "1536([Symbol name=linear])" - "1580([Symbol name=linear])" - "1579([Symbol name=reshape])" -> "1580([Symbol name=linear])" - "103([Symbol name=unpack_trivial])" -> "1580([Symbol name=linear])" - "1600([Symbol name=linear])" - "104([Symbol name=unpack_trivial])" -> "1600([Symbol name=linear])" - "1599([Symbol name=convert_element_type])" -> "1600([Symbol name=linear])" - "1601([Symbol name=linear])" - "105([Symbol name=unpack_trivial])" -> "1601([Symbol name=linear])" - "1599([Symbol name=convert_element_type])" -> "1601([Symbol name=linear])" - "1616([Symbol name=linear])" - "106([Symbol name=unpack_trivial])" -> "1616([Symbol name=linear])" - "1615([Symbol name=convert_element_type])" -> "1616([Symbol name=linear])" - "1531([Symbol name=broadcast_in_dim])" - "107([Symbol name=unpack_trivial])" -> "1531([Symbol name=broadcast_in_dim])" - "1595([Symbol name=broadcast_in_dim])" - "108([Symbol name=unpack_trivial])" -> "1595([Symbol name=broadcast_in_dim])" - "1636([Symbol name=linear])" - "1635([Symbol name=convert_element_type])" -> "1636([Symbol name=linear])" - "109([Symbol name=unpack_trivial])" -> "1636([Symbol name=linear])" - "1680([Symbol name=linear])" - "110([Symbol name=unpack_trivial])" -> "1680([Symbol name=linear])" - "1679([Symbol name=reshape])" -> "1680([Symbol name=linear])" - "1700([Symbol name=linear])" - "1699([Symbol name=convert_element_type])" -> "1700([Symbol name=linear])" - "111([Symbol name=unpack_trivial])" -> "1700([Symbol name=linear])" - "1701([Symbol name=linear])" - "112([Symbol name=unpack_trivial])" -> "1701([Symbol name=linear])" - "1699([Symbol name=convert_element_type])" -> "1701([Symbol name=linear])" - "1716([Symbol name=linear])" - "113([Symbol name=unpack_trivial])" -> "1716([Symbol name=linear])" - "1715([Symbol name=convert_element_type])" -> "1716([Symbol name=linear])" - "1631([Symbol name=broadcast_in_dim])" - "114([Symbol name=unpack_trivial])" -> "1631([Symbol name=broadcast_in_dim])" - "1695([Symbol name=broadcast_in_dim])" - "115([Symbol name=unpack_trivial])" -> "1695([Symbol name=broadcast_in_dim])" - "1731([Symbol name=broadcast_in_dim])" - "116([Symbol name=unpack_trivial])" -> "1731([Symbol name=broadcast_in_dim])" - "121([Symbol name=convert_element_type])" - "120([Symbol name=embedding])" -> "121([Symbol name=convert_element_type])" - "182([Symbol name=convert_element_type])" - "120([Symbol name=embedding])" -> "182([Symbol name=convert_element_type])" - "1665([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "1665([Symbol name=broadcast_in_dim])" - "265([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "265([Symbol name=broadcast_in_dim])" - "650([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "650([Symbol name=broadcast_in_dim])" - "1165([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "1165([Symbol name=broadcast_in_dim])" - "1550([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "1550([Symbol name=broadcast_in_dim])" - "150([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "150([Symbol name=broadcast_in_dim])" - "665([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "665([Symbol name=broadcast_in_dim])" - "1050([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "1050([Symbol name=broadcast_in_dim])" - "1565([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "1565([Symbol name=broadcast_in_dim])" - "165([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "165([Symbol name=broadcast_in_dim])" - "550([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "550([Symbol name=broadcast_in_dim])" - "1065([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "1065([Symbol name=broadcast_in_dim])" - "1450([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "1450([Symbol name=broadcast_in_dim])" - "565([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "565([Symbol name=broadcast_in_dim])" - "950([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "950([Symbol name=broadcast_in_dim])" - "1465([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "1465([Symbol name=broadcast_in_dim])" - "450([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "450([Symbol name=broadcast_in_dim])" - "965([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "965([Symbol name=broadcast_in_dim])" - "1350([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "1350([Symbol name=broadcast_in_dim])" - "465([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "465([Symbol name=broadcast_in_dim])" - "850([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "850([Symbol name=broadcast_in_dim])" - "1365([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "1365([Symbol name=broadcast_in_dim])" - "350([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "350([Symbol name=broadcast_in_dim])" - "865([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "865([Symbol name=broadcast_in_dim])" - "1250([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "1250([Symbol name=broadcast_in_dim])" - "365([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "365([Symbol name=broadcast_in_dim])" - "750([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "750([Symbol name=broadcast_in_dim])" - "1265([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "1265([Symbol name=broadcast_in_dim])" - "1650([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "1650([Symbol name=broadcast_in_dim])" - "250([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "250([Symbol name=broadcast_in_dim])" - "765([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "765([Symbol name=broadcast_in_dim])" - "1150([Symbol name=broadcast_in_dim])" - "118([Symbol name=slice_prim])" -> "1150([Symbol name=broadcast_in_dim])" - "768([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "768([Symbol name=broadcast_in_dim])" - "1153([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "1153([Symbol name=broadcast_in_dim])" - "1668([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "1668([Symbol name=broadcast_in_dim])" - "268([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "268([Symbol name=broadcast_in_dim])" - "653([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "653([Symbol name=broadcast_in_dim])" - "1168([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "1168([Symbol name=broadcast_in_dim])" - "1553([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "1553([Symbol name=broadcast_in_dim])" - "153([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "153([Symbol name=broadcast_in_dim])" - "668([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "668([Symbol name=broadcast_in_dim])" - "1053([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "1053([Symbol name=broadcast_in_dim])" - "1568([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "1568([Symbol name=broadcast_in_dim])" - "168([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "168([Symbol name=broadcast_in_dim])" - "553([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "553([Symbol name=broadcast_in_dim])" - "1068([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "1068([Symbol name=broadcast_in_dim])" - "1453([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "1453([Symbol name=broadcast_in_dim])" - "568([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "568([Symbol name=broadcast_in_dim])" - "953([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "953([Symbol name=broadcast_in_dim])" - "1468([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "1468([Symbol name=broadcast_in_dim])" - "453([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "453([Symbol name=broadcast_in_dim])" - "968([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "968([Symbol name=broadcast_in_dim])" - "1353([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "1353([Symbol name=broadcast_in_dim])" - "468([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "468([Symbol name=broadcast_in_dim])" - "853([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "853([Symbol name=broadcast_in_dim])" - "1368([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "1368([Symbol name=broadcast_in_dim])" - "353([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "353([Symbol name=broadcast_in_dim])" - "868([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "868([Symbol name=broadcast_in_dim])" - "1253([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "1253([Symbol name=broadcast_in_dim])" - "368([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "368([Symbol name=broadcast_in_dim])" - "753([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "753([Symbol name=broadcast_in_dim])" - "1268([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "1268([Symbol name=broadcast_in_dim])" - "1653([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "1653([Symbol name=broadcast_in_dim])" - "253([Symbol name=broadcast_in_dim])" - "119([Symbol name=slice_prim])" -> "253([Symbol name=broadcast_in_dim])" - "137([Symbol name=reshape])" - "136([Symbol name=linear])" -> "137([Symbol name=reshape])" - "181([Symbol name=convert_element_type])" - "180([Symbol name=linear])" -> "181([Symbol name=convert_element_type])" - "208([Symbol name=convert_element_type])" - "200([Symbol name=linear])" -> "208([Symbol name=convert_element_type])" - "202([Symbol name=convert_element_type])" - "200([Symbol name=linear])" -> "202([Symbol name=convert_element_type])" - "213([Symbol name=convert_element_type])" - "201([Symbol name=linear])" -> "213([Symbol name=convert_element_type])" - "217([Symbol name=convert_element_type])" - "216([Symbol name=linear])" -> "217([Symbol name=convert_element_type])" - "133([Symbol name=convert_element_type])" - "131([Symbol name=broadcast_in_dim])" -> "133([Symbol name=convert_element_type])" - "197([Symbol name=convert_element_type])" - "195([Symbol name=broadcast_in_dim])" -> "197([Symbol name=convert_element_type])" - "237([Symbol name=reshape])" - "236([Symbol name=linear])" -> "237([Symbol name=reshape])" - "281([Symbol name=convert_element_type])" - "280([Symbol name=linear])" -> "281([Symbol name=convert_element_type])" - "308([Symbol name=convert_element_type])" - "300([Symbol name=linear])" -> "308([Symbol name=convert_element_type])" - "302([Symbol name=convert_element_type])" - "300([Symbol name=linear])" -> "302([Symbol name=convert_element_type])" - "313([Symbol name=convert_element_type])" - "301([Symbol name=linear])" -> "313([Symbol name=convert_element_type])" - "317([Symbol name=convert_element_type])" - "316([Symbol name=linear])" -> "317([Symbol name=convert_element_type])" - "233([Symbol name=convert_element_type])" - "231([Symbol name=broadcast_in_dim])" -> "233([Symbol name=convert_element_type])" - "297([Symbol name=convert_element_type])" - "295([Symbol name=broadcast_in_dim])" -> "297([Symbol name=convert_element_type])" - "337([Symbol name=reshape])" - "336([Symbol name=linear])" -> "337([Symbol name=reshape])" - "381([Symbol name=convert_element_type])" - "380([Symbol name=linear])" -> "381([Symbol name=convert_element_type])" - "408([Symbol name=convert_element_type])" - "400([Symbol name=linear])" -> "408([Symbol name=convert_element_type])" - "402([Symbol name=convert_element_type])" - "400([Symbol name=linear])" -> "402([Symbol name=convert_element_type])" - "413([Symbol name=convert_element_type])" - "401([Symbol name=linear])" -> "413([Symbol name=convert_element_type])" - "417([Symbol name=convert_element_type])" - "416([Symbol name=linear])" -> "417([Symbol name=convert_element_type])" - "333([Symbol name=convert_element_type])" - "331([Symbol name=broadcast_in_dim])" -> "333([Symbol name=convert_element_type])" - "397([Symbol name=convert_element_type])" - "395([Symbol name=broadcast_in_dim])" -> "397([Symbol name=convert_element_type])" - "437([Symbol name=reshape])" - "436([Symbol name=linear])" -> "437([Symbol name=reshape])" - "481([Symbol name=convert_element_type])" - "480([Symbol name=linear])" -> "481([Symbol name=convert_element_type])" - "508([Symbol name=convert_element_type])" - "500([Symbol name=linear])" -> "508([Symbol name=convert_element_type])" - "502([Symbol name=convert_element_type])" - "500([Symbol name=linear])" -> "502([Symbol name=convert_element_type])" - "513([Symbol name=convert_element_type])" - "501([Symbol name=linear])" -> "513([Symbol name=convert_element_type])" - "517([Symbol name=convert_element_type])" - "516([Symbol name=linear])" -> "517([Symbol name=convert_element_type])" - "433([Symbol name=convert_element_type])" - "431([Symbol name=broadcast_in_dim])" -> "433([Symbol name=convert_element_type])" - "497([Symbol name=convert_element_type])" - "495([Symbol name=broadcast_in_dim])" -> "497([Symbol name=convert_element_type])" - "537([Symbol name=reshape])" - "536([Symbol name=linear])" -> "537([Symbol name=reshape])" - "581([Symbol name=convert_element_type])" - "580([Symbol name=linear])" -> "581([Symbol name=convert_element_type])" - "608([Symbol name=convert_element_type])" - "600([Symbol name=linear])" -> "608([Symbol name=convert_element_type])" - "602([Symbol name=convert_element_type])" - "600([Symbol name=linear])" -> "602([Symbol name=convert_element_type])" - "613([Symbol name=convert_element_type])" - "601([Symbol name=linear])" -> "613([Symbol name=convert_element_type])" - "617([Symbol name=convert_element_type])" - "616([Symbol name=linear])" -> "617([Symbol name=convert_element_type])" - "533([Symbol name=convert_element_type])" - "531([Symbol name=broadcast_in_dim])" -> "533([Symbol name=convert_element_type])" - "597([Symbol name=convert_element_type])" - "595([Symbol name=broadcast_in_dim])" -> "597([Symbol name=convert_element_type])" - "637([Symbol name=reshape])" - "636([Symbol name=linear])" -> "637([Symbol name=reshape])" - "681([Symbol name=convert_element_type])" - "680([Symbol name=linear])" -> "681([Symbol name=convert_element_type])" - "708([Symbol name=convert_element_type])" - "700([Symbol name=linear])" -> "708([Symbol name=convert_element_type])" - "702([Symbol name=convert_element_type])" - "700([Symbol name=linear])" -> "702([Symbol name=convert_element_type])" - "713([Symbol name=convert_element_type])" - "701([Symbol name=linear])" -> "713([Symbol name=convert_element_type])" - "717([Symbol name=convert_element_type])" - "716([Symbol name=linear])" -> "717([Symbol name=convert_element_type])" - "633([Symbol name=convert_element_type])" - "631([Symbol name=broadcast_in_dim])" -> "633([Symbol name=convert_element_type])" - "697([Symbol name=convert_element_type])" - "695([Symbol name=broadcast_in_dim])" -> "697([Symbol name=convert_element_type])" - "737([Symbol name=reshape])" - "736([Symbol name=linear])" -> "737([Symbol name=reshape])" - "781([Symbol name=convert_element_type])" - "780([Symbol name=linear])" -> "781([Symbol name=convert_element_type])" - "808([Symbol name=convert_element_type])" - "800([Symbol name=linear])" -> "808([Symbol name=convert_element_type])" - "802([Symbol name=convert_element_type])" - "800([Symbol name=linear])" -> "802([Symbol name=convert_element_type])" - "813([Symbol name=convert_element_type])" - "801([Symbol name=linear])" -> "813([Symbol name=convert_element_type])" - "817([Symbol name=convert_element_type])" - "816([Symbol name=linear])" -> "817([Symbol name=convert_element_type])" - "733([Symbol name=convert_element_type])" - "731([Symbol name=broadcast_in_dim])" -> "733([Symbol name=convert_element_type])" - "797([Symbol name=convert_element_type])" - "795([Symbol name=broadcast_in_dim])" -> "797([Symbol name=convert_element_type])" - "837([Symbol name=reshape])" - "836([Symbol name=linear])" -> "837([Symbol name=reshape])" - "881([Symbol name=convert_element_type])" - "880([Symbol name=linear])" -> "881([Symbol name=convert_element_type])" - "908([Symbol name=convert_element_type])" - "900([Symbol name=linear])" -> "908([Symbol name=convert_element_type])" - "902([Symbol name=convert_element_type])" - "900([Symbol name=linear])" -> "902([Symbol name=convert_element_type])" - "913([Symbol name=convert_element_type])" - "901([Symbol name=linear])" -> "913([Symbol name=convert_element_type])" - "917([Symbol name=convert_element_type])" - "916([Symbol name=linear])" -> "917([Symbol name=convert_element_type])" - "833([Symbol name=convert_element_type])" - "831([Symbol name=broadcast_in_dim])" -> "833([Symbol name=convert_element_type])" - "897([Symbol name=convert_element_type])" - "895([Symbol name=broadcast_in_dim])" -> "897([Symbol name=convert_element_type])" - "937([Symbol name=reshape])" - "936([Symbol name=linear])" -> "937([Symbol name=reshape])" - "981([Symbol name=convert_element_type])" - "980([Symbol name=linear])" -> "981([Symbol name=convert_element_type])" - "1008([Symbol name=convert_element_type])" - "1000([Symbol name=linear])" -> "1008([Symbol name=convert_element_type])" - "1002([Symbol name=convert_element_type])" - "1000([Symbol name=linear])" -> "1002([Symbol name=convert_element_type])" - "1013([Symbol name=convert_element_type])" - "1001([Symbol name=linear])" -> "1013([Symbol name=convert_element_type])" - "1017([Symbol name=convert_element_type])" - "1016([Symbol name=linear])" -> "1017([Symbol name=convert_element_type])" - "933([Symbol name=convert_element_type])" - "931([Symbol name=broadcast_in_dim])" -> "933([Symbol name=convert_element_type])" - "997([Symbol name=convert_element_type])" - "995([Symbol name=broadcast_in_dim])" -> "997([Symbol name=convert_element_type])" - "1037([Symbol name=reshape])" - "1036([Symbol name=linear])" -> "1037([Symbol name=reshape])" - "1081([Symbol name=convert_element_type])" - "1080([Symbol name=linear])" -> "1081([Symbol name=convert_element_type])" - "1108([Symbol name=convert_element_type])" - "1100([Symbol name=linear])" -> "1108([Symbol name=convert_element_type])" - "1102([Symbol name=convert_element_type])" - "1100([Symbol name=linear])" -> "1102([Symbol name=convert_element_type])" - "1113([Symbol name=convert_element_type])" - "1101([Symbol name=linear])" -> "1113([Symbol name=convert_element_type])" - "1117([Symbol name=convert_element_type])" - "1116([Symbol name=linear])" -> "1117([Symbol name=convert_element_type])" - "1033([Symbol name=convert_element_type])" - "1031([Symbol name=broadcast_in_dim])" -> "1033([Symbol name=convert_element_type])" - "1097([Symbol name=convert_element_type])" - "1095([Symbol name=broadcast_in_dim])" -> "1097([Symbol name=convert_element_type])" - "1137([Symbol name=reshape])" - "1136([Symbol name=linear])" -> "1137([Symbol name=reshape])" - "1181([Symbol name=convert_element_type])" - "1180([Symbol name=linear])" -> "1181([Symbol name=convert_element_type])" - "1208([Symbol name=convert_element_type])" - "1200([Symbol name=linear])" -> "1208([Symbol name=convert_element_type])" - "1202([Symbol name=convert_element_type])" - "1200([Symbol name=linear])" -> "1202([Symbol name=convert_element_type])" - "1213([Symbol name=convert_element_type])" - "1201([Symbol name=linear])" -> "1213([Symbol name=convert_element_type])" - "1217([Symbol name=convert_element_type])" - "1216([Symbol name=linear])" -> "1217([Symbol name=convert_element_type])" - "1133([Symbol name=convert_element_type])" - "1131([Symbol name=broadcast_in_dim])" -> "1133([Symbol name=convert_element_type])" - "1197([Symbol name=convert_element_type])" - "1195([Symbol name=broadcast_in_dim])" -> "1197([Symbol name=convert_element_type])" - "1237([Symbol name=reshape])" - "1236([Symbol name=linear])" -> "1237([Symbol name=reshape])" - "1281([Symbol name=convert_element_type])" - "1280([Symbol name=linear])" -> "1281([Symbol name=convert_element_type])" - "1308([Symbol name=convert_element_type])" - "1300([Symbol name=linear])" -> "1308([Symbol name=convert_element_type])" - "1302([Symbol name=convert_element_type])" - "1300([Symbol name=linear])" -> "1302([Symbol name=convert_element_type])" - "1313([Symbol name=convert_element_type])" - "1301([Symbol name=linear])" -> "1313([Symbol name=convert_element_type])" - "1317([Symbol name=convert_element_type])" - "1316([Symbol name=linear])" -> "1317([Symbol name=convert_element_type])" - "1233([Symbol name=convert_element_type])" - "1231([Symbol name=broadcast_in_dim])" -> "1233([Symbol name=convert_element_type])" - "1297([Symbol name=convert_element_type])" - "1295([Symbol name=broadcast_in_dim])" -> "1297([Symbol name=convert_element_type])" - "1337([Symbol name=reshape])" - "1336([Symbol name=linear])" -> "1337([Symbol name=reshape])" - "1381([Symbol name=convert_element_type])" - "1380([Symbol name=linear])" -> "1381([Symbol name=convert_element_type])" - "1408([Symbol name=convert_element_type])" - "1400([Symbol name=linear])" -> "1408([Symbol name=convert_element_type])" - "1402([Symbol name=convert_element_type])" - "1400([Symbol name=linear])" -> "1402([Symbol name=convert_element_type])" - "1413([Symbol name=convert_element_type])" - "1401([Symbol name=linear])" -> "1413([Symbol name=convert_element_type])" - "1417([Symbol name=convert_element_type])" - "1416([Symbol name=linear])" -> "1417([Symbol name=convert_element_type])" - "1333([Symbol name=convert_element_type])" - "1331([Symbol name=broadcast_in_dim])" -> "1333([Symbol name=convert_element_type])" - "1397([Symbol name=convert_element_type])" - "1395([Symbol name=broadcast_in_dim])" -> "1397([Symbol name=convert_element_type])" - "1437([Symbol name=reshape])" - "1436([Symbol name=linear])" -> "1437([Symbol name=reshape])" - "1481([Symbol name=convert_element_type])" - "1480([Symbol name=linear])" -> "1481([Symbol name=convert_element_type])" - "1508([Symbol name=convert_element_type])" - "1500([Symbol name=linear])" -> "1508([Symbol name=convert_element_type])" - "1502([Symbol name=convert_element_type])" - "1500([Symbol name=linear])" -> "1502([Symbol name=convert_element_type])" - "1513([Symbol name=convert_element_type])" - "1501([Symbol name=linear])" -> "1513([Symbol name=convert_element_type])" - "1517([Symbol name=convert_element_type])" - "1516([Symbol name=linear])" -> "1517([Symbol name=convert_element_type])" - "1433([Symbol name=convert_element_type])" - "1431([Symbol name=broadcast_in_dim])" -> "1433([Symbol name=convert_element_type])" - "1497([Symbol name=convert_element_type])" - "1495([Symbol name=broadcast_in_dim])" -> "1497([Symbol name=convert_element_type])" - "1537([Symbol name=reshape])" - "1536([Symbol name=linear])" -> "1537([Symbol name=reshape])" - "1581([Symbol name=convert_element_type])" - "1580([Symbol name=linear])" -> "1581([Symbol name=convert_element_type])" - "1608([Symbol name=convert_element_type])" - "1600([Symbol name=linear])" -> "1608([Symbol name=convert_element_type])" - "1602([Symbol name=convert_element_type])" - "1600([Symbol name=linear])" -> "1602([Symbol name=convert_element_type])" - "1613([Symbol name=convert_element_type])" - "1601([Symbol name=linear])" -> "1613([Symbol name=convert_element_type])" - "1617([Symbol name=convert_element_type])" - "1616([Symbol name=linear])" -> "1617([Symbol name=convert_element_type])" - "1533([Symbol name=convert_element_type])" - "1531([Symbol name=broadcast_in_dim])" -> "1533([Symbol name=convert_element_type])" - "1597([Symbol name=convert_element_type])" - "1595([Symbol name=broadcast_in_dim])" -> "1597([Symbol name=convert_element_type])" - "1637([Symbol name=reshape])" - "1636([Symbol name=linear])" -> "1637([Symbol name=reshape])" - "1681([Symbol name=convert_element_type])" - "1680([Symbol name=linear])" -> "1681([Symbol name=convert_element_type])" - "1708([Symbol name=convert_element_type])" - "1700([Symbol name=linear])" -> "1708([Symbol name=convert_element_type])" - "1702([Symbol name=convert_element_type])" - "1700([Symbol name=linear])" -> "1702([Symbol name=convert_element_type])" - "1713([Symbol name=convert_element_type])" - "1701([Symbol name=linear])" -> "1713([Symbol name=convert_element_type])" - "1717([Symbol name=convert_element_type])" - "1716([Symbol name=linear])" -> "1717([Symbol name=convert_element_type])" - "1633([Symbol name=convert_element_type])" - "1631([Symbol name=broadcast_in_dim])" -> "1633([Symbol name=convert_element_type])" - "1697([Symbol name=convert_element_type])" - "1695([Symbol name=broadcast_in_dim])" -> "1697([Symbol name=convert_element_type])" - "1733([Symbol name=convert_element_type])" - "1731([Symbol name=broadcast_in_dim])" -> "1733([Symbol name=convert_element_type])" - "129([Symbol name=mul])" - "128([Symbol name=broadcast_in_dim])" -> "129([Symbol name=mul])" - "121([Symbol name=convert_element_type])" -> "129([Symbol name=mul])" - "122([Symbol name=mul])" - "121([Symbol name=convert_element_type])" -> "122([Symbol name=mul])" - "183([Symbol name=add])" - "181([Symbol name=convert_element_type])" -> "183([Symbol name=add])" - "182([Symbol name=convert_element_type])" -> "183([Symbol name=add])" - "1667([Symbol name=mul])" - "1665([Symbol name=broadcast_in_dim])" -> "1667([Symbol name=mul])" - "1666([Symbol name=convert_element_type])" -> "1667([Symbol name=mul])" - "267([Symbol name=mul])" - "265([Symbol name=broadcast_in_dim])" -> "267([Symbol name=mul])" - "266([Symbol name=convert_element_type])" -> "267([Symbol name=mul])" - "652([Symbol name=mul])" - "650([Symbol name=broadcast_in_dim])" -> "652([Symbol name=mul])" - "651([Symbol name=convert_element_type])" -> "652([Symbol name=mul])" - "1167([Symbol name=mul])" - "1165([Symbol name=broadcast_in_dim])" -> "1167([Symbol name=mul])" - "1166([Symbol name=convert_element_type])" -> "1167([Symbol name=mul])" - "1552([Symbol name=mul])" - "1550([Symbol name=broadcast_in_dim])" -> "1552([Symbol name=mul])" - "1551([Symbol name=convert_element_type])" -> "1552([Symbol name=mul])" - "152([Symbol name=mul])" - "150([Symbol name=broadcast_in_dim])" -> "152([Symbol name=mul])" - "151([Symbol name=convert_element_type])" -> "152([Symbol name=mul])" - "667([Symbol name=mul])" - "665([Symbol name=broadcast_in_dim])" -> "667([Symbol name=mul])" - "666([Symbol name=convert_element_type])" -> "667([Symbol name=mul])" - "1052([Symbol name=mul])" - "1050([Symbol name=broadcast_in_dim])" -> "1052([Symbol name=mul])" - "1051([Symbol name=convert_element_type])" -> "1052([Symbol name=mul])" - "1567([Symbol name=mul])" - "1565([Symbol name=broadcast_in_dim])" -> "1567([Symbol name=mul])" - "1566([Symbol name=convert_element_type])" -> "1567([Symbol name=mul])" - "167([Symbol name=mul])" - "165([Symbol name=broadcast_in_dim])" -> "167([Symbol name=mul])" - "166([Symbol name=convert_element_type])" -> "167([Symbol name=mul])" - "552([Symbol name=mul])" - "550([Symbol name=broadcast_in_dim])" -> "552([Symbol name=mul])" - "551([Symbol name=convert_element_type])" -> "552([Symbol name=mul])" - "1067([Symbol name=mul])" - "1065([Symbol name=broadcast_in_dim])" -> "1067([Symbol name=mul])" - "1066([Symbol name=convert_element_type])" -> "1067([Symbol name=mul])" - "1452([Symbol name=mul])" - "1450([Symbol name=broadcast_in_dim])" -> "1452([Symbol name=mul])" - "1451([Symbol name=convert_element_type])" -> "1452([Symbol name=mul])" - "567([Symbol name=mul])" - "565([Symbol name=broadcast_in_dim])" -> "567([Symbol name=mul])" - "566([Symbol name=convert_element_type])" -> "567([Symbol name=mul])" - "952([Symbol name=mul])" - "950([Symbol name=broadcast_in_dim])" -> "952([Symbol name=mul])" - "951([Symbol name=convert_element_type])" -> "952([Symbol name=mul])" - "1467([Symbol name=mul])" - "1465([Symbol name=broadcast_in_dim])" -> "1467([Symbol name=mul])" - "1466([Symbol name=convert_element_type])" -> "1467([Symbol name=mul])" - "452([Symbol name=mul])" - "450([Symbol name=broadcast_in_dim])" -> "452([Symbol name=mul])" - "451([Symbol name=convert_element_type])" -> "452([Symbol name=mul])" - "967([Symbol name=mul])" - "965([Symbol name=broadcast_in_dim])" -> "967([Symbol name=mul])" - "966([Symbol name=convert_element_type])" -> "967([Symbol name=mul])" - "1352([Symbol name=mul])" - "1350([Symbol name=broadcast_in_dim])" -> "1352([Symbol name=mul])" - "1351([Symbol name=convert_element_type])" -> "1352([Symbol name=mul])" - "467([Symbol name=mul])" - "465([Symbol name=broadcast_in_dim])" -> "467([Symbol name=mul])" - "466([Symbol name=convert_element_type])" -> "467([Symbol name=mul])" - "852([Symbol name=mul])" - "850([Symbol name=broadcast_in_dim])" -> "852([Symbol name=mul])" - "851([Symbol name=convert_element_type])" -> "852([Symbol name=mul])" - "1367([Symbol name=mul])" - "1365([Symbol name=broadcast_in_dim])" -> "1367([Symbol name=mul])" - "1366([Symbol name=convert_element_type])" -> "1367([Symbol name=mul])" - "352([Symbol name=mul])" - "350([Symbol name=broadcast_in_dim])" -> "352([Symbol name=mul])" - "351([Symbol name=convert_element_type])" -> "352([Symbol name=mul])" - "867([Symbol name=mul])" - "865([Symbol name=broadcast_in_dim])" -> "867([Symbol name=mul])" - "866([Symbol name=convert_element_type])" -> "867([Symbol name=mul])" - "1252([Symbol name=mul])" - "1250([Symbol name=broadcast_in_dim])" -> "1252([Symbol name=mul])" - "1251([Symbol name=convert_element_type])" -> "1252([Symbol name=mul])" - "367([Symbol name=mul])" - "365([Symbol name=broadcast_in_dim])" -> "367([Symbol name=mul])" - "366([Symbol name=convert_element_type])" -> "367([Symbol name=mul])" - "752([Symbol name=mul])" - "750([Symbol name=broadcast_in_dim])" -> "752([Symbol name=mul])" - "751([Symbol name=convert_element_type])" -> "752([Symbol name=mul])" - "1267([Symbol name=mul])" - "1265([Symbol name=broadcast_in_dim])" -> "1267([Symbol name=mul])" - "1266([Symbol name=convert_element_type])" -> "1267([Symbol name=mul])" - "1652([Symbol name=mul])" - "1650([Symbol name=broadcast_in_dim])" -> "1652([Symbol name=mul])" - "1651([Symbol name=convert_element_type])" -> "1652([Symbol name=mul])" - "252([Symbol name=mul])" - "250([Symbol name=broadcast_in_dim])" -> "252([Symbol name=mul])" - "251([Symbol name=convert_element_type])" -> "252([Symbol name=mul])" - "767([Symbol name=mul])" - "765([Symbol name=broadcast_in_dim])" -> "767([Symbol name=mul])" - "766([Symbol name=convert_element_type])" -> "767([Symbol name=mul])" - "1152([Symbol name=mul])" - "1150([Symbol name=broadcast_in_dim])" -> "1152([Symbol name=mul])" - "1151([Symbol name=convert_element_type])" -> "1152([Symbol name=mul])" - "770([Symbol name=mul])" - "768([Symbol name=broadcast_in_dim])" -> "770([Symbol name=mul])" - "769([Symbol name=convert_element_type])" -> "770([Symbol name=mul])" - "1155([Symbol name=mul])" - "1153([Symbol name=broadcast_in_dim])" -> "1155([Symbol name=mul])" - "1154([Symbol name=convert_element_type])" -> "1155([Symbol name=mul])" - "1670([Symbol name=mul])" - "1668([Symbol name=broadcast_in_dim])" -> "1670([Symbol name=mul])" - "1669([Symbol name=convert_element_type])" -> "1670([Symbol name=mul])" - "270([Symbol name=mul])" - "268([Symbol name=broadcast_in_dim])" -> "270([Symbol name=mul])" - "269([Symbol name=convert_element_type])" -> "270([Symbol name=mul])" - "655([Symbol name=mul])" - "653([Symbol name=broadcast_in_dim])" -> "655([Symbol name=mul])" - "654([Symbol name=convert_element_type])" -> "655([Symbol name=mul])" - "1170([Symbol name=mul])" - "1168([Symbol name=broadcast_in_dim])" -> "1170([Symbol name=mul])" - "1169([Symbol name=convert_element_type])" -> "1170([Symbol name=mul])" - "1555([Symbol name=mul])" - "1553([Symbol name=broadcast_in_dim])" -> "1555([Symbol name=mul])" - "1554([Symbol name=convert_element_type])" -> "1555([Symbol name=mul])" - "155([Symbol name=mul])" - "153([Symbol name=broadcast_in_dim])" -> "155([Symbol name=mul])" - "154([Symbol name=convert_element_type])" -> "155([Symbol name=mul])" - "670([Symbol name=mul])" - "668([Symbol name=broadcast_in_dim])" -> "670([Symbol name=mul])" - "669([Symbol name=convert_element_type])" -> "670([Symbol name=mul])" - "1055([Symbol name=mul])" - "1053([Symbol name=broadcast_in_dim])" -> "1055([Symbol name=mul])" - "1054([Symbol name=convert_element_type])" -> "1055([Symbol name=mul])" - "1570([Symbol name=mul])" - "1568([Symbol name=broadcast_in_dim])" -> "1570([Symbol name=mul])" - "1569([Symbol name=convert_element_type])" -> "1570([Symbol name=mul])" - "170([Symbol name=mul])" - "168([Symbol name=broadcast_in_dim])" -> "170([Symbol name=mul])" - "169([Symbol name=convert_element_type])" -> "170([Symbol name=mul])" - "555([Symbol name=mul])" - "553([Symbol name=broadcast_in_dim])" -> "555([Symbol name=mul])" - "554([Symbol name=convert_element_type])" -> "555([Symbol name=mul])" - "1070([Symbol name=mul])" - "1068([Symbol name=broadcast_in_dim])" -> "1070([Symbol name=mul])" - "1069([Symbol name=convert_element_type])" -> "1070([Symbol name=mul])" - "1455([Symbol name=mul])" - "1453([Symbol name=broadcast_in_dim])" -> "1455([Symbol name=mul])" - "1454([Symbol name=convert_element_type])" -> "1455([Symbol name=mul])" - "570([Symbol name=mul])" - "568([Symbol name=broadcast_in_dim])" -> "570([Symbol name=mul])" - "569([Symbol name=convert_element_type])" -> "570([Symbol name=mul])" - "955([Symbol name=mul])" - "953([Symbol name=broadcast_in_dim])" -> "955([Symbol name=mul])" - "954([Symbol name=convert_element_type])" -> "955([Symbol name=mul])" - "1470([Symbol name=mul])" - "1468([Symbol name=broadcast_in_dim])" -> "1470([Symbol name=mul])" - "1469([Symbol name=convert_element_type])" -> "1470([Symbol name=mul])" - "455([Symbol name=mul])" - "453([Symbol name=broadcast_in_dim])" -> "455([Symbol name=mul])" - "454([Symbol name=convert_element_type])" -> "455([Symbol name=mul])" - "970([Symbol name=mul])" - "968([Symbol name=broadcast_in_dim])" -> "970([Symbol name=mul])" - "969([Symbol name=convert_element_type])" -> "970([Symbol name=mul])" - "1355([Symbol name=mul])" - "1353([Symbol name=broadcast_in_dim])" -> "1355([Symbol name=mul])" - "1354([Symbol name=convert_element_type])" -> "1355([Symbol name=mul])" - "470([Symbol name=mul])" - "468([Symbol name=broadcast_in_dim])" -> "470([Symbol name=mul])" - "469([Symbol name=convert_element_type])" -> "470([Symbol name=mul])" - "855([Symbol name=mul])" - "853([Symbol name=broadcast_in_dim])" -> "855([Symbol name=mul])" - "854([Symbol name=convert_element_type])" -> "855([Symbol name=mul])" - "1370([Symbol name=mul])" - "1368([Symbol name=broadcast_in_dim])" -> "1370([Symbol name=mul])" - "1369([Symbol name=convert_element_type])" -> "1370([Symbol name=mul])" - "355([Symbol name=mul])" - "353([Symbol name=broadcast_in_dim])" -> "355([Symbol name=mul])" - "354([Symbol name=convert_element_type])" -> "355([Symbol name=mul])" - "870([Symbol name=mul])" - "868([Symbol name=broadcast_in_dim])" -> "870([Symbol name=mul])" - "869([Symbol name=convert_element_type])" -> "870([Symbol name=mul])" - "1255([Symbol name=mul])" - "1253([Symbol name=broadcast_in_dim])" -> "1255([Symbol name=mul])" - "1254([Symbol name=convert_element_type])" -> "1255([Symbol name=mul])" - "370([Symbol name=mul])" - "368([Symbol name=broadcast_in_dim])" -> "370([Symbol name=mul])" - "369([Symbol name=convert_element_type])" -> "370([Symbol name=mul])" - "755([Symbol name=mul])" - "753([Symbol name=broadcast_in_dim])" -> "755([Symbol name=mul])" - "754([Symbol name=convert_element_type])" -> "755([Symbol name=mul])" - "1270([Symbol name=mul])" - "1268([Symbol name=broadcast_in_dim])" -> "1270([Symbol name=mul])" - "1269([Symbol name=convert_element_type])" -> "1270([Symbol name=mul])" - "1655([Symbol name=mul])" - "1653([Symbol name=broadcast_in_dim])" -> "1655([Symbol name=mul])" - "1654([Symbol name=convert_element_type])" -> "1655([Symbol name=mul])" - "255([Symbol name=mul])" - "253([Symbol name=broadcast_in_dim])" -> "255([Symbol name=mul])" - "254([Symbol name=convert_element_type])" -> "255([Symbol name=mul])" - "138([Symbol name=transpose])" - "137([Symbol name=reshape])" -> "138([Symbol name=transpose])" - "210([Symbol name=mul])" - "208([Symbol name=convert_element_type])" -> "210([Symbol name=mul])" - "209([Symbol name=convert_element_type])" -> "210([Symbol name=mul])" - "203([Symbol name=neg])" - "202([Symbol name=convert_element_type])" -> "203([Symbol name=neg])" - "214([Symbol name=mul])" - "212([Symbol name=convert_element_type])" -> "214([Symbol name=mul])" - "213([Symbol name=convert_element_type])" -> "214([Symbol name=mul])" - "219([Symbol name=add])" - "217([Symbol name=convert_element_type])" -> "219([Symbol name=add])" - "218([Symbol name=convert_element_type])" -> "219([Symbol name=add])" - "134([Symbol name=mul])" - "132([Symbol name=convert_element_type])" -> "134([Symbol name=mul])" - "133([Symbol name=convert_element_type])" -> "134([Symbol name=mul])" - "198([Symbol name=mul])" - "196([Symbol name=convert_element_type])" -> "198([Symbol name=mul])" - "197([Symbol name=convert_element_type])" -> "198([Symbol name=mul])" - "238([Symbol name=transpose])" - "237([Symbol name=reshape])" -> "238([Symbol name=transpose])" - "283([Symbol name=add])" - "281([Symbol name=convert_element_type])" -> "283([Symbol name=add])" - "282([Symbol name=convert_element_type])" -> "283([Symbol name=add])" - "310([Symbol name=mul])" - "308([Symbol name=convert_element_type])" -> "310([Symbol name=mul])" - "309([Symbol name=convert_element_type])" -> "310([Symbol name=mul])" - "303([Symbol name=neg])" - "302([Symbol name=convert_element_type])" -> "303([Symbol name=neg])" - "314([Symbol name=mul])" - "312([Symbol name=convert_element_type])" -> "314([Symbol name=mul])" - "313([Symbol name=convert_element_type])" -> "314([Symbol name=mul])" - "319([Symbol name=add])" - "317([Symbol name=convert_element_type])" -> "319([Symbol name=add])" - "318([Symbol name=convert_element_type])" -> "319([Symbol name=add])" - "234([Symbol name=mul])" - "232([Symbol name=convert_element_type])" -> "234([Symbol name=mul])" - "233([Symbol name=convert_element_type])" -> "234([Symbol name=mul])" - "298([Symbol name=mul])" - "296([Symbol name=convert_element_type])" -> "298([Symbol name=mul])" - "297([Symbol name=convert_element_type])" -> "298([Symbol name=mul])" - "338([Symbol name=transpose])" - "337([Symbol name=reshape])" -> "338([Symbol name=transpose])" - "383([Symbol name=add])" - "381([Symbol name=convert_element_type])" -> "383([Symbol name=add])" - "382([Symbol name=convert_element_type])" -> "383([Symbol name=add])" - "410([Symbol name=mul])" - "408([Symbol name=convert_element_type])" -> "410([Symbol name=mul])" - "409([Symbol name=convert_element_type])" -> "410([Symbol name=mul])" - "403([Symbol name=neg])" - "402([Symbol name=convert_element_type])" -> "403([Symbol name=neg])" - "414([Symbol name=mul])" - "412([Symbol name=convert_element_type])" -> "414([Symbol name=mul])" - "413([Symbol name=convert_element_type])" -> "414([Symbol name=mul])" - "419([Symbol name=add])" - "417([Symbol name=convert_element_type])" -> "419([Symbol name=add])" - "418([Symbol name=convert_element_type])" -> "419([Symbol name=add])" - "334([Symbol name=mul])" - "332([Symbol name=convert_element_type])" -> "334([Symbol name=mul])" - "333([Symbol name=convert_element_type])" -> "334([Symbol name=mul])" - "398([Symbol name=mul])" - "396([Symbol name=convert_element_type])" -> "398([Symbol name=mul])" - "397([Symbol name=convert_element_type])" -> "398([Symbol name=mul])" - "438([Symbol name=transpose])" - "437([Symbol name=reshape])" -> "438([Symbol name=transpose])" - "483([Symbol name=add])" - "481([Symbol name=convert_element_type])" -> "483([Symbol name=add])" - "482([Symbol name=convert_element_type])" -> "483([Symbol name=add])" - "510([Symbol name=mul])" - "508([Symbol name=convert_element_type])" -> "510([Symbol name=mul])" - "509([Symbol name=convert_element_type])" -> "510([Symbol name=mul])" - "503([Symbol name=neg])" - "502([Symbol name=convert_element_type])" -> "503([Symbol name=neg])" - "514([Symbol name=mul])" - "512([Symbol name=convert_element_type])" -> "514([Symbol name=mul])" - "513([Symbol name=convert_element_type])" -> "514([Symbol name=mul])" - "519([Symbol name=add])" - "517([Symbol name=convert_element_type])" -> "519([Symbol name=add])" - "518([Symbol name=convert_element_type])" -> "519([Symbol name=add])" - "434([Symbol name=mul])" - "432([Symbol name=convert_element_type])" -> "434([Symbol name=mul])" - "433([Symbol name=convert_element_type])" -> "434([Symbol name=mul])" - "498([Symbol name=mul])" - "496([Symbol name=convert_element_type])" -> "498([Symbol name=mul])" - "497([Symbol name=convert_element_type])" -> "498([Symbol name=mul])" - "538([Symbol name=transpose])" - "537([Symbol name=reshape])" -> "538([Symbol name=transpose])" - "583([Symbol name=add])" - "581([Symbol name=convert_element_type])" -> "583([Symbol name=add])" - "582([Symbol name=convert_element_type])" -> "583([Symbol name=add])" - "610([Symbol name=mul])" - "608([Symbol name=convert_element_type])" -> "610([Symbol name=mul])" - "609([Symbol name=convert_element_type])" -> "610([Symbol name=mul])" - "603([Symbol name=neg])" - "602([Symbol name=convert_element_type])" -> "603([Symbol name=neg])" - "614([Symbol name=mul])" - "612([Symbol name=convert_element_type])" -> "614([Symbol name=mul])" - "613([Symbol name=convert_element_type])" -> "614([Symbol name=mul])" - "619([Symbol name=add])" - "617([Symbol name=convert_element_type])" -> "619([Symbol name=add])" - "618([Symbol name=convert_element_type])" -> "619([Symbol name=add])" - "534([Symbol name=mul])" - "532([Symbol name=convert_element_type])" -> "534([Symbol name=mul])" - "533([Symbol name=convert_element_type])" -> "534([Symbol name=mul])" - "598([Symbol name=mul])" - "596([Symbol name=convert_element_type])" -> "598([Symbol name=mul])" - "597([Symbol name=convert_element_type])" -> "598([Symbol name=mul])" - "638([Symbol name=transpose])" - "637([Symbol name=reshape])" -> "638([Symbol name=transpose])" - "683([Symbol name=add])" - "681([Symbol name=convert_element_type])" -> "683([Symbol name=add])" - "682([Symbol name=convert_element_type])" -> "683([Symbol name=add])" - "710([Symbol name=mul])" - "708([Symbol name=convert_element_type])" -> "710([Symbol name=mul])" - "709([Symbol name=convert_element_type])" -> "710([Symbol name=mul])" - "703([Symbol name=neg])" - "702([Symbol name=convert_element_type])" -> "703([Symbol name=neg])" - "714([Symbol name=mul])" - "712([Symbol name=convert_element_type])" -> "714([Symbol name=mul])" - "713([Symbol name=convert_element_type])" -> "714([Symbol name=mul])" - "719([Symbol name=add])" - "717([Symbol name=convert_element_type])" -> "719([Symbol name=add])" - "718([Symbol name=convert_element_type])" -> "719([Symbol name=add])" - "634([Symbol name=mul])" - "632([Symbol name=convert_element_type])" -> "634([Symbol name=mul])" - "633([Symbol name=convert_element_type])" -> "634([Symbol name=mul])" - "698([Symbol name=mul])" - "696([Symbol name=convert_element_type])" -> "698([Symbol name=mul])" - "697([Symbol name=convert_element_type])" -> "698([Symbol name=mul])" - "738([Symbol name=transpose])" - "737([Symbol name=reshape])" -> "738([Symbol name=transpose])" - "783([Symbol name=add])" - "781([Symbol name=convert_element_type])" -> "783([Symbol name=add])" - "782([Symbol name=convert_element_type])" -> "783([Symbol name=add])" - "810([Symbol name=mul])" - "808([Symbol name=convert_element_type])" -> "810([Symbol name=mul])" - "809([Symbol name=convert_element_type])" -> "810([Symbol name=mul])" - "803([Symbol name=neg])" - "802([Symbol name=convert_element_type])" -> "803([Symbol name=neg])" - "814([Symbol name=mul])" - "812([Symbol name=convert_element_type])" -> "814([Symbol name=mul])" - "813([Symbol name=convert_element_type])" -> "814([Symbol name=mul])" - "819([Symbol name=add])" - "817([Symbol name=convert_element_type])" -> "819([Symbol name=add])" - "818([Symbol name=convert_element_type])" -> "819([Symbol name=add])" - "734([Symbol name=mul])" - "732([Symbol name=convert_element_type])" -> "734([Symbol name=mul])" - "733([Symbol name=convert_element_type])" -> "734([Symbol name=mul])" - "798([Symbol name=mul])" - "796([Symbol name=convert_element_type])" -> "798([Symbol name=mul])" - "797([Symbol name=convert_element_type])" -> "798([Symbol name=mul])" - "838([Symbol name=transpose])" - "837([Symbol name=reshape])" -> "838([Symbol name=transpose])" - "883([Symbol name=add])" - "881([Symbol name=convert_element_type])" -> "883([Symbol name=add])" - "882([Symbol name=convert_element_type])" -> "883([Symbol name=add])" - "910([Symbol name=mul])" - "908([Symbol name=convert_element_type])" -> "910([Symbol name=mul])" - "909([Symbol name=convert_element_type])" -> "910([Symbol name=mul])" - "903([Symbol name=neg])" - "902([Symbol name=convert_element_type])" -> "903([Symbol name=neg])" - "914([Symbol name=mul])" - "912([Symbol name=convert_element_type])" -> "914([Symbol name=mul])" - "913([Symbol name=convert_element_type])" -> "914([Symbol name=mul])" - "919([Symbol name=add])" - "917([Symbol name=convert_element_type])" -> "919([Symbol name=add])" - "918([Symbol name=convert_element_type])" -> "919([Symbol name=add])" - "834([Symbol name=mul])" - "832([Symbol name=convert_element_type])" -> "834([Symbol name=mul])" - "833([Symbol name=convert_element_type])" -> "834([Symbol name=mul])" - "898([Symbol name=mul])" - "896([Symbol name=convert_element_type])" -> "898([Symbol name=mul])" - "897([Symbol name=convert_element_type])" -> "898([Symbol name=mul])" - "938([Symbol name=transpose])" - "937([Symbol name=reshape])" -> "938([Symbol name=transpose])" - "983([Symbol name=add])" - "981([Symbol name=convert_element_type])" -> "983([Symbol name=add])" - "982([Symbol name=convert_element_type])" -> "983([Symbol name=add])" - "1010([Symbol name=mul])" - "1008([Symbol name=convert_element_type])" -> "1010([Symbol name=mul])" - "1009([Symbol name=convert_element_type])" -> "1010([Symbol name=mul])" - "1003([Symbol name=neg])" - "1002([Symbol name=convert_element_type])" -> "1003([Symbol name=neg])" - "1014([Symbol name=mul])" - "1012([Symbol name=convert_element_type])" -> "1014([Symbol name=mul])" - "1013([Symbol name=convert_element_type])" -> "1014([Symbol name=mul])" - "1019([Symbol name=add])" - "1017([Symbol name=convert_element_type])" -> "1019([Symbol name=add])" - "1018([Symbol name=convert_element_type])" -> "1019([Symbol name=add])" - "934([Symbol name=mul])" - "932([Symbol name=convert_element_type])" -> "934([Symbol name=mul])" - "933([Symbol name=convert_element_type])" -> "934([Symbol name=mul])" - "998([Symbol name=mul])" - "996([Symbol name=convert_element_type])" -> "998([Symbol name=mul])" - "997([Symbol name=convert_element_type])" -> "998([Symbol name=mul])" - "1038([Symbol name=transpose])" - "1037([Symbol name=reshape])" -> "1038([Symbol name=transpose])" - "1083([Symbol name=add])" - "1081([Symbol name=convert_element_type])" -> "1083([Symbol name=add])" - "1082([Symbol name=convert_element_type])" -> "1083([Symbol name=add])" - "1110([Symbol name=mul])" - "1108([Symbol name=convert_element_type])" -> "1110([Symbol name=mul])" - "1109([Symbol name=convert_element_type])" -> "1110([Symbol name=mul])" - "1103([Symbol name=neg])" - "1102([Symbol name=convert_element_type])" -> "1103([Symbol name=neg])" - "1114([Symbol name=mul])" - "1112([Symbol name=convert_element_type])" -> "1114([Symbol name=mul])" - "1113([Symbol name=convert_element_type])" -> "1114([Symbol name=mul])" - "1119([Symbol name=add])" - "1117([Symbol name=convert_element_type])" -> "1119([Symbol name=add])" - "1118([Symbol name=convert_element_type])" -> "1119([Symbol name=add])" - "1034([Symbol name=mul])" - "1032([Symbol name=convert_element_type])" -> "1034([Symbol name=mul])" - "1033([Symbol name=convert_element_type])" -> "1034([Symbol name=mul])" - "1098([Symbol name=mul])" - "1096([Symbol name=convert_element_type])" -> "1098([Symbol name=mul])" - "1097([Symbol name=convert_element_type])" -> "1098([Symbol name=mul])" - "1138([Symbol name=transpose])" - "1137([Symbol name=reshape])" -> "1138([Symbol name=transpose])" - "1183([Symbol name=add])" - "1181([Symbol name=convert_element_type])" -> "1183([Symbol name=add])" - "1182([Symbol name=convert_element_type])" -> "1183([Symbol name=add])" - "1210([Symbol name=mul])" - "1208([Symbol name=convert_element_type])" -> "1210([Symbol name=mul])" - "1209([Symbol name=convert_element_type])" -> "1210([Symbol name=mul])" - "1203([Symbol name=neg])" - "1202([Symbol name=convert_element_type])" -> "1203([Symbol name=neg])" - "1214([Symbol name=mul])" - "1212([Symbol name=convert_element_type])" -> "1214([Symbol name=mul])" - "1213([Symbol name=convert_element_type])" -> "1214([Symbol name=mul])" - "1219([Symbol name=add])" - "1217([Symbol name=convert_element_type])" -> "1219([Symbol name=add])" - "1218([Symbol name=convert_element_type])" -> "1219([Symbol name=add])" - "1134([Symbol name=mul])" - "1132([Symbol name=convert_element_type])" -> "1134([Symbol name=mul])" - "1133([Symbol name=convert_element_type])" -> "1134([Symbol name=mul])" - "1198([Symbol name=mul])" - "1196([Symbol name=convert_element_type])" -> "1198([Symbol name=mul])" - "1197([Symbol name=convert_element_type])" -> "1198([Symbol name=mul])" - "1238([Symbol name=transpose])" - "1237([Symbol name=reshape])" -> "1238([Symbol name=transpose])" - "1283([Symbol name=add])" - "1281([Symbol name=convert_element_type])" -> "1283([Symbol name=add])" - "1282([Symbol name=convert_element_type])" -> "1283([Symbol name=add])" - "1310([Symbol name=mul])" - "1308([Symbol name=convert_element_type])" -> "1310([Symbol name=mul])" - "1309([Symbol name=convert_element_type])" -> "1310([Symbol name=mul])" - "1303([Symbol name=neg])" - "1302([Symbol name=convert_element_type])" -> "1303([Symbol name=neg])" - "1314([Symbol name=mul])" - "1312([Symbol name=convert_element_type])" -> "1314([Symbol name=mul])" - "1313([Symbol name=convert_element_type])" -> "1314([Symbol name=mul])" - "1319([Symbol name=add])" - "1317([Symbol name=convert_element_type])" -> "1319([Symbol name=add])" - "1318([Symbol name=convert_element_type])" -> "1319([Symbol name=add])" - "1234([Symbol name=mul])" - "1232([Symbol name=convert_element_type])" -> "1234([Symbol name=mul])" - "1233([Symbol name=convert_element_type])" -> "1234([Symbol name=mul])" - "1298([Symbol name=mul])" - "1296([Symbol name=convert_element_type])" -> "1298([Symbol name=mul])" - "1297([Symbol name=convert_element_type])" -> "1298([Symbol name=mul])" - "1338([Symbol name=transpose])" - "1337([Symbol name=reshape])" -> "1338([Symbol name=transpose])" - "1383([Symbol name=add])" - "1381([Symbol name=convert_element_type])" -> "1383([Symbol name=add])" - "1382([Symbol name=convert_element_type])" -> "1383([Symbol name=add])" - "1410([Symbol name=mul])" - "1408([Symbol name=convert_element_type])" -> "1410([Symbol name=mul])" - "1409([Symbol name=convert_element_type])" -> "1410([Symbol name=mul])" - "1403([Symbol name=neg])" - "1402([Symbol name=convert_element_type])" -> "1403([Symbol name=neg])" - "1414([Symbol name=mul])" - "1412([Symbol name=convert_element_type])" -> "1414([Symbol name=mul])" - "1413([Symbol name=convert_element_type])" -> "1414([Symbol name=mul])" - "1419([Symbol name=add])" - "1417([Symbol name=convert_element_type])" -> "1419([Symbol name=add])" - "1418([Symbol name=convert_element_type])" -> "1419([Symbol name=add])" - "1334([Symbol name=mul])" - "1332([Symbol name=convert_element_type])" -> "1334([Symbol name=mul])" - "1333([Symbol name=convert_element_type])" -> "1334([Symbol name=mul])" - "1398([Symbol name=mul])" - "1396([Symbol name=convert_element_type])" -> "1398([Symbol name=mul])" - "1397([Symbol name=convert_element_type])" -> "1398([Symbol name=mul])" - "1438([Symbol name=transpose])" - "1437([Symbol name=reshape])" -> "1438([Symbol name=transpose])" - "1483([Symbol name=add])" - "1481([Symbol name=convert_element_type])" -> "1483([Symbol name=add])" - "1482([Symbol name=convert_element_type])" -> "1483([Symbol name=add])" - "1510([Symbol name=mul])" - "1508([Symbol name=convert_element_type])" -> "1510([Symbol name=mul])" - "1509([Symbol name=convert_element_type])" -> "1510([Symbol name=mul])" - "1503([Symbol name=neg])" - "1502([Symbol name=convert_element_type])" -> "1503([Symbol name=neg])" - "1514([Symbol name=mul])" - "1512([Symbol name=convert_element_type])" -> "1514([Symbol name=mul])" - "1513([Symbol name=convert_element_type])" -> "1514([Symbol name=mul])" - "1519([Symbol name=add])" - "1517([Symbol name=convert_element_type])" -> "1519([Symbol name=add])" - "1518([Symbol name=convert_element_type])" -> "1519([Symbol name=add])" - "1434([Symbol name=mul])" - "1432([Symbol name=convert_element_type])" -> "1434([Symbol name=mul])" - "1433([Symbol name=convert_element_type])" -> "1434([Symbol name=mul])" - "1498([Symbol name=mul])" - "1496([Symbol name=convert_element_type])" -> "1498([Symbol name=mul])" - "1497([Symbol name=convert_element_type])" -> "1498([Symbol name=mul])" - "1538([Symbol name=transpose])" - "1537([Symbol name=reshape])" -> "1538([Symbol name=transpose])" - "1583([Symbol name=add])" - "1581([Symbol name=convert_element_type])" -> "1583([Symbol name=add])" - "1582([Symbol name=convert_element_type])" -> "1583([Symbol name=add])" - "1610([Symbol name=mul])" - "1608([Symbol name=convert_element_type])" -> "1610([Symbol name=mul])" - "1609([Symbol name=convert_element_type])" -> "1610([Symbol name=mul])" - "1603([Symbol name=neg])" - "1602([Symbol name=convert_element_type])" -> "1603([Symbol name=neg])" - "1614([Symbol name=mul])" - "1612([Symbol name=convert_element_type])" -> "1614([Symbol name=mul])" - "1613([Symbol name=convert_element_type])" -> "1614([Symbol name=mul])" - "1619([Symbol name=add])" - "1617([Symbol name=convert_element_type])" -> "1619([Symbol name=add])" - "1618([Symbol name=convert_element_type])" -> "1619([Symbol name=add])" - "1534([Symbol name=mul])" - "1532([Symbol name=convert_element_type])" -> "1534([Symbol name=mul])" - "1533([Symbol name=convert_element_type])" -> "1534([Symbol name=mul])" - "1598([Symbol name=mul])" - "1596([Symbol name=convert_element_type])" -> "1598([Symbol name=mul])" - "1597([Symbol name=convert_element_type])" -> "1598([Symbol name=mul])" - "1638([Symbol name=transpose])" - "1637([Symbol name=reshape])" -> "1638([Symbol name=transpose])" - "1683([Symbol name=add])" - "1681([Symbol name=convert_element_type])" -> "1683([Symbol name=add])" - "1682([Symbol name=convert_element_type])" -> "1683([Symbol name=add])" - "1710([Symbol name=mul])" - "1708([Symbol name=convert_element_type])" -> "1710([Symbol name=mul])" - "1709([Symbol name=convert_element_type])" -> "1710([Symbol name=mul])" - "1703([Symbol name=neg])" - "1702([Symbol name=convert_element_type])" -> "1703([Symbol name=neg])" - "1714([Symbol name=mul])" - "1712([Symbol name=convert_element_type])" -> "1714([Symbol name=mul])" - "1713([Symbol name=convert_element_type])" -> "1714([Symbol name=mul])" - "1719([Symbol name=add])" - "1717([Symbol name=convert_element_type])" -> "1719([Symbol name=add])" - "1718([Symbol name=convert_element_type])" -> "1719([Symbol name=add])" - "1634([Symbol name=mul])" - "1632([Symbol name=convert_element_type])" -> "1634([Symbol name=mul])" - "1633([Symbol name=convert_element_type])" -> "1634([Symbol name=mul])" - "1698([Symbol name=mul])" - "1696([Symbol name=convert_element_type])" -> "1698([Symbol name=mul])" - "1697([Symbol name=convert_element_type])" -> "1698([Symbol name=mul])" - "1734([Symbol name=mul])" - "1732([Symbol name=convert_element_type])" -> "1734([Symbol name=mul])" - "1733([Symbol name=convert_element_type])" -> "1734([Symbol name=mul])" - "130([Symbol name=convert_element_type])" - "129([Symbol name=mul])" -> "130([Symbol name=convert_element_type])" - "123([Symbol name=sum])" - "122([Symbol name=mul])" -> "123([Symbol name=sum])" - "184([Symbol name=convert_element_type])" - "183([Symbol name=add])" -> "184([Symbol name=convert_element_type])" - "1671([Symbol name=add])" - "1667([Symbol name=mul])" -> "1671([Symbol name=add])" - "1670([Symbol name=mul])" -> "1671([Symbol name=add])" - "271([Symbol name=add])" - "267([Symbol name=mul])" -> "271([Symbol name=add])" - "270([Symbol name=mul])" -> "271([Symbol name=add])" - "656([Symbol name=add])" - "652([Symbol name=mul])" -> "656([Symbol name=add])" - "655([Symbol name=mul])" -> "656([Symbol name=add])" - "1171([Symbol name=add])" - "1170([Symbol name=mul])" -> "1171([Symbol name=add])" - "1167([Symbol name=mul])" -> "1171([Symbol name=add])" - "1556([Symbol name=add])" - "1552([Symbol name=mul])" -> "1556([Symbol name=add])" - "1555([Symbol name=mul])" -> "1556([Symbol name=add])" - "156([Symbol name=add])" - "152([Symbol name=mul])" -> "156([Symbol name=add])" - "155([Symbol name=mul])" -> "156([Symbol name=add])" - "671([Symbol name=add])" - "667([Symbol name=mul])" -> "671([Symbol name=add])" - "670([Symbol name=mul])" -> "671([Symbol name=add])" - "1056([Symbol name=add])" - "1052([Symbol name=mul])" -> "1056([Symbol name=add])" - "1055([Symbol name=mul])" -> "1056([Symbol name=add])" - "1571([Symbol name=add])" - "1570([Symbol name=mul])" -> "1571([Symbol name=add])" - "1567([Symbol name=mul])" -> "1571([Symbol name=add])" - "171([Symbol name=add])" - "170([Symbol name=mul])" -> "171([Symbol name=add])" - "167([Symbol name=mul])" -> "171([Symbol name=add])" - "556([Symbol name=add])" - "552([Symbol name=mul])" -> "556([Symbol name=add])" - "555([Symbol name=mul])" -> "556([Symbol name=add])" - "1071([Symbol name=add])" - "1067([Symbol name=mul])" -> "1071([Symbol name=add])" - "1070([Symbol name=mul])" -> "1071([Symbol name=add])" - "1456([Symbol name=add])" - "1452([Symbol name=mul])" -> "1456([Symbol name=add])" - "1455([Symbol name=mul])" -> "1456([Symbol name=add])" - "571([Symbol name=add])" - "570([Symbol name=mul])" -> "571([Symbol name=add])" - "567([Symbol name=mul])" -> "571([Symbol name=add])" - "956([Symbol name=add])" - "952([Symbol name=mul])" -> "956([Symbol name=add])" - "955([Symbol name=mul])" -> "956([Symbol name=add])" - "1471([Symbol name=add])" - "1467([Symbol name=mul])" -> "1471([Symbol name=add])" - "1470([Symbol name=mul])" -> "1471([Symbol name=add])" - "456([Symbol name=add])" - "452([Symbol name=mul])" -> "456([Symbol name=add])" - "455([Symbol name=mul])" -> "456([Symbol name=add])" - "971([Symbol name=add])" - "970([Symbol name=mul])" -> "971([Symbol name=add])" - "967([Symbol name=mul])" -> "971([Symbol name=add])" - "1356([Symbol name=add])" - "1352([Symbol name=mul])" -> "1356([Symbol name=add])" - "1355([Symbol name=mul])" -> "1356([Symbol name=add])" - "471([Symbol name=add])" - "467([Symbol name=mul])" -> "471([Symbol name=add])" - "470([Symbol name=mul])" -> "471([Symbol name=add])" - "856([Symbol name=add])" - "852([Symbol name=mul])" -> "856([Symbol name=add])" - "855([Symbol name=mul])" -> "856([Symbol name=add])" - "1371([Symbol name=add])" - "1370([Symbol name=mul])" -> "1371([Symbol name=add])" - "1367([Symbol name=mul])" -> "1371([Symbol name=add])" - "356([Symbol name=add])" - "352([Symbol name=mul])" -> "356([Symbol name=add])" - "355([Symbol name=mul])" -> "356([Symbol name=add])" - "871([Symbol name=add])" - "867([Symbol name=mul])" -> "871([Symbol name=add])" - "870([Symbol name=mul])" -> "871([Symbol name=add])" - "1256([Symbol name=add])" - "1252([Symbol name=mul])" -> "1256([Symbol name=add])" - "1255([Symbol name=mul])" -> "1256([Symbol name=add])" - "371([Symbol name=add])" - "370([Symbol name=mul])" -> "371([Symbol name=add])" - "367([Symbol name=mul])" -> "371([Symbol name=add])" - "756([Symbol name=add])" - "752([Symbol name=mul])" -> "756([Symbol name=add])" - "755([Symbol name=mul])" -> "756([Symbol name=add])" - "1271([Symbol name=add])" - "1267([Symbol name=mul])" -> "1271([Symbol name=add])" - "1270([Symbol name=mul])" -> "1271([Symbol name=add])" - "1656([Symbol name=add])" - "1652([Symbol name=mul])" -> "1656([Symbol name=add])" - "1655([Symbol name=mul])" -> "1656([Symbol name=add])" - "256([Symbol name=add])" - "252([Symbol name=mul])" -> "256([Symbol name=add])" - "255([Symbol name=mul])" -> "256([Symbol name=add])" - "771([Symbol name=add])" - "770([Symbol name=mul])" -> "771([Symbol name=add])" - "767([Symbol name=mul])" -> "771([Symbol name=add])" - "1156([Symbol name=add])" - "1152([Symbol name=mul])" -> "1156([Symbol name=add])" - "1155([Symbol name=mul])" -> "1156([Symbol name=add])" - "139([Symbol name=split])" - "138([Symbol name=transpose])" -> "139([Symbol name=split])" - "211([Symbol name=convert_element_type])" - "210([Symbol name=mul])" -> "211([Symbol name=convert_element_type])" - "204([Symbol name=exp])" - "203([Symbol name=neg])" -> "204([Symbol name=exp])" - "215([Symbol name=convert_element_type])" - "214([Symbol name=mul])" -> "215([Symbol name=convert_element_type])" - "220([Symbol name=convert_element_type])" - "219([Symbol name=add])" -> "220([Symbol name=convert_element_type])" - "135([Symbol name=convert_element_type])" - "134([Symbol name=mul])" -> "135([Symbol name=convert_element_type])" - "199([Symbol name=convert_element_type])" - "198([Symbol name=mul])" -> "199([Symbol name=convert_element_type])" - "239([Symbol name=split])" - "238([Symbol name=transpose])" -> "239([Symbol name=split])" - "284([Symbol name=convert_element_type])" - "283([Symbol name=add])" -> "284([Symbol name=convert_element_type])" - "311([Symbol name=convert_element_type])" - "310([Symbol name=mul])" -> "311([Symbol name=convert_element_type])" - "304([Symbol name=exp])" - "303([Symbol name=neg])" -> "304([Symbol name=exp])" - "315([Symbol name=convert_element_type])" - "314([Symbol name=mul])" -> "315([Symbol name=convert_element_type])" - "320([Symbol name=convert_element_type])" - "319([Symbol name=add])" -> "320([Symbol name=convert_element_type])" - "235([Symbol name=convert_element_type])" - "234([Symbol name=mul])" -> "235([Symbol name=convert_element_type])" - "299([Symbol name=convert_element_type])" - "298([Symbol name=mul])" -> "299([Symbol name=convert_element_type])" - "339([Symbol name=split])" - "338([Symbol name=transpose])" -> "339([Symbol name=split])" - "384([Symbol name=convert_element_type])" - "383([Symbol name=add])" -> "384([Symbol name=convert_element_type])" - "411([Symbol name=convert_element_type])" - "410([Symbol name=mul])" -> "411([Symbol name=convert_element_type])" - "404([Symbol name=exp])" - "403([Symbol name=neg])" -> "404([Symbol name=exp])" - "415([Symbol name=convert_element_type])" - "414([Symbol name=mul])" -> "415([Symbol name=convert_element_type])" - "420([Symbol name=convert_element_type])" - "419([Symbol name=add])" -> "420([Symbol name=convert_element_type])" - "335([Symbol name=convert_element_type])" - "334([Symbol name=mul])" -> "335([Symbol name=convert_element_type])" - "399([Symbol name=convert_element_type])" - "398([Symbol name=mul])" -> "399([Symbol name=convert_element_type])" - "439([Symbol name=split])" - "438([Symbol name=transpose])" -> "439([Symbol name=split])" - "484([Symbol name=convert_element_type])" - "483([Symbol name=add])" -> "484([Symbol name=convert_element_type])" - "511([Symbol name=convert_element_type])" - "510([Symbol name=mul])" -> "511([Symbol name=convert_element_type])" - "504([Symbol name=exp])" - "503([Symbol name=neg])" -> "504([Symbol name=exp])" - "515([Symbol name=convert_element_type])" - "514([Symbol name=mul])" -> "515([Symbol name=convert_element_type])" - "520([Symbol name=convert_element_type])" - "519([Symbol name=add])" -> "520([Symbol name=convert_element_type])" - "435([Symbol name=convert_element_type])" - "434([Symbol name=mul])" -> "435([Symbol name=convert_element_type])" - "499([Symbol name=convert_element_type])" - "498([Symbol name=mul])" -> "499([Symbol name=convert_element_type])" - "539([Symbol name=split])" - "538([Symbol name=transpose])" -> "539([Symbol name=split])" - "584([Symbol name=convert_element_type])" - "583([Symbol name=add])" -> "584([Symbol name=convert_element_type])" - "611([Symbol name=convert_element_type])" - "610([Symbol name=mul])" -> "611([Symbol name=convert_element_type])" - "604([Symbol name=exp])" - "603([Symbol name=neg])" -> "604([Symbol name=exp])" - "615([Symbol name=convert_element_type])" - "614([Symbol name=mul])" -> "615([Symbol name=convert_element_type])" - "620([Symbol name=convert_element_type])" - "619([Symbol name=add])" -> "620([Symbol name=convert_element_type])" - "535([Symbol name=convert_element_type])" - "534([Symbol name=mul])" -> "535([Symbol name=convert_element_type])" - "599([Symbol name=convert_element_type])" - "598([Symbol name=mul])" -> "599([Symbol name=convert_element_type])" - "639([Symbol name=split])" - "638([Symbol name=transpose])" -> "639([Symbol name=split])" - "684([Symbol name=convert_element_type])" - "683([Symbol name=add])" -> "684([Symbol name=convert_element_type])" - "711([Symbol name=convert_element_type])" - "710([Symbol name=mul])" -> "711([Symbol name=convert_element_type])" - "704([Symbol name=exp])" - "703([Symbol name=neg])" -> "704([Symbol name=exp])" - "715([Symbol name=convert_element_type])" - "714([Symbol name=mul])" -> "715([Symbol name=convert_element_type])" - "720([Symbol name=convert_element_type])" - "719([Symbol name=add])" -> "720([Symbol name=convert_element_type])" - "635([Symbol name=convert_element_type])" - "634([Symbol name=mul])" -> "635([Symbol name=convert_element_type])" - "699([Symbol name=convert_element_type])" - "698([Symbol name=mul])" -> "699([Symbol name=convert_element_type])" - "739([Symbol name=split])" - "738([Symbol name=transpose])" -> "739([Symbol name=split])" - "784([Symbol name=convert_element_type])" - "783([Symbol name=add])" -> "784([Symbol name=convert_element_type])" - "811([Symbol name=convert_element_type])" - "810([Symbol name=mul])" -> "811([Symbol name=convert_element_type])" - "804([Symbol name=exp])" - "803([Symbol name=neg])" -> "804([Symbol name=exp])" - "815([Symbol name=convert_element_type])" - "814([Symbol name=mul])" -> "815([Symbol name=convert_element_type])" - "820([Symbol name=convert_element_type])" - "819([Symbol name=add])" -> "820([Symbol name=convert_element_type])" - "735([Symbol name=convert_element_type])" - "734([Symbol name=mul])" -> "735([Symbol name=convert_element_type])" - "799([Symbol name=convert_element_type])" - "798([Symbol name=mul])" -> "799([Symbol name=convert_element_type])" - "839([Symbol name=split])" - "838([Symbol name=transpose])" -> "839([Symbol name=split])" - "884([Symbol name=convert_element_type])" - "883([Symbol name=add])" -> "884([Symbol name=convert_element_type])" - "911([Symbol name=convert_element_type])" - "910([Symbol name=mul])" -> "911([Symbol name=convert_element_type])" - "904([Symbol name=exp])" - "903([Symbol name=neg])" -> "904([Symbol name=exp])" - "915([Symbol name=convert_element_type])" - "914([Symbol name=mul])" -> "915([Symbol name=convert_element_type])" - "920([Symbol name=convert_element_type])" - "919([Symbol name=add])" -> "920([Symbol name=convert_element_type])" - "835([Symbol name=convert_element_type])" - "834([Symbol name=mul])" -> "835([Symbol name=convert_element_type])" - "899([Symbol name=convert_element_type])" - "898([Symbol name=mul])" -> "899([Symbol name=convert_element_type])" - "939([Symbol name=split])" - "938([Symbol name=transpose])" -> "939([Symbol name=split])" - "984([Symbol name=convert_element_type])" - "983([Symbol name=add])" -> "984([Symbol name=convert_element_type])" - "1011([Symbol name=convert_element_type])" - "1010([Symbol name=mul])" -> "1011([Symbol name=convert_element_type])" - "1004([Symbol name=exp])" - "1003([Symbol name=neg])" -> "1004([Symbol name=exp])" - "1015([Symbol name=convert_element_type])" - "1014([Symbol name=mul])" -> "1015([Symbol name=convert_element_type])" - "1020([Symbol name=convert_element_type])" - "1019([Symbol name=add])" -> "1020([Symbol name=convert_element_type])" - "935([Symbol name=convert_element_type])" - "934([Symbol name=mul])" -> "935([Symbol name=convert_element_type])" - "999([Symbol name=convert_element_type])" - "998([Symbol name=mul])" -> "999([Symbol name=convert_element_type])" - "1039([Symbol name=split])" - "1038([Symbol name=transpose])" -> "1039([Symbol name=split])" - "1084([Symbol name=convert_element_type])" - "1083([Symbol name=add])" -> "1084([Symbol name=convert_element_type])" - "1111([Symbol name=convert_element_type])" - "1110([Symbol name=mul])" -> "1111([Symbol name=convert_element_type])" - "1104([Symbol name=exp])" - "1103([Symbol name=neg])" -> "1104([Symbol name=exp])" - "1115([Symbol name=convert_element_type])" - "1114([Symbol name=mul])" -> "1115([Symbol name=convert_element_type])" - "1120([Symbol name=convert_element_type])" - "1119([Symbol name=add])" -> "1120([Symbol name=convert_element_type])" - "1035([Symbol name=convert_element_type])" - "1034([Symbol name=mul])" -> "1035([Symbol name=convert_element_type])" - "1099([Symbol name=convert_element_type])" - "1098([Symbol name=mul])" -> "1099([Symbol name=convert_element_type])" - "1139([Symbol name=split])" - "1138([Symbol name=transpose])" -> "1139([Symbol name=split])" - "1184([Symbol name=convert_element_type])" - "1183([Symbol name=add])" -> "1184([Symbol name=convert_element_type])" - "1211([Symbol name=convert_element_type])" - "1210([Symbol name=mul])" -> "1211([Symbol name=convert_element_type])" - "1204([Symbol name=exp])" - "1203([Symbol name=neg])" -> "1204([Symbol name=exp])" - "1215([Symbol name=convert_element_type])" - "1214([Symbol name=mul])" -> "1215([Symbol name=convert_element_type])" - "1220([Symbol name=convert_element_type])" - "1219([Symbol name=add])" -> "1220([Symbol name=convert_element_type])" - "1135([Symbol name=convert_element_type])" - "1134([Symbol name=mul])" -> "1135([Symbol name=convert_element_type])" - "1199([Symbol name=convert_element_type])" - "1198([Symbol name=mul])" -> "1199([Symbol name=convert_element_type])" - "1239([Symbol name=split])" - "1238([Symbol name=transpose])" -> "1239([Symbol name=split])" - "1284([Symbol name=convert_element_type])" - "1283([Symbol name=add])" -> "1284([Symbol name=convert_element_type])" - "1311([Symbol name=convert_element_type])" - "1310([Symbol name=mul])" -> "1311([Symbol name=convert_element_type])" - "1304([Symbol name=exp])" - "1303([Symbol name=neg])" -> "1304([Symbol name=exp])" - "1315([Symbol name=convert_element_type])" - "1314([Symbol name=mul])" -> "1315([Symbol name=convert_element_type])" - "1320([Symbol name=convert_element_type])" - "1319([Symbol name=add])" -> "1320([Symbol name=convert_element_type])" - "1235([Symbol name=convert_element_type])" - "1234([Symbol name=mul])" -> "1235([Symbol name=convert_element_type])" - "1299([Symbol name=convert_element_type])" - "1298([Symbol name=mul])" -> "1299([Symbol name=convert_element_type])" - "1339([Symbol name=split])" - "1338([Symbol name=transpose])" -> "1339([Symbol name=split])" - "1384([Symbol name=convert_element_type])" - "1383([Symbol name=add])" -> "1384([Symbol name=convert_element_type])" - "1411([Symbol name=convert_element_type])" - "1410([Symbol name=mul])" -> "1411([Symbol name=convert_element_type])" - "1404([Symbol name=exp])" - "1403([Symbol name=neg])" -> "1404([Symbol name=exp])" - "1415([Symbol name=convert_element_type])" - "1414([Symbol name=mul])" -> "1415([Symbol name=convert_element_type])" - "1420([Symbol name=convert_element_type])" - "1419([Symbol name=add])" -> "1420([Symbol name=convert_element_type])" - "1335([Symbol name=convert_element_type])" - "1334([Symbol name=mul])" -> "1335([Symbol name=convert_element_type])" - "1399([Symbol name=convert_element_type])" - "1398([Symbol name=mul])" -> "1399([Symbol name=convert_element_type])" - "1439([Symbol name=split])" - "1438([Symbol name=transpose])" -> "1439([Symbol name=split])" - "1484([Symbol name=convert_element_type])" - "1483([Symbol name=add])" -> "1484([Symbol name=convert_element_type])" - "1511([Symbol name=convert_element_type])" - "1510([Symbol name=mul])" -> "1511([Symbol name=convert_element_type])" - "1504([Symbol name=exp])" - "1503([Symbol name=neg])" -> "1504([Symbol name=exp])" - "1515([Symbol name=convert_element_type])" - "1514([Symbol name=mul])" -> "1515([Symbol name=convert_element_type])" - "1520([Symbol name=convert_element_type])" - "1519([Symbol name=add])" -> "1520([Symbol name=convert_element_type])" - "1435([Symbol name=convert_element_type])" - "1434([Symbol name=mul])" -> "1435([Symbol name=convert_element_type])" - "1499([Symbol name=convert_element_type])" - "1498([Symbol name=mul])" -> "1499([Symbol name=convert_element_type])" - "1539([Symbol name=split])" - "1538([Symbol name=transpose])" -> "1539([Symbol name=split])" - "1584([Symbol name=convert_element_type])" - "1583([Symbol name=add])" -> "1584([Symbol name=convert_element_type])" - "1611([Symbol name=convert_element_type])" - "1610([Symbol name=mul])" -> "1611([Symbol name=convert_element_type])" - "1604([Symbol name=exp])" - "1603([Symbol name=neg])" -> "1604([Symbol name=exp])" - "1615([Symbol name=convert_element_type])" - "1614([Symbol name=mul])" -> "1615([Symbol name=convert_element_type])" - "1620([Symbol name=convert_element_type])" - "1619([Symbol name=add])" -> "1620([Symbol name=convert_element_type])" - "1535([Symbol name=convert_element_type])" - "1534([Symbol name=mul])" -> "1535([Symbol name=convert_element_type])" - "1599([Symbol name=convert_element_type])" - "1598([Symbol name=mul])" -> "1599([Symbol name=convert_element_type])" - "1639([Symbol name=split])" - "1638([Symbol name=transpose])" -> "1639([Symbol name=split])" - "1684([Symbol name=convert_element_type])" - "1683([Symbol name=add])" -> "1684([Symbol name=convert_element_type])" - "1711([Symbol name=convert_element_type])" - "1710([Symbol name=mul])" -> "1711([Symbol name=convert_element_type])" - "1704([Symbol name=exp])" - "1703([Symbol name=neg])" -> "1704([Symbol name=exp])" - "1715([Symbol name=convert_element_type])" - "1714([Symbol name=mul])" -> "1715([Symbol name=convert_element_type])" - "1720([Symbol name=convert_element_type])" - "1719([Symbol name=add])" -> "1720([Symbol name=convert_element_type])" - "1635([Symbol name=convert_element_type])" - "1634([Symbol name=mul])" -> "1635([Symbol name=convert_element_type])" - "1699([Symbol name=convert_element_type])" - "1698([Symbol name=mul])" -> "1699([Symbol name=convert_element_type])" - "1735([Symbol name=convert_element_type])" - "1734([Symbol name=mul])" -> "1735([Symbol name=convert_element_type])" - "132([Symbol name=convert_element_type])" - "130([Symbol name=convert_element_type])" -> "132([Symbol name=convert_element_type])" - "124([Symbol name=broadcast_in_dim])" - "123([Symbol name=sum])" -> "124([Symbol name=broadcast_in_dim])" - "185([Symbol name=convert_element_type])" - "184([Symbol name=convert_element_type])" -> "185([Symbol name=convert_element_type])" - "218([Symbol name=convert_element_type])" - "184([Symbol name=convert_element_type])" -> "218([Symbol name=convert_element_type])" - "1672([Symbol name=convert_element_type])" - "1671([Symbol name=add])" -> "1672([Symbol name=convert_element_type])" - "272([Symbol name=convert_element_type])" - "271([Symbol name=add])" -> "272([Symbol name=convert_element_type])" - "657([Symbol name=convert_element_type])" - "656([Symbol name=add])" -> "657([Symbol name=convert_element_type])" - "1172([Symbol name=convert_element_type])" - "1171([Symbol name=add])" -> "1172([Symbol name=convert_element_type])" - "1557([Symbol name=convert_element_type])" - "1556([Symbol name=add])" -> "1557([Symbol name=convert_element_type])" - "157([Symbol name=convert_element_type])" - "156([Symbol name=add])" -> "157([Symbol name=convert_element_type])" - "672([Symbol name=convert_element_type])" - "671([Symbol name=add])" -> "672([Symbol name=convert_element_type])" - "1057([Symbol name=convert_element_type])" - "1056([Symbol name=add])" -> "1057([Symbol name=convert_element_type])" - "1572([Symbol name=convert_element_type])" - "1571([Symbol name=add])" -> "1572([Symbol name=convert_element_type])" - "172([Symbol name=convert_element_type])" - "171([Symbol name=add])" -> "172([Symbol name=convert_element_type])" - "557([Symbol name=convert_element_type])" - "556([Symbol name=add])" -> "557([Symbol name=convert_element_type])" - "1072([Symbol name=convert_element_type])" - "1071([Symbol name=add])" -> "1072([Symbol name=convert_element_type])" - "1457([Symbol name=convert_element_type])" - "1456([Symbol name=add])" -> "1457([Symbol name=convert_element_type])" - "572([Symbol name=convert_element_type])" - "571([Symbol name=add])" -> "572([Symbol name=convert_element_type])" - "957([Symbol name=convert_element_type])" - "956([Symbol name=add])" -> "957([Symbol name=convert_element_type])" - "1472([Symbol name=convert_element_type])" - "1471([Symbol name=add])" -> "1472([Symbol name=convert_element_type])" - "457([Symbol name=convert_element_type])" - "456([Symbol name=add])" -> "457([Symbol name=convert_element_type])" - "972([Symbol name=convert_element_type])" - "971([Symbol name=add])" -> "972([Symbol name=convert_element_type])" - "1357([Symbol name=convert_element_type])" - "1356([Symbol name=add])" -> "1357([Symbol name=convert_element_type])" - "472([Symbol name=convert_element_type])" - "471([Symbol name=add])" -> "472([Symbol name=convert_element_type])" - "857([Symbol name=convert_element_type])" - "856([Symbol name=add])" -> "857([Symbol name=convert_element_type])" - "1372([Symbol name=convert_element_type])" - "1371([Symbol name=add])" -> "1372([Symbol name=convert_element_type])" - "357([Symbol name=convert_element_type])" - "356([Symbol name=add])" -> "357([Symbol name=convert_element_type])" - "872([Symbol name=convert_element_type])" - "871([Symbol name=add])" -> "872([Symbol name=convert_element_type])" - "1257([Symbol name=convert_element_type])" - "1256([Symbol name=add])" -> "1257([Symbol name=convert_element_type])" - "372([Symbol name=convert_element_type])" - "371([Symbol name=add])" -> "372([Symbol name=convert_element_type])" - "757([Symbol name=convert_element_type])" - "756([Symbol name=add])" -> "757([Symbol name=convert_element_type])" - "1272([Symbol name=convert_element_type])" - "1271([Symbol name=add])" -> "1272([Symbol name=convert_element_type])" - "1657([Symbol name=convert_element_type])" - "1656([Symbol name=add])" -> "1657([Symbol name=convert_element_type])" - "257([Symbol name=convert_element_type])" - "256([Symbol name=add])" -> "257([Symbol name=convert_element_type])" - "772([Symbol name=convert_element_type])" - "771([Symbol name=add])" -> "772([Symbol name=convert_element_type])" - "1157([Symbol name=convert_element_type])" - "1156([Symbol name=add])" -> "1157([Symbol name=convert_element_type])" - "140([Symbol name=reshape])" - "139([Symbol name=split])" -> "140([Symbol name=reshape])" - "141([Symbol name=reshape])" - "139([Symbol name=split])" -> "141([Symbol name=reshape])" - "142([Symbol name=reshape])" - "139([Symbol name=split])" -> "142([Symbol name=reshape])" - "212([Symbol name=convert_element_type])" - "211([Symbol name=convert_element_type])" -> "212([Symbol name=convert_element_type])" - "205([Symbol name=add])" - "204([Symbol name=exp])" -> "205([Symbol name=add])" - "282([Symbol name=convert_element_type])" - "220([Symbol name=convert_element_type])" -> "282([Symbol name=convert_element_type])" - "221([Symbol name=convert_element_type])" - "220([Symbol name=convert_element_type])" -> "221([Symbol name=convert_element_type])" - "240([Symbol name=reshape])" - "239([Symbol name=split])" -> "240([Symbol name=reshape])" - "241([Symbol name=reshape])" - "239([Symbol name=split])" -> "241([Symbol name=reshape])" - "242([Symbol name=reshape])" - "239([Symbol name=split])" -> "242([Symbol name=reshape])" - "285([Symbol name=convert_element_type])" - "284([Symbol name=convert_element_type])" -> "285([Symbol name=convert_element_type])" - "318([Symbol name=convert_element_type])" - "284([Symbol name=convert_element_type])" -> "318([Symbol name=convert_element_type])" - "312([Symbol name=convert_element_type])" - "311([Symbol name=convert_element_type])" -> "312([Symbol name=convert_element_type])" - "305([Symbol name=add])" - "304([Symbol name=exp])" -> "305([Symbol name=add])" - "321([Symbol name=convert_element_type])" - "320([Symbol name=convert_element_type])" -> "321([Symbol name=convert_element_type])" - "382([Symbol name=convert_element_type])" - "320([Symbol name=convert_element_type])" -> "382([Symbol name=convert_element_type])" - "340([Symbol name=reshape])" - "339([Symbol name=split])" -> "340([Symbol name=reshape])" - "341([Symbol name=reshape])" - "339([Symbol name=split])" -> "341([Symbol name=reshape])" - "342([Symbol name=reshape])" - "339([Symbol name=split])" -> "342([Symbol name=reshape])" - "385([Symbol name=convert_element_type])" - "384([Symbol name=convert_element_type])" -> "385([Symbol name=convert_element_type])" - "418([Symbol name=convert_element_type])" - "384([Symbol name=convert_element_type])" -> "418([Symbol name=convert_element_type])" - "412([Symbol name=convert_element_type])" - "411([Symbol name=convert_element_type])" -> "412([Symbol name=convert_element_type])" - "405([Symbol name=add])" - "404([Symbol name=exp])" -> "405([Symbol name=add])" - "482([Symbol name=convert_element_type])" - "420([Symbol name=convert_element_type])" -> "482([Symbol name=convert_element_type])" - "421([Symbol name=convert_element_type])" - "420([Symbol name=convert_element_type])" -> "421([Symbol name=convert_element_type])" - "440([Symbol name=reshape])" - "439([Symbol name=split])" -> "440([Symbol name=reshape])" - "441([Symbol name=reshape])" - "439([Symbol name=split])" -> "441([Symbol name=reshape])" - "442([Symbol name=reshape])" - "439([Symbol name=split])" -> "442([Symbol name=reshape])" - "485([Symbol name=convert_element_type])" - "484([Symbol name=convert_element_type])" -> "485([Symbol name=convert_element_type])" - "518([Symbol name=convert_element_type])" - "484([Symbol name=convert_element_type])" -> "518([Symbol name=convert_element_type])" - "512([Symbol name=convert_element_type])" - "511([Symbol name=convert_element_type])" -> "512([Symbol name=convert_element_type])" - "505([Symbol name=add])" - "504([Symbol name=exp])" -> "505([Symbol name=add])" - "521([Symbol name=convert_element_type])" - "520([Symbol name=convert_element_type])" -> "521([Symbol name=convert_element_type])" - "582([Symbol name=convert_element_type])" - "520([Symbol name=convert_element_type])" -> "582([Symbol name=convert_element_type])" - "540([Symbol name=reshape])" - "539([Symbol name=split])" -> "540([Symbol name=reshape])" - "541([Symbol name=reshape])" - "539([Symbol name=split])" -> "541([Symbol name=reshape])" - "542([Symbol name=reshape])" - "539([Symbol name=split])" -> "542([Symbol name=reshape])" - "585([Symbol name=convert_element_type])" - "584([Symbol name=convert_element_type])" -> "585([Symbol name=convert_element_type])" - "618([Symbol name=convert_element_type])" - "584([Symbol name=convert_element_type])" -> "618([Symbol name=convert_element_type])" - "612([Symbol name=convert_element_type])" - "611([Symbol name=convert_element_type])" -> "612([Symbol name=convert_element_type])" - "605([Symbol name=add])" - "604([Symbol name=exp])" -> "605([Symbol name=add])" - "682([Symbol name=convert_element_type])" - "620([Symbol name=convert_element_type])" -> "682([Symbol name=convert_element_type])" - "621([Symbol name=convert_element_type])" - "620([Symbol name=convert_element_type])" -> "621([Symbol name=convert_element_type])" - "640([Symbol name=reshape])" - "639([Symbol name=split])" -> "640([Symbol name=reshape])" - "641([Symbol name=reshape])" - "639([Symbol name=split])" -> "641([Symbol name=reshape])" - "642([Symbol name=reshape])" - "639([Symbol name=split])" -> "642([Symbol name=reshape])" - "685([Symbol name=convert_element_type])" - "684([Symbol name=convert_element_type])" -> "685([Symbol name=convert_element_type])" - "718([Symbol name=convert_element_type])" - "684([Symbol name=convert_element_type])" -> "718([Symbol name=convert_element_type])" - "712([Symbol name=convert_element_type])" - "711([Symbol name=convert_element_type])" -> "712([Symbol name=convert_element_type])" - "705([Symbol name=add])" - "704([Symbol name=exp])" -> "705([Symbol name=add])" - "721([Symbol name=convert_element_type])" - "720([Symbol name=convert_element_type])" -> "721([Symbol name=convert_element_type])" - "782([Symbol name=convert_element_type])" - "720([Symbol name=convert_element_type])" -> "782([Symbol name=convert_element_type])" - "740([Symbol name=reshape])" - "739([Symbol name=split])" -> "740([Symbol name=reshape])" - "741([Symbol name=reshape])" - "739([Symbol name=split])" -> "741([Symbol name=reshape])" - "742([Symbol name=reshape])" - "739([Symbol name=split])" -> "742([Symbol name=reshape])" - "785([Symbol name=convert_element_type])" - "784([Symbol name=convert_element_type])" -> "785([Symbol name=convert_element_type])" - "818([Symbol name=convert_element_type])" - "784([Symbol name=convert_element_type])" -> "818([Symbol name=convert_element_type])" - "812([Symbol name=convert_element_type])" - "811([Symbol name=convert_element_type])" -> "812([Symbol name=convert_element_type])" - "805([Symbol name=add])" - "804([Symbol name=exp])" -> "805([Symbol name=add])" - "882([Symbol name=convert_element_type])" - "820([Symbol name=convert_element_type])" -> "882([Symbol name=convert_element_type])" - "821([Symbol name=convert_element_type])" - "820([Symbol name=convert_element_type])" -> "821([Symbol name=convert_element_type])" - "840([Symbol name=reshape])" - "839([Symbol name=split])" -> "840([Symbol name=reshape])" - "841([Symbol name=reshape])" - "839([Symbol name=split])" -> "841([Symbol name=reshape])" - "842([Symbol name=reshape])" - "839([Symbol name=split])" -> "842([Symbol name=reshape])" - "885([Symbol name=convert_element_type])" - "884([Symbol name=convert_element_type])" -> "885([Symbol name=convert_element_type])" - "918([Symbol name=convert_element_type])" - "884([Symbol name=convert_element_type])" -> "918([Symbol name=convert_element_type])" - "912([Symbol name=convert_element_type])" - "911([Symbol name=convert_element_type])" -> "912([Symbol name=convert_element_type])" - "905([Symbol name=add])" - "904([Symbol name=exp])" -> "905([Symbol name=add])" - "921([Symbol name=convert_element_type])" - "920([Symbol name=convert_element_type])" -> "921([Symbol name=convert_element_type])" - "982([Symbol name=convert_element_type])" - "920([Symbol name=convert_element_type])" -> "982([Symbol name=convert_element_type])" - "940([Symbol name=reshape])" - "939([Symbol name=split])" -> "940([Symbol name=reshape])" - "941([Symbol name=reshape])" - "939([Symbol name=split])" -> "941([Symbol name=reshape])" - "942([Symbol name=reshape])" - "939([Symbol name=split])" -> "942([Symbol name=reshape])" - "985([Symbol name=convert_element_type])" - "984([Symbol name=convert_element_type])" -> "985([Symbol name=convert_element_type])" - "1018([Symbol name=convert_element_type])" - "984([Symbol name=convert_element_type])" -> "1018([Symbol name=convert_element_type])" - "1012([Symbol name=convert_element_type])" - "1011([Symbol name=convert_element_type])" -> "1012([Symbol name=convert_element_type])" - "1005([Symbol name=add])" - "1004([Symbol name=exp])" -> "1005([Symbol name=add])" - "1082([Symbol name=convert_element_type])" - "1020([Symbol name=convert_element_type])" -> "1082([Symbol name=convert_element_type])" - "1021([Symbol name=convert_element_type])" - "1020([Symbol name=convert_element_type])" -> "1021([Symbol name=convert_element_type])" - "1040([Symbol name=reshape])" - "1039([Symbol name=split])" -> "1040([Symbol name=reshape])" - "1041([Symbol name=reshape])" - "1039([Symbol name=split])" -> "1041([Symbol name=reshape])" - "1042([Symbol name=reshape])" - "1039([Symbol name=split])" -> "1042([Symbol name=reshape])" - "1085([Symbol name=convert_element_type])" - "1084([Symbol name=convert_element_type])" -> "1085([Symbol name=convert_element_type])" - "1118([Symbol name=convert_element_type])" - "1084([Symbol name=convert_element_type])" -> "1118([Symbol name=convert_element_type])" - "1112([Symbol name=convert_element_type])" - "1111([Symbol name=convert_element_type])" -> "1112([Symbol name=convert_element_type])" - "1105([Symbol name=add])" - "1104([Symbol name=exp])" -> "1105([Symbol name=add])" - "1121([Symbol name=convert_element_type])" - "1120([Symbol name=convert_element_type])" -> "1121([Symbol name=convert_element_type])" - "1182([Symbol name=convert_element_type])" - "1120([Symbol name=convert_element_type])" -> "1182([Symbol name=convert_element_type])" - "1140([Symbol name=reshape])" - "1139([Symbol name=split])" -> "1140([Symbol name=reshape])" - "1141([Symbol name=reshape])" - "1139([Symbol name=split])" -> "1141([Symbol name=reshape])" - "1142([Symbol name=reshape])" - "1139([Symbol name=split])" -> "1142([Symbol name=reshape])" - "1185([Symbol name=convert_element_type])" - "1184([Symbol name=convert_element_type])" -> "1185([Symbol name=convert_element_type])" - "1218([Symbol name=convert_element_type])" - "1184([Symbol name=convert_element_type])" -> "1218([Symbol name=convert_element_type])" - "1212([Symbol name=convert_element_type])" - "1211([Symbol name=convert_element_type])" -> "1212([Symbol name=convert_element_type])" - "1205([Symbol name=add])" - "1204([Symbol name=exp])" -> "1205([Symbol name=add])" - "1282([Symbol name=convert_element_type])" - "1220([Symbol name=convert_element_type])" -> "1282([Symbol name=convert_element_type])" - "1221([Symbol name=convert_element_type])" - "1220([Symbol name=convert_element_type])" -> "1221([Symbol name=convert_element_type])" - "1240([Symbol name=reshape])" - "1239([Symbol name=split])" -> "1240([Symbol name=reshape])" - "1241([Symbol name=reshape])" - "1239([Symbol name=split])" -> "1241([Symbol name=reshape])" - "1242([Symbol name=reshape])" - "1239([Symbol name=split])" -> "1242([Symbol name=reshape])" - "1285([Symbol name=convert_element_type])" - "1284([Symbol name=convert_element_type])" -> "1285([Symbol name=convert_element_type])" - "1318([Symbol name=convert_element_type])" - "1284([Symbol name=convert_element_type])" -> "1318([Symbol name=convert_element_type])" - "1312([Symbol name=convert_element_type])" - "1311([Symbol name=convert_element_type])" -> "1312([Symbol name=convert_element_type])" - "1305([Symbol name=add])" - "1304([Symbol name=exp])" -> "1305([Symbol name=add])" - "1321([Symbol name=convert_element_type])" - "1320([Symbol name=convert_element_type])" -> "1321([Symbol name=convert_element_type])" - "1382([Symbol name=convert_element_type])" - "1320([Symbol name=convert_element_type])" -> "1382([Symbol name=convert_element_type])" - "1340([Symbol name=reshape])" - "1339([Symbol name=split])" -> "1340([Symbol name=reshape])" - "1341([Symbol name=reshape])" - "1339([Symbol name=split])" -> "1341([Symbol name=reshape])" - "1342([Symbol name=reshape])" - "1339([Symbol name=split])" -> "1342([Symbol name=reshape])" - "1385([Symbol name=convert_element_type])" - "1384([Symbol name=convert_element_type])" -> "1385([Symbol name=convert_element_type])" - "1418([Symbol name=convert_element_type])" - "1384([Symbol name=convert_element_type])" -> "1418([Symbol name=convert_element_type])" - "1412([Symbol name=convert_element_type])" - "1411([Symbol name=convert_element_type])" -> "1412([Symbol name=convert_element_type])" - "1405([Symbol name=add])" - "1404([Symbol name=exp])" -> "1405([Symbol name=add])" - "1482([Symbol name=convert_element_type])" - "1420([Symbol name=convert_element_type])" -> "1482([Symbol name=convert_element_type])" - "1421([Symbol name=convert_element_type])" - "1420([Symbol name=convert_element_type])" -> "1421([Symbol name=convert_element_type])" - "1440([Symbol name=reshape])" - "1439([Symbol name=split])" -> "1440([Symbol name=reshape])" - "1441([Symbol name=reshape])" - "1439([Symbol name=split])" -> "1441([Symbol name=reshape])" - "1442([Symbol name=reshape])" - "1439([Symbol name=split])" -> "1442([Symbol name=reshape])" - "1485([Symbol name=convert_element_type])" - "1484([Symbol name=convert_element_type])" -> "1485([Symbol name=convert_element_type])" - "1518([Symbol name=convert_element_type])" - "1484([Symbol name=convert_element_type])" -> "1518([Symbol name=convert_element_type])" - "1512([Symbol name=convert_element_type])" - "1511([Symbol name=convert_element_type])" -> "1512([Symbol name=convert_element_type])" - "1505([Symbol name=add])" - "1504([Symbol name=exp])" -> "1505([Symbol name=add])" - "1521([Symbol name=convert_element_type])" - "1520([Symbol name=convert_element_type])" -> "1521([Symbol name=convert_element_type])" - "1582([Symbol name=convert_element_type])" - "1520([Symbol name=convert_element_type])" -> "1582([Symbol name=convert_element_type])" - "1540([Symbol name=reshape])" - "1539([Symbol name=split])" -> "1540([Symbol name=reshape])" - "1541([Symbol name=reshape])" - "1539([Symbol name=split])" -> "1541([Symbol name=reshape])" - "1542([Symbol name=reshape])" - "1539([Symbol name=split])" -> "1542([Symbol name=reshape])" - "1585([Symbol name=convert_element_type])" - "1584([Symbol name=convert_element_type])" -> "1585([Symbol name=convert_element_type])" - "1618([Symbol name=convert_element_type])" - "1584([Symbol name=convert_element_type])" -> "1618([Symbol name=convert_element_type])" - "1612([Symbol name=convert_element_type])" - "1611([Symbol name=convert_element_type])" -> "1612([Symbol name=convert_element_type])" - "1605([Symbol name=add])" - "1604([Symbol name=exp])" -> "1605([Symbol name=add])" - "1682([Symbol name=convert_element_type])" - "1620([Symbol name=convert_element_type])" -> "1682([Symbol name=convert_element_type])" - "1621([Symbol name=convert_element_type])" - "1620([Symbol name=convert_element_type])" -> "1621([Symbol name=convert_element_type])" - "1640([Symbol name=reshape])" - "1639([Symbol name=split])" -> "1640([Symbol name=reshape])" - "1641([Symbol name=reshape])" - "1639([Symbol name=split])" -> "1641([Symbol name=reshape])" - "1642([Symbol name=reshape])" - "1639([Symbol name=split])" -> "1642([Symbol name=reshape])" - "1685([Symbol name=convert_element_type])" - "1684([Symbol name=convert_element_type])" -> "1685([Symbol name=convert_element_type])" - "1718([Symbol name=convert_element_type])" - "1684([Symbol name=convert_element_type])" -> "1718([Symbol name=convert_element_type])" - "1712([Symbol name=convert_element_type])" - "1711([Symbol name=convert_element_type])" -> "1712([Symbol name=convert_element_type])" - "1705([Symbol name=add])" - "1704([Symbol name=exp])" -> "1705([Symbol name=add])" - "1721([Symbol name=convert_element_type])" - "1720([Symbol name=convert_element_type])" -> "1721([Symbol name=convert_element_type])" - "125([Symbol name=true_divide])" - "124([Symbol name=broadcast_in_dim])" -> "125([Symbol name=true_divide])" - "193([Symbol name=mul])" - "192([Symbol name=broadcast_in_dim])" -> "193([Symbol name=mul])" - "185([Symbol name=convert_element_type])" -> "193([Symbol name=mul])" - "186([Symbol name=mul])" - "185([Symbol name=convert_element_type])" -> "186([Symbol name=mul])" - "1676([Symbol name=cat])" - "1672([Symbol name=convert_element_type])" -> "1676([Symbol name=cat])" - "1675([Symbol name=slice_prim])" -> "1676([Symbol name=cat])" - "276([Symbol name=cat])" - "272([Symbol name=convert_element_type])" -> "276([Symbol name=cat])" - "275([Symbol name=slice_prim])" -> "276([Symbol name=cat])" - "674([Symbol name=cat])" - "657([Symbol name=convert_element_type])" -> "674([Symbol name=cat])" - "673([Symbol name=slice_prim])" -> "674([Symbol name=cat])" - "1176([Symbol name=cat])" - "1172([Symbol name=convert_element_type])" -> "1176([Symbol name=cat])" - "1175([Symbol name=slice_prim])" -> "1176([Symbol name=cat])" - "1574([Symbol name=cat])" - "1573([Symbol name=slice_prim])" -> "1574([Symbol name=cat])" - "1557([Symbol name=convert_element_type])" -> "1574([Symbol name=cat])" - "174([Symbol name=cat])" - "157([Symbol name=convert_element_type])" -> "174([Symbol name=cat])" - "173([Symbol name=slice_prim])" -> "174([Symbol name=cat])" - "676([Symbol name=cat])" - "672([Symbol name=convert_element_type])" -> "676([Symbol name=cat])" - "675([Symbol name=slice_prim])" -> "676([Symbol name=cat])" - "1074([Symbol name=cat])" - "1057([Symbol name=convert_element_type])" -> "1074([Symbol name=cat])" - "1073([Symbol name=slice_prim])" -> "1074([Symbol name=cat])" - "1576([Symbol name=cat])" - "1572([Symbol name=convert_element_type])" -> "1576([Symbol name=cat])" - "1575([Symbol name=slice_prim])" -> "1576([Symbol name=cat])" - "176([Symbol name=cat])" - "172([Symbol name=convert_element_type])" -> "176([Symbol name=cat])" - "175([Symbol name=slice_prim])" -> "176([Symbol name=cat])" - "574([Symbol name=cat])" - "573([Symbol name=slice_prim])" -> "574([Symbol name=cat])" - "557([Symbol name=convert_element_type])" -> "574([Symbol name=cat])" - "1076([Symbol name=cat])" - "1072([Symbol name=convert_element_type])" -> "1076([Symbol name=cat])" - "1075([Symbol name=slice_prim])" -> "1076([Symbol name=cat])" - "1474([Symbol name=cat])" - "1457([Symbol name=convert_element_type])" -> "1474([Symbol name=cat])" - "1473([Symbol name=slice_prim])" -> "1474([Symbol name=cat])" - "576([Symbol name=cat])" - "572([Symbol name=convert_element_type])" -> "576([Symbol name=cat])" - "575([Symbol name=slice_prim])" -> "576([Symbol name=cat])" - "974([Symbol name=cat])" - "973([Symbol name=slice_prim])" -> "974([Symbol name=cat])" - "957([Symbol name=convert_element_type])" -> "974([Symbol name=cat])" - "1476([Symbol name=cat])" - "1472([Symbol name=convert_element_type])" -> "1476([Symbol name=cat])" - "1475([Symbol name=slice_prim])" -> "1476([Symbol name=cat])" - "474([Symbol name=cat])" - "457([Symbol name=convert_element_type])" -> "474([Symbol name=cat])" - "473([Symbol name=slice_prim])" -> "474([Symbol name=cat])" - "976([Symbol name=cat])" - "972([Symbol name=convert_element_type])" -> "976([Symbol name=cat])" - "975([Symbol name=slice_prim])" -> "976([Symbol name=cat])" - "1374([Symbol name=cat])" - "1373([Symbol name=slice_prim])" -> "1374([Symbol name=cat])" - "1357([Symbol name=convert_element_type])" -> "1374([Symbol name=cat])" - "476([Symbol name=cat])" - "472([Symbol name=convert_element_type])" -> "476([Symbol name=cat])" - "475([Symbol name=slice_prim])" -> "476([Symbol name=cat])" - "874([Symbol name=cat])" - "857([Symbol name=convert_element_type])" -> "874([Symbol name=cat])" - "873([Symbol name=slice_prim])" -> "874([Symbol name=cat])" - "1376([Symbol name=cat])" - "1372([Symbol name=convert_element_type])" -> "1376([Symbol name=cat])" - "1375([Symbol name=slice_prim])" -> "1376([Symbol name=cat])" - "374([Symbol name=cat])" - "373([Symbol name=slice_prim])" -> "374([Symbol name=cat])" - "357([Symbol name=convert_element_type])" -> "374([Symbol name=cat])" - "876([Symbol name=cat])" - "872([Symbol name=convert_element_type])" -> "876([Symbol name=cat])" - "875([Symbol name=slice_prim])" -> "876([Symbol name=cat])" - "1274([Symbol name=cat])" - "1257([Symbol name=convert_element_type])" -> "1274([Symbol name=cat])" - "1273([Symbol name=slice_prim])" -> "1274([Symbol name=cat])" - "376([Symbol name=cat])" - "372([Symbol name=convert_element_type])" -> "376([Symbol name=cat])" - "375([Symbol name=slice_prim])" -> "376([Symbol name=cat])" - "774([Symbol name=cat])" - "773([Symbol name=slice_prim])" -> "774([Symbol name=cat])" - "757([Symbol name=convert_element_type])" -> "774([Symbol name=cat])" - "1276([Symbol name=cat])" - "1272([Symbol name=convert_element_type])" -> "1276([Symbol name=cat])" - "1275([Symbol name=slice_prim])" -> "1276([Symbol name=cat])" - "1674([Symbol name=cat])" - "1657([Symbol name=convert_element_type])" -> "1674([Symbol name=cat])" - "1673([Symbol name=slice_prim])" -> "1674([Symbol name=cat])" - "274([Symbol name=cat])" - "257([Symbol name=convert_element_type])" -> "274([Symbol name=cat])" - "273([Symbol name=slice_prim])" -> "274([Symbol name=cat])" - "776([Symbol name=cat])" - "772([Symbol name=convert_element_type])" -> "776([Symbol name=cat])" - "775([Symbol name=slice_prim])" -> "776([Symbol name=cat])" - "1174([Symbol name=cat])" - "1157([Symbol name=convert_element_type])" -> "1174([Symbol name=cat])" - "1173([Symbol name=slice_prim])" -> "1174([Symbol name=cat])" - "173([Symbol name=slice_prim])" - "140([Symbol name=reshape])" -> "173([Symbol name=slice_prim])" - "143([Symbol name=slice_prim])" - "140([Symbol name=reshape])" -> "143([Symbol name=slice_prim])" - "158([Symbol name=slice_prim])" - "141([Symbol name=reshape])" -> "158([Symbol name=slice_prim])" - "175([Symbol name=slice_prim])" - "141([Symbol name=reshape])" -> "175([Symbol name=slice_prim])" - "177([Symbol name=cudnn_sdpa_fwd])" - "176([Symbol name=cat])" -> "177([Symbol name=cudnn_sdpa_fwd])" - "142([Symbol name=reshape])" -> "177([Symbol name=cudnn_sdpa_fwd])" - "174([Symbol name=cat])" -> "177([Symbol name=cudnn_sdpa_fwd])" - "206([Symbol name=reciprocal])" - "205([Symbol name=add])" -> "206([Symbol name=reciprocal])" - "229([Symbol name=mul])" - "228([Symbol name=broadcast_in_dim])" -> "229([Symbol name=mul])" - "221([Symbol name=convert_element_type])" -> "229([Symbol name=mul])" - "222([Symbol name=mul])" - "221([Symbol name=convert_element_type])" -> "222([Symbol name=mul])" - "273([Symbol name=slice_prim])" - "240([Symbol name=reshape])" -> "273([Symbol name=slice_prim])" - "243([Symbol name=slice_prim])" - "240([Symbol name=reshape])" -> "243([Symbol name=slice_prim])" - "258([Symbol name=slice_prim])" - "241([Symbol name=reshape])" -> "258([Symbol name=slice_prim])" - "275([Symbol name=slice_prim])" - "241([Symbol name=reshape])" -> "275([Symbol name=slice_prim])" - "277([Symbol name=cudnn_sdpa_fwd])" - "274([Symbol name=cat])" -> "277([Symbol name=cudnn_sdpa_fwd])" - "242([Symbol name=reshape])" -> "277([Symbol name=cudnn_sdpa_fwd])" - "276([Symbol name=cat])" -> "277([Symbol name=cudnn_sdpa_fwd])" - "293([Symbol name=mul])" - "292([Symbol name=broadcast_in_dim])" -> "293([Symbol name=mul])" - "285([Symbol name=convert_element_type])" -> "293([Symbol name=mul])" - "286([Symbol name=mul])" - "285([Symbol name=convert_element_type])" -> "286([Symbol name=mul])" - "306([Symbol name=reciprocal])" - "305([Symbol name=add])" -> "306([Symbol name=reciprocal])" - "329([Symbol name=mul])" - "328([Symbol name=broadcast_in_dim])" -> "329([Symbol name=mul])" - "321([Symbol name=convert_element_type])" -> "329([Symbol name=mul])" - "322([Symbol name=mul])" - "321([Symbol name=convert_element_type])" -> "322([Symbol name=mul])" - "373([Symbol name=slice_prim])" - "340([Symbol name=reshape])" -> "373([Symbol name=slice_prim])" - "343([Symbol name=slice_prim])" - "340([Symbol name=reshape])" -> "343([Symbol name=slice_prim])" - "358([Symbol name=slice_prim])" - "341([Symbol name=reshape])" -> "358([Symbol name=slice_prim])" - "375([Symbol name=slice_prim])" - "341([Symbol name=reshape])" -> "375([Symbol name=slice_prim])" - "377([Symbol name=cudnn_sdpa_fwd])" - "376([Symbol name=cat])" -> "377([Symbol name=cudnn_sdpa_fwd])" - "342([Symbol name=reshape])" -> "377([Symbol name=cudnn_sdpa_fwd])" - "374([Symbol name=cat])" -> "377([Symbol name=cudnn_sdpa_fwd])" - "393([Symbol name=mul])" - "392([Symbol name=broadcast_in_dim])" -> "393([Symbol name=mul])" - "385([Symbol name=convert_element_type])" -> "393([Symbol name=mul])" - "386([Symbol name=mul])" - "385([Symbol name=convert_element_type])" -> "386([Symbol name=mul])" - "406([Symbol name=reciprocal])" - "405([Symbol name=add])" -> "406([Symbol name=reciprocal])" - "429([Symbol name=mul])" - "428([Symbol name=broadcast_in_dim])" -> "429([Symbol name=mul])" - "421([Symbol name=convert_element_type])" -> "429([Symbol name=mul])" - "422([Symbol name=mul])" - "421([Symbol name=convert_element_type])" -> "422([Symbol name=mul])" - "473([Symbol name=slice_prim])" - "440([Symbol name=reshape])" -> "473([Symbol name=slice_prim])" - "443([Symbol name=slice_prim])" - "440([Symbol name=reshape])" -> "443([Symbol name=slice_prim])" - "458([Symbol name=slice_prim])" - "441([Symbol name=reshape])" -> "458([Symbol name=slice_prim])" - "475([Symbol name=slice_prim])" - "441([Symbol name=reshape])" -> "475([Symbol name=slice_prim])" - "477([Symbol name=cudnn_sdpa_fwd])" - "442([Symbol name=reshape])" -> "477([Symbol name=cudnn_sdpa_fwd])" - "474([Symbol name=cat])" -> "477([Symbol name=cudnn_sdpa_fwd])" - "476([Symbol name=cat])" -> "477([Symbol name=cudnn_sdpa_fwd])" - "493([Symbol name=mul])" - "492([Symbol name=broadcast_in_dim])" -> "493([Symbol name=mul])" - "485([Symbol name=convert_element_type])" -> "493([Symbol name=mul])" - "486([Symbol name=mul])" - "485([Symbol name=convert_element_type])" -> "486([Symbol name=mul])" - "506([Symbol name=reciprocal])" - "505([Symbol name=add])" -> "506([Symbol name=reciprocal])" - "529([Symbol name=mul])" - "528([Symbol name=broadcast_in_dim])" -> "529([Symbol name=mul])" - "521([Symbol name=convert_element_type])" -> "529([Symbol name=mul])" - "522([Symbol name=mul])" - "521([Symbol name=convert_element_type])" -> "522([Symbol name=mul])" - "573([Symbol name=slice_prim])" - "540([Symbol name=reshape])" -> "573([Symbol name=slice_prim])" - "543([Symbol name=slice_prim])" - "540([Symbol name=reshape])" -> "543([Symbol name=slice_prim])" - "558([Symbol name=slice_prim])" - "541([Symbol name=reshape])" -> "558([Symbol name=slice_prim])" - "575([Symbol name=slice_prim])" - "541([Symbol name=reshape])" -> "575([Symbol name=slice_prim])" - "577([Symbol name=cudnn_sdpa_fwd])" - "576([Symbol name=cat])" -> "577([Symbol name=cudnn_sdpa_fwd])" - "574([Symbol name=cat])" -> "577([Symbol name=cudnn_sdpa_fwd])" - "542([Symbol name=reshape])" -> "577([Symbol name=cudnn_sdpa_fwd])" - "593([Symbol name=mul])" - "592([Symbol name=broadcast_in_dim])" -> "593([Symbol name=mul])" - "585([Symbol name=convert_element_type])" -> "593([Symbol name=mul])" - "586([Symbol name=mul])" - "585([Symbol name=convert_element_type])" -> "586([Symbol name=mul])" - "606([Symbol name=reciprocal])" - "605([Symbol name=add])" -> "606([Symbol name=reciprocal])" - "629([Symbol name=mul])" - "628([Symbol name=broadcast_in_dim])" -> "629([Symbol name=mul])" - "621([Symbol name=convert_element_type])" -> "629([Symbol name=mul])" - "622([Symbol name=mul])" - "621([Symbol name=convert_element_type])" -> "622([Symbol name=mul])" - "673([Symbol name=slice_prim])" - "640([Symbol name=reshape])" -> "673([Symbol name=slice_prim])" - "643([Symbol name=slice_prim])" - "640([Symbol name=reshape])" -> "643([Symbol name=slice_prim])" - "658([Symbol name=slice_prim])" - "641([Symbol name=reshape])" -> "658([Symbol name=slice_prim])" - "675([Symbol name=slice_prim])" - "641([Symbol name=reshape])" -> "675([Symbol name=slice_prim])" - "677([Symbol name=cudnn_sdpa_fwd])" - "674([Symbol name=cat])" -> "677([Symbol name=cudnn_sdpa_fwd])" - "676([Symbol name=cat])" -> "677([Symbol name=cudnn_sdpa_fwd])" - "642([Symbol name=reshape])" -> "677([Symbol name=cudnn_sdpa_fwd])" - "693([Symbol name=mul])" - "692([Symbol name=broadcast_in_dim])" -> "693([Symbol name=mul])" - "685([Symbol name=convert_element_type])" -> "693([Symbol name=mul])" - "686([Symbol name=mul])" - "685([Symbol name=convert_element_type])" -> "686([Symbol name=mul])" - "706([Symbol name=reciprocal])" - "705([Symbol name=add])" -> "706([Symbol name=reciprocal])" - "729([Symbol name=mul])" - "728([Symbol name=broadcast_in_dim])" -> "729([Symbol name=mul])" - "721([Symbol name=convert_element_type])" -> "729([Symbol name=mul])" - "722([Symbol name=mul])" - "721([Symbol name=convert_element_type])" -> "722([Symbol name=mul])" - "773([Symbol name=slice_prim])" - "740([Symbol name=reshape])" -> "773([Symbol name=slice_prim])" - "743([Symbol name=slice_prim])" - "740([Symbol name=reshape])" -> "743([Symbol name=slice_prim])" - "758([Symbol name=slice_prim])" - "741([Symbol name=reshape])" -> "758([Symbol name=slice_prim])" - "775([Symbol name=slice_prim])" - "741([Symbol name=reshape])" -> "775([Symbol name=slice_prim])" - "777([Symbol name=cudnn_sdpa_fwd])" - "776([Symbol name=cat])" -> "777([Symbol name=cudnn_sdpa_fwd])" - "774([Symbol name=cat])" -> "777([Symbol name=cudnn_sdpa_fwd])" - "742([Symbol name=reshape])" -> "777([Symbol name=cudnn_sdpa_fwd])" - "793([Symbol name=mul])" - "792([Symbol name=broadcast_in_dim])" -> "793([Symbol name=mul])" - "785([Symbol name=convert_element_type])" -> "793([Symbol name=mul])" - "786([Symbol name=mul])" - "785([Symbol name=convert_element_type])" -> "786([Symbol name=mul])" - "806([Symbol name=reciprocal])" - "805([Symbol name=add])" -> "806([Symbol name=reciprocal])" - "829([Symbol name=mul])" - "828([Symbol name=broadcast_in_dim])" -> "829([Symbol name=mul])" - "821([Symbol name=convert_element_type])" -> "829([Symbol name=mul])" - "822([Symbol name=mul])" - "821([Symbol name=convert_element_type])" -> "822([Symbol name=mul])" - "873([Symbol name=slice_prim])" - "840([Symbol name=reshape])" -> "873([Symbol name=slice_prim])" - "843([Symbol name=slice_prim])" - "840([Symbol name=reshape])" -> "843([Symbol name=slice_prim])" - "858([Symbol name=slice_prim])" - "841([Symbol name=reshape])" -> "858([Symbol name=slice_prim])" - "875([Symbol name=slice_prim])" - "841([Symbol name=reshape])" -> "875([Symbol name=slice_prim])" - "877([Symbol name=cudnn_sdpa_fwd])" - "874([Symbol name=cat])" -> "877([Symbol name=cudnn_sdpa_fwd])" - "876([Symbol name=cat])" -> "877([Symbol name=cudnn_sdpa_fwd])" - "842([Symbol name=reshape])" -> "877([Symbol name=cudnn_sdpa_fwd])" - "893([Symbol name=mul])" - "892([Symbol name=broadcast_in_dim])" -> "893([Symbol name=mul])" - "885([Symbol name=convert_element_type])" -> "893([Symbol name=mul])" - "886([Symbol name=mul])" - "885([Symbol name=convert_element_type])" -> "886([Symbol name=mul])" - "906([Symbol name=reciprocal])" - "905([Symbol name=add])" -> "906([Symbol name=reciprocal])" - "929([Symbol name=mul])" - "928([Symbol name=broadcast_in_dim])" -> "929([Symbol name=mul])" - "921([Symbol name=convert_element_type])" -> "929([Symbol name=mul])" - "922([Symbol name=mul])" - "921([Symbol name=convert_element_type])" -> "922([Symbol name=mul])" - "973([Symbol name=slice_prim])" - "940([Symbol name=reshape])" -> "973([Symbol name=slice_prim])" - "943([Symbol name=slice_prim])" - "940([Symbol name=reshape])" -> "943([Symbol name=slice_prim])" - "958([Symbol name=slice_prim])" - "941([Symbol name=reshape])" -> "958([Symbol name=slice_prim])" - "975([Symbol name=slice_prim])" - "941([Symbol name=reshape])" -> "975([Symbol name=slice_prim])" - "977([Symbol name=cudnn_sdpa_fwd])" - "976([Symbol name=cat])" -> "977([Symbol name=cudnn_sdpa_fwd])" - "942([Symbol name=reshape])" -> "977([Symbol name=cudnn_sdpa_fwd])" - "974([Symbol name=cat])" -> "977([Symbol name=cudnn_sdpa_fwd])" - "993([Symbol name=mul])" - "992([Symbol name=broadcast_in_dim])" -> "993([Symbol name=mul])" - "985([Symbol name=convert_element_type])" -> "993([Symbol name=mul])" - "986([Symbol name=mul])" - "985([Symbol name=convert_element_type])" -> "986([Symbol name=mul])" - "1006([Symbol name=reciprocal])" - "1005([Symbol name=add])" -> "1006([Symbol name=reciprocal])" - "1029([Symbol name=mul])" - "1028([Symbol name=broadcast_in_dim])" -> "1029([Symbol name=mul])" - "1021([Symbol name=convert_element_type])" -> "1029([Symbol name=mul])" - "1022([Symbol name=mul])" - "1021([Symbol name=convert_element_type])" -> "1022([Symbol name=mul])" - "1073([Symbol name=slice_prim])" - "1040([Symbol name=reshape])" -> "1073([Symbol name=slice_prim])" - "1043([Symbol name=slice_prim])" - "1040([Symbol name=reshape])" -> "1043([Symbol name=slice_prim])" - "1058([Symbol name=slice_prim])" - "1041([Symbol name=reshape])" -> "1058([Symbol name=slice_prim])" - "1075([Symbol name=slice_prim])" - "1041([Symbol name=reshape])" -> "1075([Symbol name=slice_prim])" - "1077([Symbol name=cudnn_sdpa_fwd])" - "1074([Symbol name=cat])" -> "1077([Symbol name=cudnn_sdpa_fwd])" - "1042([Symbol name=reshape])" -> "1077([Symbol name=cudnn_sdpa_fwd])" - "1076([Symbol name=cat])" -> "1077([Symbol name=cudnn_sdpa_fwd])" - "1093([Symbol name=mul])" - "1092([Symbol name=broadcast_in_dim])" -> "1093([Symbol name=mul])" - "1085([Symbol name=convert_element_type])" -> "1093([Symbol name=mul])" - "1086([Symbol name=mul])" - "1085([Symbol name=convert_element_type])" -> "1086([Symbol name=mul])" - "1106([Symbol name=reciprocal])" - "1105([Symbol name=add])" -> "1106([Symbol name=reciprocal])" - "1129([Symbol name=mul])" - "1128([Symbol name=broadcast_in_dim])" -> "1129([Symbol name=mul])" - "1121([Symbol name=convert_element_type])" -> "1129([Symbol name=mul])" - "1122([Symbol name=mul])" - "1121([Symbol name=convert_element_type])" -> "1122([Symbol name=mul])" - "1173([Symbol name=slice_prim])" - "1140([Symbol name=reshape])" -> "1173([Symbol name=slice_prim])" - "1143([Symbol name=slice_prim])" - "1140([Symbol name=reshape])" -> "1143([Symbol name=slice_prim])" - "1158([Symbol name=slice_prim])" - "1141([Symbol name=reshape])" -> "1158([Symbol name=slice_prim])" - "1175([Symbol name=slice_prim])" - "1141([Symbol name=reshape])" -> "1175([Symbol name=slice_prim])" - "1177([Symbol name=cudnn_sdpa_fwd])" - "1176([Symbol name=cat])" -> "1177([Symbol name=cudnn_sdpa_fwd])" - "1142([Symbol name=reshape])" -> "1177([Symbol name=cudnn_sdpa_fwd])" - "1174([Symbol name=cat])" -> "1177([Symbol name=cudnn_sdpa_fwd])" - "1193([Symbol name=mul])" - "1192([Symbol name=broadcast_in_dim])" -> "1193([Symbol name=mul])" - "1185([Symbol name=convert_element_type])" -> "1193([Symbol name=mul])" - "1186([Symbol name=mul])" - "1185([Symbol name=convert_element_type])" -> "1186([Symbol name=mul])" - "1206([Symbol name=reciprocal])" - "1205([Symbol name=add])" -> "1206([Symbol name=reciprocal])" - "1229([Symbol name=mul])" - "1228([Symbol name=broadcast_in_dim])" -> "1229([Symbol name=mul])" - "1221([Symbol name=convert_element_type])" -> "1229([Symbol name=mul])" - "1222([Symbol name=mul])" - "1221([Symbol name=convert_element_type])" -> "1222([Symbol name=mul])" - "1273([Symbol name=slice_prim])" - "1240([Symbol name=reshape])" -> "1273([Symbol name=slice_prim])" - "1243([Symbol name=slice_prim])" - "1240([Symbol name=reshape])" -> "1243([Symbol name=slice_prim])" - "1258([Symbol name=slice_prim])" - "1241([Symbol name=reshape])" -> "1258([Symbol name=slice_prim])" - "1275([Symbol name=slice_prim])" - "1241([Symbol name=reshape])" -> "1275([Symbol name=slice_prim])" - "1277([Symbol name=cudnn_sdpa_fwd])" - "1242([Symbol name=reshape])" -> "1277([Symbol name=cudnn_sdpa_fwd])" - "1274([Symbol name=cat])" -> "1277([Symbol name=cudnn_sdpa_fwd])" - "1276([Symbol name=cat])" -> "1277([Symbol name=cudnn_sdpa_fwd])" - "1293([Symbol name=mul])" - "1292([Symbol name=broadcast_in_dim])" -> "1293([Symbol name=mul])" - "1285([Symbol name=convert_element_type])" -> "1293([Symbol name=mul])" - "1286([Symbol name=mul])" - "1285([Symbol name=convert_element_type])" -> "1286([Symbol name=mul])" - "1306([Symbol name=reciprocal])" - "1305([Symbol name=add])" -> "1306([Symbol name=reciprocal])" - "1329([Symbol name=mul])" - "1328([Symbol name=broadcast_in_dim])" -> "1329([Symbol name=mul])" - "1321([Symbol name=convert_element_type])" -> "1329([Symbol name=mul])" - "1322([Symbol name=mul])" - "1321([Symbol name=convert_element_type])" -> "1322([Symbol name=mul])" - "1373([Symbol name=slice_prim])" - "1340([Symbol name=reshape])" -> "1373([Symbol name=slice_prim])" - "1343([Symbol name=slice_prim])" - "1340([Symbol name=reshape])" -> "1343([Symbol name=slice_prim])" - "1358([Symbol name=slice_prim])" - "1341([Symbol name=reshape])" -> "1358([Symbol name=slice_prim])" - "1375([Symbol name=slice_prim])" - "1341([Symbol name=reshape])" -> "1375([Symbol name=slice_prim])" - "1377([Symbol name=cudnn_sdpa_fwd])" - "1376([Symbol name=cat])" -> "1377([Symbol name=cudnn_sdpa_fwd])" - "1342([Symbol name=reshape])" -> "1377([Symbol name=cudnn_sdpa_fwd])" - "1374([Symbol name=cat])" -> "1377([Symbol name=cudnn_sdpa_fwd])" - "1393([Symbol name=mul])" - "1392([Symbol name=broadcast_in_dim])" -> "1393([Symbol name=mul])" - "1385([Symbol name=convert_element_type])" -> "1393([Symbol name=mul])" - "1386([Symbol name=mul])" - "1385([Symbol name=convert_element_type])" -> "1386([Symbol name=mul])" - "1406([Symbol name=reciprocal])" - "1405([Symbol name=add])" -> "1406([Symbol name=reciprocal])" - "1429([Symbol name=mul])" - "1428([Symbol name=broadcast_in_dim])" -> "1429([Symbol name=mul])" - "1421([Symbol name=convert_element_type])" -> "1429([Symbol name=mul])" - "1422([Symbol name=mul])" - "1421([Symbol name=convert_element_type])" -> "1422([Symbol name=mul])" - "1473([Symbol name=slice_prim])" - "1440([Symbol name=reshape])" -> "1473([Symbol name=slice_prim])" - "1443([Symbol name=slice_prim])" - "1440([Symbol name=reshape])" -> "1443([Symbol name=slice_prim])" - "1458([Symbol name=slice_prim])" - "1441([Symbol name=reshape])" -> "1458([Symbol name=slice_prim])" - "1475([Symbol name=slice_prim])" - "1441([Symbol name=reshape])" -> "1475([Symbol name=slice_prim])" - "1477([Symbol name=cudnn_sdpa_fwd])" - "1442([Symbol name=reshape])" -> "1477([Symbol name=cudnn_sdpa_fwd])" - "1474([Symbol name=cat])" -> "1477([Symbol name=cudnn_sdpa_fwd])" - "1476([Symbol name=cat])" -> "1477([Symbol name=cudnn_sdpa_fwd])" - "1493([Symbol name=mul])" - "1492([Symbol name=broadcast_in_dim])" -> "1493([Symbol name=mul])" - "1485([Symbol name=convert_element_type])" -> "1493([Symbol name=mul])" - "1486([Symbol name=mul])" - "1485([Symbol name=convert_element_type])" -> "1486([Symbol name=mul])" - "1506([Symbol name=reciprocal])" - "1505([Symbol name=add])" -> "1506([Symbol name=reciprocal])" - "1529([Symbol name=mul])" - "1528([Symbol name=broadcast_in_dim])" -> "1529([Symbol name=mul])" - "1521([Symbol name=convert_element_type])" -> "1529([Symbol name=mul])" - "1522([Symbol name=mul])" - "1521([Symbol name=convert_element_type])" -> "1522([Symbol name=mul])" - "1573([Symbol name=slice_prim])" - "1540([Symbol name=reshape])" -> "1573([Symbol name=slice_prim])" - "1543([Symbol name=slice_prim])" - "1540([Symbol name=reshape])" -> "1543([Symbol name=slice_prim])" - "1558([Symbol name=slice_prim])" - "1541([Symbol name=reshape])" -> "1558([Symbol name=slice_prim])" - "1575([Symbol name=slice_prim])" - "1541([Symbol name=reshape])" -> "1575([Symbol name=slice_prim])" - "1577([Symbol name=cudnn_sdpa_fwd])" - "1576([Symbol name=cat])" -> "1577([Symbol name=cudnn_sdpa_fwd])" - "1574([Symbol name=cat])" -> "1577([Symbol name=cudnn_sdpa_fwd])" - "1542([Symbol name=reshape])" -> "1577([Symbol name=cudnn_sdpa_fwd])" - "1593([Symbol name=mul])" - "1592([Symbol name=broadcast_in_dim])" -> "1593([Symbol name=mul])" - "1585([Symbol name=convert_element_type])" -> "1593([Symbol name=mul])" - "1586([Symbol name=mul])" - "1585([Symbol name=convert_element_type])" -> "1586([Symbol name=mul])" - "1606([Symbol name=reciprocal])" - "1605([Symbol name=add])" -> "1606([Symbol name=reciprocal])" - "1629([Symbol name=mul])" - "1628([Symbol name=broadcast_in_dim])" -> "1629([Symbol name=mul])" - "1621([Symbol name=convert_element_type])" -> "1629([Symbol name=mul])" - "1622([Symbol name=mul])" - "1621([Symbol name=convert_element_type])" -> "1622([Symbol name=mul])" - "1673([Symbol name=slice_prim])" - "1640([Symbol name=reshape])" -> "1673([Symbol name=slice_prim])" - "1643([Symbol name=slice_prim])" - "1640([Symbol name=reshape])" -> "1643([Symbol name=slice_prim])" - "1658([Symbol name=slice_prim])" - "1641([Symbol name=reshape])" -> "1658([Symbol name=slice_prim])" - "1675([Symbol name=slice_prim])" - "1641([Symbol name=reshape])" -> "1675([Symbol name=slice_prim])" - "1677([Symbol name=cudnn_sdpa_fwd])" - "1674([Symbol name=cat])" -> "1677([Symbol name=cudnn_sdpa_fwd])" - "1676([Symbol name=cat])" -> "1677([Symbol name=cudnn_sdpa_fwd])" - "1642([Symbol name=reshape])" -> "1677([Symbol name=cudnn_sdpa_fwd])" - "1693([Symbol name=mul])" - "1692([Symbol name=broadcast_in_dim])" -> "1693([Symbol name=mul])" - "1685([Symbol name=convert_element_type])" -> "1693([Symbol name=mul])" - "1686([Symbol name=mul])" - "1685([Symbol name=convert_element_type])" -> "1686([Symbol name=mul])" - "1706([Symbol name=reciprocal])" - "1705([Symbol name=add])" -> "1706([Symbol name=reciprocal])" - "1729([Symbol name=mul])" - "1728([Symbol name=broadcast_in_dim])" -> "1729([Symbol name=mul])" - "1721([Symbol name=convert_element_type])" -> "1729([Symbol name=mul])" - "1722([Symbol name=mul])" - "1721([Symbol name=convert_element_type])" -> "1722([Symbol name=mul])" - "126([Symbol name=add])" - "125([Symbol name=true_divide])" -> "126([Symbol name=add])" - "194([Symbol name=convert_element_type])" - "193([Symbol name=mul])" -> "194([Symbol name=convert_element_type])" - "187([Symbol name=sum])" - "186([Symbol name=mul])" -> "187([Symbol name=sum])" - "144([Symbol name=slice_prim])" - "143([Symbol name=slice_prim])" -> "144([Symbol name=slice_prim])" - "145([Symbol name=slice_prim])" - "143([Symbol name=slice_prim])" -> "145([Symbol name=slice_prim])" - "151([Symbol name=convert_element_type])" - "143([Symbol name=slice_prim])" -> "151([Symbol name=convert_element_type])" - "160([Symbol name=slice_prim])" - "158([Symbol name=slice_prim])" -> "160([Symbol name=slice_prim])" - "166([Symbol name=convert_element_type])" - "158([Symbol name=slice_prim])" -> "166([Symbol name=convert_element_type])" - "159([Symbol name=slice_prim])" - "158([Symbol name=slice_prim])" -> "159([Symbol name=slice_prim])" - "178([Symbol name=transpose])" - "177([Symbol name=cudnn_sdpa_fwd])" -> "178([Symbol name=transpose])" - "207([Symbol name=convert_element_type])" - "206([Symbol name=reciprocal])" -> "207([Symbol name=convert_element_type])" - "230([Symbol name=convert_element_type])" - "229([Symbol name=mul])" -> "230([Symbol name=convert_element_type])" - "223([Symbol name=sum])" - "222([Symbol name=mul])" -> "223([Symbol name=sum])" - "251([Symbol name=convert_element_type])" - "243([Symbol name=slice_prim])" -> "251([Symbol name=convert_element_type])" - "244([Symbol name=slice_prim])" - "243([Symbol name=slice_prim])" -> "244([Symbol name=slice_prim])" - "245([Symbol name=slice_prim])" - "243([Symbol name=slice_prim])" -> "245([Symbol name=slice_prim])" - "266([Symbol name=convert_element_type])" - "258([Symbol name=slice_prim])" -> "266([Symbol name=convert_element_type])" - "259([Symbol name=slice_prim])" - "258([Symbol name=slice_prim])" -> "259([Symbol name=slice_prim])" - "260([Symbol name=slice_prim])" - "258([Symbol name=slice_prim])" -> "260([Symbol name=slice_prim])" - "278([Symbol name=transpose])" - "277([Symbol name=cudnn_sdpa_fwd])" -> "278([Symbol name=transpose])" - "294([Symbol name=convert_element_type])" - "293([Symbol name=mul])" -> "294([Symbol name=convert_element_type])" - "287([Symbol name=sum])" - "286([Symbol name=mul])" -> "287([Symbol name=sum])" - "307([Symbol name=convert_element_type])" - "306([Symbol name=reciprocal])" -> "307([Symbol name=convert_element_type])" - "330([Symbol name=convert_element_type])" - "329([Symbol name=mul])" -> "330([Symbol name=convert_element_type])" - "323([Symbol name=sum])" - "322([Symbol name=mul])" -> "323([Symbol name=sum])" - "344([Symbol name=slice_prim])" - "343([Symbol name=slice_prim])" -> "344([Symbol name=slice_prim])" - "345([Symbol name=slice_prim])" - "343([Symbol name=slice_prim])" -> "345([Symbol name=slice_prim])" - "351([Symbol name=convert_element_type])" - "343([Symbol name=slice_prim])" -> "351([Symbol name=convert_element_type])" - "360([Symbol name=slice_prim])" - "358([Symbol name=slice_prim])" -> "360([Symbol name=slice_prim])" - "366([Symbol name=convert_element_type])" - "358([Symbol name=slice_prim])" -> "366([Symbol name=convert_element_type])" - "359([Symbol name=slice_prim])" - "358([Symbol name=slice_prim])" -> "359([Symbol name=slice_prim])" - "378([Symbol name=transpose])" - "377([Symbol name=cudnn_sdpa_fwd])" -> "378([Symbol name=transpose])" - "394([Symbol name=convert_element_type])" - "393([Symbol name=mul])" -> "394([Symbol name=convert_element_type])" - "387([Symbol name=sum])" - "386([Symbol name=mul])" -> "387([Symbol name=sum])" - "407([Symbol name=convert_element_type])" - "406([Symbol name=reciprocal])" -> "407([Symbol name=convert_element_type])" - "430([Symbol name=convert_element_type])" - "429([Symbol name=mul])" -> "430([Symbol name=convert_element_type])" - "423([Symbol name=sum])" - "422([Symbol name=mul])" -> "423([Symbol name=sum])" - "451([Symbol name=convert_element_type])" - "443([Symbol name=slice_prim])" -> "451([Symbol name=convert_element_type])" - "444([Symbol name=slice_prim])" - "443([Symbol name=slice_prim])" -> "444([Symbol name=slice_prim])" - "445([Symbol name=slice_prim])" - "443([Symbol name=slice_prim])" -> "445([Symbol name=slice_prim])" - "466([Symbol name=convert_element_type])" - "458([Symbol name=slice_prim])" -> "466([Symbol name=convert_element_type])" - "459([Symbol name=slice_prim])" - "458([Symbol name=slice_prim])" -> "459([Symbol name=slice_prim])" - "460([Symbol name=slice_prim])" - "458([Symbol name=slice_prim])" -> "460([Symbol name=slice_prim])" - "478([Symbol name=transpose])" - "477([Symbol name=cudnn_sdpa_fwd])" -> "478([Symbol name=transpose])" - "494([Symbol name=convert_element_type])" - "493([Symbol name=mul])" -> "494([Symbol name=convert_element_type])" - "487([Symbol name=sum])" - "486([Symbol name=mul])" -> "487([Symbol name=sum])" - "507([Symbol name=convert_element_type])" - "506([Symbol name=reciprocal])" -> "507([Symbol name=convert_element_type])" - "530([Symbol name=convert_element_type])" - "529([Symbol name=mul])" -> "530([Symbol name=convert_element_type])" - "523([Symbol name=sum])" - "522([Symbol name=mul])" -> "523([Symbol name=sum])" - "544([Symbol name=slice_prim])" - "543([Symbol name=slice_prim])" -> "544([Symbol name=slice_prim])" - "545([Symbol name=slice_prim])" - "543([Symbol name=slice_prim])" -> "545([Symbol name=slice_prim])" - "551([Symbol name=convert_element_type])" - "543([Symbol name=slice_prim])" -> "551([Symbol name=convert_element_type])" - "560([Symbol name=slice_prim])" - "558([Symbol name=slice_prim])" -> "560([Symbol name=slice_prim])" - "566([Symbol name=convert_element_type])" - "558([Symbol name=slice_prim])" -> "566([Symbol name=convert_element_type])" - "559([Symbol name=slice_prim])" - "558([Symbol name=slice_prim])" -> "559([Symbol name=slice_prim])" - "578([Symbol name=transpose])" - "577([Symbol name=cudnn_sdpa_fwd])" -> "578([Symbol name=transpose])" - "594([Symbol name=convert_element_type])" - "593([Symbol name=mul])" -> "594([Symbol name=convert_element_type])" - "587([Symbol name=sum])" - "586([Symbol name=mul])" -> "587([Symbol name=sum])" - "607([Symbol name=convert_element_type])" - "606([Symbol name=reciprocal])" -> "607([Symbol name=convert_element_type])" - "630([Symbol name=convert_element_type])" - "629([Symbol name=mul])" -> "630([Symbol name=convert_element_type])" - "623([Symbol name=sum])" - "622([Symbol name=mul])" -> "623([Symbol name=sum])" - "651([Symbol name=convert_element_type])" - "643([Symbol name=slice_prim])" -> "651([Symbol name=convert_element_type])" - "644([Symbol name=slice_prim])" - "643([Symbol name=slice_prim])" -> "644([Symbol name=slice_prim])" - "645([Symbol name=slice_prim])" - "643([Symbol name=slice_prim])" -> "645([Symbol name=slice_prim])" - "666([Symbol name=convert_element_type])" - "658([Symbol name=slice_prim])" -> "666([Symbol name=convert_element_type])" - "659([Symbol name=slice_prim])" - "658([Symbol name=slice_prim])" -> "659([Symbol name=slice_prim])" - "660([Symbol name=slice_prim])" - "658([Symbol name=slice_prim])" -> "660([Symbol name=slice_prim])" - "678([Symbol name=transpose])" - "677([Symbol name=cudnn_sdpa_fwd])" -> "678([Symbol name=transpose])" - "694([Symbol name=convert_element_type])" - "693([Symbol name=mul])" -> "694([Symbol name=convert_element_type])" - "687([Symbol name=sum])" - "686([Symbol name=mul])" -> "687([Symbol name=sum])" - "707([Symbol name=convert_element_type])" - "706([Symbol name=reciprocal])" -> "707([Symbol name=convert_element_type])" - "730([Symbol name=convert_element_type])" - "729([Symbol name=mul])" -> "730([Symbol name=convert_element_type])" - "723([Symbol name=sum])" - "722([Symbol name=mul])" -> "723([Symbol name=sum])" - "744([Symbol name=slice_prim])" - "743([Symbol name=slice_prim])" -> "744([Symbol name=slice_prim])" - "745([Symbol name=slice_prim])" - "743([Symbol name=slice_prim])" -> "745([Symbol name=slice_prim])" - "751([Symbol name=convert_element_type])" - "743([Symbol name=slice_prim])" -> "751([Symbol name=convert_element_type])" - "760([Symbol name=slice_prim])" - "758([Symbol name=slice_prim])" -> "760([Symbol name=slice_prim])" - "766([Symbol name=convert_element_type])" - "758([Symbol name=slice_prim])" -> "766([Symbol name=convert_element_type])" - "759([Symbol name=slice_prim])" - "758([Symbol name=slice_prim])" -> "759([Symbol name=slice_prim])" - "778([Symbol name=transpose])" - "777([Symbol name=cudnn_sdpa_fwd])" -> "778([Symbol name=transpose])" - "794([Symbol name=convert_element_type])" - "793([Symbol name=mul])" -> "794([Symbol name=convert_element_type])" - "787([Symbol name=sum])" - "786([Symbol name=mul])" -> "787([Symbol name=sum])" - "807([Symbol name=convert_element_type])" - "806([Symbol name=reciprocal])" -> "807([Symbol name=convert_element_type])" - "830([Symbol name=convert_element_type])" - "829([Symbol name=mul])" -> "830([Symbol name=convert_element_type])" - "823([Symbol name=sum])" - "822([Symbol name=mul])" -> "823([Symbol name=sum])" - "851([Symbol name=convert_element_type])" - "843([Symbol name=slice_prim])" -> "851([Symbol name=convert_element_type])" - "844([Symbol name=slice_prim])" - "843([Symbol name=slice_prim])" -> "844([Symbol name=slice_prim])" - "845([Symbol name=slice_prim])" - "843([Symbol name=slice_prim])" -> "845([Symbol name=slice_prim])" - "866([Symbol name=convert_element_type])" - "858([Symbol name=slice_prim])" -> "866([Symbol name=convert_element_type])" - "859([Symbol name=slice_prim])" - "858([Symbol name=slice_prim])" -> "859([Symbol name=slice_prim])" - "860([Symbol name=slice_prim])" - "858([Symbol name=slice_prim])" -> "860([Symbol name=slice_prim])" - "878([Symbol name=transpose])" - "877([Symbol name=cudnn_sdpa_fwd])" -> "878([Symbol name=transpose])" - "894([Symbol name=convert_element_type])" - "893([Symbol name=mul])" -> "894([Symbol name=convert_element_type])" - "887([Symbol name=sum])" - "886([Symbol name=mul])" -> "887([Symbol name=sum])" - "907([Symbol name=convert_element_type])" - "906([Symbol name=reciprocal])" -> "907([Symbol name=convert_element_type])" - "930([Symbol name=convert_element_type])" - "929([Symbol name=mul])" -> "930([Symbol name=convert_element_type])" - "923([Symbol name=sum])" - "922([Symbol name=mul])" -> "923([Symbol name=sum])" - "944([Symbol name=slice_prim])" - "943([Symbol name=slice_prim])" -> "944([Symbol name=slice_prim])" - "945([Symbol name=slice_prim])" - "943([Symbol name=slice_prim])" -> "945([Symbol name=slice_prim])" - "951([Symbol name=convert_element_type])" - "943([Symbol name=slice_prim])" -> "951([Symbol name=convert_element_type])" - "960([Symbol name=slice_prim])" - "958([Symbol name=slice_prim])" -> "960([Symbol name=slice_prim])" - "966([Symbol name=convert_element_type])" - "958([Symbol name=slice_prim])" -> "966([Symbol name=convert_element_type])" - "959([Symbol name=slice_prim])" - "958([Symbol name=slice_prim])" -> "959([Symbol name=slice_prim])" - "978([Symbol name=transpose])" - "977([Symbol name=cudnn_sdpa_fwd])" -> "978([Symbol name=transpose])" - "994([Symbol name=convert_element_type])" - "993([Symbol name=mul])" -> "994([Symbol name=convert_element_type])" - "987([Symbol name=sum])" - "986([Symbol name=mul])" -> "987([Symbol name=sum])" - "1007([Symbol name=convert_element_type])" - "1006([Symbol name=reciprocal])" -> "1007([Symbol name=convert_element_type])" - "1030([Symbol name=convert_element_type])" - "1029([Symbol name=mul])" -> "1030([Symbol name=convert_element_type])" - "1023([Symbol name=sum])" - "1022([Symbol name=mul])" -> "1023([Symbol name=sum])" - "1051([Symbol name=convert_element_type])" - "1043([Symbol name=slice_prim])" -> "1051([Symbol name=convert_element_type])" - "1044([Symbol name=slice_prim])" - "1043([Symbol name=slice_prim])" -> "1044([Symbol name=slice_prim])" - "1045([Symbol name=slice_prim])" - "1043([Symbol name=slice_prim])" -> "1045([Symbol name=slice_prim])" - "1066([Symbol name=convert_element_type])" - "1058([Symbol name=slice_prim])" -> "1066([Symbol name=convert_element_type])" - "1059([Symbol name=slice_prim])" - "1058([Symbol name=slice_prim])" -> "1059([Symbol name=slice_prim])" - "1060([Symbol name=slice_prim])" - "1058([Symbol name=slice_prim])" -> "1060([Symbol name=slice_prim])" - "1078([Symbol name=transpose])" - "1077([Symbol name=cudnn_sdpa_fwd])" -> "1078([Symbol name=transpose])" - "1094([Symbol name=convert_element_type])" - "1093([Symbol name=mul])" -> "1094([Symbol name=convert_element_type])" - "1087([Symbol name=sum])" - "1086([Symbol name=mul])" -> "1087([Symbol name=sum])" - "1107([Symbol name=convert_element_type])" - "1106([Symbol name=reciprocal])" -> "1107([Symbol name=convert_element_type])" - "1130([Symbol name=convert_element_type])" - "1129([Symbol name=mul])" -> "1130([Symbol name=convert_element_type])" - "1123([Symbol name=sum])" - "1122([Symbol name=mul])" -> "1123([Symbol name=sum])" - "1144([Symbol name=slice_prim])" - "1143([Symbol name=slice_prim])" -> "1144([Symbol name=slice_prim])" - "1145([Symbol name=slice_prim])" - "1143([Symbol name=slice_prim])" -> "1145([Symbol name=slice_prim])" - "1151([Symbol name=convert_element_type])" - "1143([Symbol name=slice_prim])" -> "1151([Symbol name=convert_element_type])" - "1160([Symbol name=slice_prim])" - "1158([Symbol name=slice_prim])" -> "1160([Symbol name=slice_prim])" - "1166([Symbol name=convert_element_type])" - "1158([Symbol name=slice_prim])" -> "1166([Symbol name=convert_element_type])" - "1159([Symbol name=slice_prim])" - "1158([Symbol name=slice_prim])" -> "1159([Symbol name=slice_prim])" - "1178([Symbol name=transpose])" - "1177([Symbol name=cudnn_sdpa_fwd])" -> "1178([Symbol name=transpose])" - "1194([Symbol name=convert_element_type])" - "1193([Symbol name=mul])" -> "1194([Symbol name=convert_element_type])" - "1187([Symbol name=sum])" - "1186([Symbol name=mul])" -> "1187([Symbol name=sum])" - "1207([Symbol name=convert_element_type])" - "1206([Symbol name=reciprocal])" -> "1207([Symbol name=convert_element_type])" - "1230([Symbol name=convert_element_type])" - "1229([Symbol name=mul])" -> "1230([Symbol name=convert_element_type])" - "1223([Symbol name=sum])" - "1222([Symbol name=mul])" -> "1223([Symbol name=sum])" - "1251([Symbol name=convert_element_type])" - "1243([Symbol name=slice_prim])" -> "1251([Symbol name=convert_element_type])" - "1244([Symbol name=slice_prim])" - "1243([Symbol name=slice_prim])" -> "1244([Symbol name=slice_prim])" - "1245([Symbol name=slice_prim])" - "1243([Symbol name=slice_prim])" -> "1245([Symbol name=slice_prim])" - "1266([Symbol name=convert_element_type])" - "1258([Symbol name=slice_prim])" -> "1266([Symbol name=convert_element_type])" - "1259([Symbol name=slice_prim])" - "1258([Symbol name=slice_prim])" -> "1259([Symbol name=slice_prim])" - "1260([Symbol name=slice_prim])" - "1258([Symbol name=slice_prim])" -> "1260([Symbol name=slice_prim])" - "1278([Symbol name=transpose])" - "1277([Symbol name=cudnn_sdpa_fwd])" -> "1278([Symbol name=transpose])" - "1294([Symbol name=convert_element_type])" - "1293([Symbol name=mul])" -> "1294([Symbol name=convert_element_type])" - "1287([Symbol name=sum])" - "1286([Symbol name=mul])" -> "1287([Symbol name=sum])" - "1307([Symbol name=convert_element_type])" - "1306([Symbol name=reciprocal])" -> "1307([Symbol name=convert_element_type])" - "1330([Symbol name=convert_element_type])" - "1329([Symbol name=mul])" -> "1330([Symbol name=convert_element_type])" - "1323([Symbol name=sum])" - "1322([Symbol name=mul])" -> "1323([Symbol name=sum])" - "1344([Symbol name=slice_prim])" - "1343([Symbol name=slice_prim])" -> "1344([Symbol name=slice_prim])" - "1345([Symbol name=slice_prim])" - "1343([Symbol name=slice_prim])" -> "1345([Symbol name=slice_prim])" - "1351([Symbol name=convert_element_type])" - "1343([Symbol name=slice_prim])" -> "1351([Symbol name=convert_element_type])" - "1360([Symbol name=slice_prim])" - "1358([Symbol name=slice_prim])" -> "1360([Symbol name=slice_prim])" - "1366([Symbol name=convert_element_type])" - "1358([Symbol name=slice_prim])" -> "1366([Symbol name=convert_element_type])" - "1359([Symbol name=slice_prim])" - "1358([Symbol name=slice_prim])" -> "1359([Symbol name=slice_prim])" - "1378([Symbol name=transpose])" - "1377([Symbol name=cudnn_sdpa_fwd])" -> "1378([Symbol name=transpose])" - "1394([Symbol name=convert_element_type])" - "1393([Symbol name=mul])" -> "1394([Symbol name=convert_element_type])" - "1387([Symbol name=sum])" - "1386([Symbol name=mul])" -> "1387([Symbol name=sum])" - "1407([Symbol name=convert_element_type])" - "1406([Symbol name=reciprocal])" -> "1407([Symbol name=convert_element_type])" - "1430([Symbol name=convert_element_type])" - "1429([Symbol name=mul])" -> "1430([Symbol name=convert_element_type])" - "1423([Symbol name=sum])" - "1422([Symbol name=mul])" -> "1423([Symbol name=sum])" - "1451([Symbol name=convert_element_type])" - "1443([Symbol name=slice_prim])" -> "1451([Symbol name=convert_element_type])" - "1444([Symbol name=slice_prim])" - "1443([Symbol name=slice_prim])" -> "1444([Symbol name=slice_prim])" - "1445([Symbol name=slice_prim])" - "1443([Symbol name=slice_prim])" -> "1445([Symbol name=slice_prim])" - "1466([Symbol name=convert_element_type])" - "1458([Symbol name=slice_prim])" -> "1466([Symbol name=convert_element_type])" - "1459([Symbol name=slice_prim])" - "1458([Symbol name=slice_prim])" -> "1459([Symbol name=slice_prim])" - "1460([Symbol name=slice_prim])" - "1458([Symbol name=slice_prim])" -> "1460([Symbol name=slice_prim])" - "1478([Symbol name=transpose])" - "1477([Symbol name=cudnn_sdpa_fwd])" -> "1478([Symbol name=transpose])" - "1494([Symbol name=convert_element_type])" - "1493([Symbol name=mul])" -> "1494([Symbol name=convert_element_type])" - "1487([Symbol name=sum])" - "1486([Symbol name=mul])" -> "1487([Symbol name=sum])" - "1507([Symbol name=convert_element_type])" - "1506([Symbol name=reciprocal])" -> "1507([Symbol name=convert_element_type])" - "1530([Symbol name=convert_element_type])" - "1529([Symbol name=mul])" -> "1530([Symbol name=convert_element_type])" - "1523([Symbol name=sum])" - "1522([Symbol name=mul])" -> "1523([Symbol name=sum])" - "1544([Symbol name=slice_prim])" - "1543([Symbol name=slice_prim])" -> "1544([Symbol name=slice_prim])" - "1545([Symbol name=slice_prim])" - "1543([Symbol name=slice_prim])" -> "1545([Symbol name=slice_prim])" - "1551([Symbol name=convert_element_type])" - "1543([Symbol name=slice_prim])" -> "1551([Symbol name=convert_element_type])" - "1560([Symbol name=slice_prim])" - "1558([Symbol name=slice_prim])" -> "1560([Symbol name=slice_prim])" - "1566([Symbol name=convert_element_type])" - "1558([Symbol name=slice_prim])" -> "1566([Symbol name=convert_element_type])" - "1559([Symbol name=slice_prim])" - "1558([Symbol name=slice_prim])" -> "1559([Symbol name=slice_prim])" - "1578([Symbol name=transpose])" - "1577([Symbol name=cudnn_sdpa_fwd])" -> "1578([Symbol name=transpose])" - "1594([Symbol name=convert_element_type])" - "1593([Symbol name=mul])" -> "1594([Symbol name=convert_element_type])" - "1587([Symbol name=sum])" - "1586([Symbol name=mul])" -> "1587([Symbol name=sum])" - "1607([Symbol name=convert_element_type])" - "1606([Symbol name=reciprocal])" -> "1607([Symbol name=convert_element_type])" - "1630([Symbol name=convert_element_type])" - "1629([Symbol name=mul])" -> "1630([Symbol name=convert_element_type])" - "1623([Symbol name=sum])" - "1622([Symbol name=mul])" -> "1623([Symbol name=sum])" - "1651([Symbol name=convert_element_type])" - "1643([Symbol name=slice_prim])" -> "1651([Symbol name=convert_element_type])" - "1644([Symbol name=slice_prim])" - "1643([Symbol name=slice_prim])" -> "1644([Symbol name=slice_prim])" - "1645([Symbol name=slice_prim])" - "1643([Symbol name=slice_prim])" -> "1645([Symbol name=slice_prim])" - "1666([Symbol name=convert_element_type])" - "1658([Symbol name=slice_prim])" -> "1666([Symbol name=convert_element_type])" - "1659([Symbol name=slice_prim])" - "1658([Symbol name=slice_prim])" -> "1659([Symbol name=slice_prim])" - "1660([Symbol name=slice_prim])" - "1658([Symbol name=slice_prim])" -> "1660([Symbol name=slice_prim])" - "1678([Symbol name=transpose])" - "1677([Symbol name=cudnn_sdpa_fwd])" -> "1678([Symbol name=transpose])" - "1694([Symbol name=convert_element_type])" - "1693([Symbol name=mul])" -> "1694([Symbol name=convert_element_type])" - "1687([Symbol name=sum])" - "1686([Symbol name=mul])" -> "1687([Symbol name=sum])" - "1707([Symbol name=convert_element_type])" - "1706([Symbol name=reciprocal])" -> "1707([Symbol name=convert_element_type])" - "1730([Symbol name=convert_element_type])" - "1729([Symbol name=mul])" -> "1730([Symbol name=convert_element_type])" - "1723([Symbol name=sum])" - "1722([Symbol name=mul])" -> "1723([Symbol name=sum])" - "127([Symbol name=rsqrt])" - "126([Symbol name=add])" -> "127([Symbol name=rsqrt])" - "196([Symbol name=convert_element_type])" - "194([Symbol name=convert_element_type])" -> "196([Symbol name=convert_element_type])" - "188([Symbol name=broadcast_in_dim])" - "187([Symbol name=sum])" -> "188([Symbol name=broadcast_in_dim])" - "149([Symbol name=cat])" - "144([Symbol name=slice_prim])" -> "149([Symbol name=cat])" - "148([Symbol name=convert_element_type])" -> "149([Symbol name=cat])" - "146([Symbol name=convert_element_type])" - "145([Symbol name=slice_prim])" -> "146([Symbol name=convert_element_type])" - "161([Symbol name=convert_element_type])" - "160([Symbol name=slice_prim])" -> "161([Symbol name=convert_element_type])" - "164([Symbol name=cat])" - "163([Symbol name=convert_element_type])" -> "164([Symbol name=cat])" - "159([Symbol name=slice_prim])" -> "164([Symbol name=cat])" - "179([Symbol name=reshape])" - "178([Symbol name=transpose])" -> "179([Symbol name=reshape])" - "209([Symbol name=convert_element_type])" - "207([Symbol name=convert_element_type])" -> "209([Symbol name=convert_element_type])" - "232([Symbol name=convert_element_type])" - "230([Symbol name=convert_element_type])" -> "232([Symbol name=convert_element_type])" - "224([Symbol name=broadcast_in_dim])" - "223([Symbol name=sum])" -> "224([Symbol name=broadcast_in_dim])" - "249([Symbol name=cat])" - "248([Symbol name=convert_element_type])" -> "249([Symbol name=cat])" - "244([Symbol name=slice_prim])" -> "249([Symbol name=cat])" - "246([Symbol name=convert_element_type])" - "245([Symbol name=slice_prim])" -> "246([Symbol name=convert_element_type])" - "264([Symbol name=cat])" - "259([Symbol name=slice_prim])" -> "264([Symbol name=cat])" - "263([Symbol name=convert_element_type])" -> "264([Symbol name=cat])" - "261([Symbol name=convert_element_type])" - "260([Symbol name=slice_prim])" -> "261([Symbol name=convert_element_type])" - "279([Symbol name=reshape])" - "278([Symbol name=transpose])" -> "279([Symbol name=reshape])" - "296([Symbol name=convert_element_type])" - "294([Symbol name=convert_element_type])" -> "296([Symbol name=convert_element_type])" - "288([Symbol name=broadcast_in_dim])" - "287([Symbol name=sum])" -> "288([Symbol name=broadcast_in_dim])" - "309([Symbol name=convert_element_type])" - "307([Symbol name=convert_element_type])" -> "309([Symbol name=convert_element_type])" - "332([Symbol name=convert_element_type])" - "330([Symbol name=convert_element_type])" -> "332([Symbol name=convert_element_type])" - "324([Symbol name=broadcast_in_dim])" - "323([Symbol name=sum])" -> "324([Symbol name=broadcast_in_dim])" - "349([Symbol name=cat])" - "344([Symbol name=slice_prim])" -> "349([Symbol name=cat])" - "348([Symbol name=convert_element_type])" -> "349([Symbol name=cat])" - "346([Symbol name=convert_element_type])" - "345([Symbol name=slice_prim])" -> "346([Symbol name=convert_element_type])" - "361([Symbol name=convert_element_type])" - "360([Symbol name=slice_prim])" -> "361([Symbol name=convert_element_type])" - "364([Symbol name=cat])" - "363([Symbol name=convert_element_type])" -> "364([Symbol name=cat])" - "359([Symbol name=slice_prim])" -> "364([Symbol name=cat])" - "379([Symbol name=reshape])" - "378([Symbol name=transpose])" -> "379([Symbol name=reshape])" - "396([Symbol name=convert_element_type])" - "394([Symbol name=convert_element_type])" -> "396([Symbol name=convert_element_type])" - "388([Symbol name=broadcast_in_dim])" - "387([Symbol name=sum])" -> "388([Symbol name=broadcast_in_dim])" - "409([Symbol name=convert_element_type])" - "407([Symbol name=convert_element_type])" -> "409([Symbol name=convert_element_type])" - "432([Symbol name=convert_element_type])" - "430([Symbol name=convert_element_type])" -> "432([Symbol name=convert_element_type])" - "424([Symbol name=broadcast_in_dim])" - "423([Symbol name=sum])" -> "424([Symbol name=broadcast_in_dim])" - "449([Symbol name=cat])" - "448([Symbol name=convert_element_type])" -> "449([Symbol name=cat])" - "444([Symbol name=slice_prim])" -> "449([Symbol name=cat])" - "446([Symbol name=convert_element_type])" - "445([Symbol name=slice_prim])" -> "446([Symbol name=convert_element_type])" - "464([Symbol name=cat])" - "459([Symbol name=slice_prim])" -> "464([Symbol name=cat])" - "463([Symbol name=convert_element_type])" -> "464([Symbol name=cat])" - "461([Symbol name=convert_element_type])" - "460([Symbol name=slice_prim])" -> "461([Symbol name=convert_element_type])" - "479([Symbol name=reshape])" - "478([Symbol name=transpose])" -> "479([Symbol name=reshape])" - "496([Symbol name=convert_element_type])" - "494([Symbol name=convert_element_type])" -> "496([Symbol name=convert_element_type])" - "488([Symbol name=broadcast_in_dim])" - "487([Symbol name=sum])" -> "488([Symbol name=broadcast_in_dim])" - "509([Symbol name=convert_element_type])" - "507([Symbol name=convert_element_type])" -> "509([Symbol name=convert_element_type])" - "532([Symbol name=convert_element_type])" - "530([Symbol name=convert_element_type])" -> "532([Symbol name=convert_element_type])" - "524([Symbol name=broadcast_in_dim])" - "523([Symbol name=sum])" -> "524([Symbol name=broadcast_in_dim])" - "549([Symbol name=cat])" - "544([Symbol name=slice_prim])" -> "549([Symbol name=cat])" - "548([Symbol name=convert_element_type])" -> "549([Symbol name=cat])" - "546([Symbol name=convert_element_type])" - "545([Symbol name=slice_prim])" -> "546([Symbol name=convert_element_type])" - "561([Symbol name=convert_element_type])" - "560([Symbol name=slice_prim])" -> "561([Symbol name=convert_element_type])" - "564([Symbol name=cat])" - "563([Symbol name=convert_element_type])" -> "564([Symbol name=cat])" - "559([Symbol name=slice_prim])" -> "564([Symbol name=cat])" - "579([Symbol name=reshape])" - "578([Symbol name=transpose])" -> "579([Symbol name=reshape])" - "596([Symbol name=convert_element_type])" - "594([Symbol name=convert_element_type])" -> "596([Symbol name=convert_element_type])" - "588([Symbol name=broadcast_in_dim])" - "587([Symbol name=sum])" -> "588([Symbol name=broadcast_in_dim])" - "609([Symbol name=convert_element_type])" - "607([Symbol name=convert_element_type])" -> "609([Symbol name=convert_element_type])" - "632([Symbol name=convert_element_type])" - "630([Symbol name=convert_element_type])" -> "632([Symbol name=convert_element_type])" - "624([Symbol name=broadcast_in_dim])" - "623([Symbol name=sum])" -> "624([Symbol name=broadcast_in_dim])" - "649([Symbol name=cat])" - "648([Symbol name=convert_element_type])" -> "649([Symbol name=cat])" - "644([Symbol name=slice_prim])" -> "649([Symbol name=cat])" - "646([Symbol name=convert_element_type])" - "645([Symbol name=slice_prim])" -> "646([Symbol name=convert_element_type])" - "664([Symbol name=cat])" - "659([Symbol name=slice_prim])" -> "664([Symbol name=cat])" - "663([Symbol name=convert_element_type])" -> "664([Symbol name=cat])" - "661([Symbol name=convert_element_type])" - "660([Symbol name=slice_prim])" -> "661([Symbol name=convert_element_type])" - "679([Symbol name=reshape])" - "678([Symbol name=transpose])" -> "679([Symbol name=reshape])" - "696([Symbol name=convert_element_type])" - "694([Symbol name=convert_element_type])" -> "696([Symbol name=convert_element_type])" - "688([Symbol name=broadcast_in_dim])" - "687([Symbol name=sum])" -> "688([Symbol name=broadcast_in_dim])" - "709([Symbol name=convert_element_type])" - "707([Symbol name=convert_element_type])" -> "709([Symbol name=convert_element_type])" - "732([Symbol name=convert_element_type])" - "730([Symbol name=convert_element_type])" -> "732([Symbol name=convert_element_type])" - "724([Symbol name=broadcast_in_dim])" - "723([Symbol name=sum])" -> "724([Symbol name=broadcast_in_dim])" - "749([Symbol name=cat])" - "744([Symbol name=slice_prim])" -> "749([Symbol name=cat])" - "748([Symbol name=convert_element_type])" -> "749([Symbol name=cat])" - "746([Symbol name=convert_element_type])" - "745([Symbol name=slice_prim])" -> "746([Symbol name=convert_element_type])" - "761([Symbol name=convert_element_type])" - "760([Symbol name=slice_prim])" -> "761([Symbol name=convert_element_type])" - "764([Symbol name=cat])" - "763([Symbol name=convert_element_type])" -> "764([Symbol name=cat])" - "759([Symbol name=slice_prim])" -> "764([Symbol name=cat])" - "779([Symbol name=reshape])" - "778([Symbol name=transpose])" -> "779([Symbol name=reshape])" - "796([Symbol name=convert_element_type])" - "794([Symbol name=convert_element_type])" -> "796([Symbol name=convert_element_type])" - "788([Symbol name=broadcast_in_dim])" - "787([Symbol name=sum])" -> "788([Symbol name=broadcast_in_dim])" - "809([Symbol name=convert_element_type])" - "807([Symbol name=convert_element_type])" -> "809([Symbol name=convert_element_type])" - "832([Symbol name=convert_element_type])" - "830([Symbol name=convert_element_type])" -> "832([Symbol name=convert_element_type])" - "824([Symbol name=broadcast_in_dim])" - "823([Symbol name=sum])" -> "824([Symbol name=broadcast_in_dim])" - "849([Symbol name=cat])" - "848([Symbol name=convert_element_type])" -> "849([Symbol name=cat])" - "844([Symbol name=slice_prim])" -> "849([Symbol name=cat])" - "846([Symbol name=convert_element_type])" - "845([Symbol name=slice_prim])" -> "846([Symbol name=convert_element_type])" - "864([Symbol name=cat])" - "859([Symbol name=slice_prim])" -> "864([Symbol name=cat])" - "863([Symbol name=convert_element_type])" -> "864([Symbol name=cat])" - "861([Symbol name=convert_element_type])" - "860([Symbol name=slice_prim])" -> "861([Symbol name=convert_element_type])" - "879([Symbol name=reshape])" - "878([Symbol name=transpose])" -> "879([Symbol name=reshape])" - "896([Symbol name=convert_element_type])" - "894([Symbol name=convert_element_type])" -> "896([Symbol name=convert_element_type])" - "888([Symbol name=broadcast_in_dim])" - "887([Symbol name=sum])" -> "888([Symbol name=broadcast_in_dim])" - "909([Symbol name=convert_element_type])" - "907([Symbol name=convert_element_type])" -> "909([Symbol name=convert_element_type])" - "932([Symbol name=convert_element_type])" - "930([Symbol name=convert_element_type])" -> "932([Symbol name=convert_element_type])" - "924([Symbol name=broadcast_in_dim])" - "923([Symbol name=sum])" -> "924([Symbol name=broadcast_in_dim])" - "949([Symbol name=cat])" - "944([Symbol name=slice_prim])" -> "949([Symbol name=cat])" - "948([Symbol name=convert_element_type])" -> "949([Symbol name=cat])" - "946([Symbol name=convert_element_type])" - "945([Symbol name=slice_prim])" -> "946([Symbol name=convert_element_type])" - "961([Symbol name=convert_element_type])" - "960([Symbol name=slice_prim])" -> "961([Symbol name=convert_element_type])" - "964([Symbol name=cat])" - "963([Symbol name=convert_element_type])" -> "964([Symbol name=cat])" - "959([Symbol name=slice_prim])" -> "964([Symbol name=cat])" - "979([Symbol name=reshape])" - "978([Symbol name=transpose])" -> "979([Symbol name=reshape])" - "996([Symbol name=convert_element_type])" - "994([Symbol name=convert_element_type])" -> "996([Symbol name=convert_element_type])" - "988([Symbol name=broadcast_in_dim])" - "987([Symbol name=sum])" -> "988([Symbol name=broadcast_in_dim])" - "1009([Symbol name=convert_element_type])" - "1007([Symbol name=convert_element_type])" -> "1009([Symbol name=convert_element_type])" - "1032([Symbol name=convert_element_type])" - "1030([Symbol name=convert_element_type])" -> "1032([Symbol name=convert_element_type])" - "1024([Symbol name=broadcast_in_dim])" - "1023([Symbol name=sum])" -> "1024([Symbol name=broadcast_in_dim])" - "1049([Symbol name=cat])" - "1048([Symbol name=convert_element_type])" -> "1049([Symbol name=cat])" - "1044([Symbol name=slice_prim])" -> "1049([Symbol name=cat])" - "1046([Symbol name=convert_element_type])" - "1045([Symbol name=slice_prim])" -> "1046([Symbol name=convert_element_type])" - "1064([Symbol name=cat])" - "1059([Symbol name=slice_prim])" -> "1064([Symbol name=cat])" - "1063([Symbol name=convert_element_type])" -> "1064([Symbol name=cat])" - "1061([Symbol name=convert_element_type])" - "1060([Symbol name=slice_prim])" -> "1061([Symbol name=convert_element_type])" - "1079([Symbol name=reshape])" - "1078([Symbol name=transpose])" -> "1079([Symbol name=reshape])" - "1096([Symbol name=convert_element_type])" - "1094([Symbol name=convert_element_type])" -> "1096([Symbol name=convert_element_type])" - "1088([Symbol name=broadcast_in_dim])" - "1087([Symbol name=sum])" -> "1088([Symbol name=broadcast_in_dim])" - "1109([Symbol name=convert_element_type])" - "1107([Symbol name=convert_element_type])" -> "1109([Symbol name=convert_element_type])" - "1132([Symbol name=convert_element_type])" - "1130([Symbol name=convert_element_type])" -> "1132([Symbol name=convert_element_type])" - "1124([Symbol name=broadcast_in_dim])" - "1123([Symbol name=sum])" -> "1124([Symbol name=broadcast_in_dim])" - "1149([Symbol name=cat])" - "1144([Symbol name=slice_prim])" -> "1149([Symbol name=cat])" - "1148([Symbol name=convert_element_type])" -> "1149([Symbol name=cat])" - "1146([Symbol name=convert_element_type])" - "1145([Symbol name=slice_prim])" -> "1146([Symbol name=convert_element_type])" - "1161([Symbol name=convert_element_type])" - "1160([Symbol name=slice_prim])" -> "1161([Symbol name=convert_element_type])" - "1164([Symbol name=cat])" - "1163([Symbol name=convert_element_type])" -> "1164([Symbol name=cat])" - "1159([Symbol name=slice_prim])" -> "1164([Symbol name=cat])" - "1179([Symbol name=reshape])" - "1178([Symbol name=transpose])" -> "1179([Symbol name=reshape])" - "1196([Symbol name=convert_element_type])" - "1194([Symbol name=convert_element_type])" -> "1196([Symbol name=convert_element_type])" - "1188([Symbol name=broadcast_in_dim])" - "1187([Symbol name=sum])" -> "1188([Symbol name=broadcast_in_dim])" - "1209([Symbol name=convert_element_type])" - "1207([Symbol name=convert_element_type])" -> "1209([Symbol name=convert_element_type])" - "1232([Symbol name=convert_element_type])" - "1230([Symbol name=convert_element_type])" -> "1232([Symbol name=convert_element_type])" - "1224([Symbol name=broadcast_in_dim])" - "1223([Symbol name=sum])" -> "1224([Symbol name=broadcast_in_dim])" - "1249([Symbol name=cat])" - "1248([Symbol name=convert_element_type])" -> "1249([Symbol name=cat])" - "1244([Symbol name=slice_prim])" -> "1249([Symbol name=cat])" - "1246([Symbol name=convert_element_type])" - "1245([Symbol name=slice_prim])" -> "1246([Symbol name=convert_element_type])" - "1264([Symbol name=cat])" - "1259([Symbol name=slice_prim])" -> "1264([Symbol name=cat])" - "1263([Symbol name=convert_element_type])" -> "1264([Symbol name=cat])" - "1261([Symbol name=convert_element_type])" - "1260([Symbol name=slice_prim])" -> "1261([Symbol name=convert_element_type])" - "1279([Symbol name=reshape])" - "1278([Symbol name=transpose])" -> "1279([Symbol name=reshape])" - "1296([Symbol name=convert_element_type])" - "1294([Symbol name=convert_element_type])" -> "1296([Symbol name=convert_element_type])" - "1288([Symbol name=broadcast_in_dim])" - "1287([Symbol name=sum])" -> "1288([Symbol name=broadcast_in_dim])" - "1309([Symbol name=convert_element_type])" - "1307([Symbol name=convert_element_type])" -> "1309([Symbol name=convert_element_type])" - "1332([Symbol name=convert_element_type])" - "1330([Symbol name=convert_element_type])" -> "1332([Symbol name=convert_element_type])" - "1324([Symbol name=broadcast_in_dim])" - "1323([Symbol name=sum])" -> "1324([Symbol name=broadcast_in_dim])" - "1349([Symbol name=cat])" - "1344([Symbol name=slice_prim])" -> "1349([Symbol name=cat])" - "1348([Symbol name=convert_element_type])" -> "1349([Symbol name=cat])" - "1346([Symbol name=convert_element_type])" - "1345([Symbol name=slice_prim])" -> "1346([Symbol name=convert_element_type])" - "1361([Symbol name=convert_element_type])" - "1360([Symbol name=slice_prim])" -> "1361([Symbol name=convert_element_type])" - "1364([Symbol name=cat])" - "1363([Symbol name=convert_element_type])" -> "1364([Symbol name=cat])" - "1359([Symbol name=slice_prim])" -> "1364([Symbol name=cat])" - "1379([Symbol name=reshape])" - "1378([Symbol name=transpose])" -> "1379([Symbol name=reshape])" - "1396([Symbol name=convert_element_type])" - "1394([Symbol name=convert_element_type])" -> "1396([Symbol name=convert_element_type])" - "1388([Symbol name=broadcast_in_dim])" - "1387([Symbol name=sum])" -> "1388([Symbol name=broadcast_in_dim])" - "1409([Symbol name=convert_element_type])" - "1407([Symbol name=convert_element_type])" -> "1409([Symbol name=convert_element_type])" - "1432([Symbol name=convert_element_type])" - "1430([Symbol name=convert_element_type])" -> "1432([Symbol name=convert_element_type])" - "1424([Symbol name=broadcast_in_dim])" - "1423([Symbol name=sum])" -> "1424([Symbol name=broadcast_in_dim])" - "1449([Symbol name=cat])" - "1448([Symbol name=convert_element_type])" -> "1449([Symbol name=cat])" - "1444([Symbol name=slice_prim])" -> "1449([Symbol name=cat])" - "1446([Symbol name=convert_element_type])" - "1445([Symbol name=slice_prim])" -> "1446([Symbol name=convert_element_type])" - "1464([Symbol name=cat])" - "1459([Symbol name=slice_prim])" -> "1464([Symbol name=cat])" - "1463([Symbol name=convert_element_type])" -> "1464([Symbol name=cat])" - "1461([Symbol name=convert_element_type])" - "1460([Symbol name=slice_prim])" -> "1461([Symbol name=convert_element_type])" - "1479([Symbol name=reshape])" - "1478([Symbol name=transpose])" -> "1479([Symbol name=reshape])" - "1496([Symbol name=convert_element_type])" - "1494([Symbol name=convert_element_type])" -> "1496([Symbol name=convert_element_type])" - "1488([Symbol name=broadcast_in_dim])" - "1487([Symbol name=sum])" -> "1488([Symbol name=broadcast_in_dim])" - "1509([Symbol name=convert_element_type])" - "1507([Symbol name=convert_element_type])" -> "1509([Symbol name=convert_element_type])" - "1532([Symbol name=convert_element_type])" - "1530([Symbol name=convert_element_type])" -> "1532([Symbol name=convert_element_type])" - "1524([Symbol name=broadcast_in_dim])" - "1523([Symbol name=sum])" -> "1524([Symbol name=broadcast_in_dim])" - "1549([Symbol name=cat])" - "1544([Symbol name=slice_prim])" -> "1549([Symbol name=cat])" - "1548([Symbol name=convert_element_type])" -> "1549([Symbol name=cat])" - "1546([Symbol name=convert_element_type])" - "1545([Symbol name=slice_prim])" -> "1546([Symbol name=convert_element_type])" - "1561([Symbol name=convert_element_type])" - "1560([Symbol name=slice_prim])" -> "1561([Symbol name=convert_element_type])" - "1564([Symbol name=cat])" - "1563([Symbol name=convert_element_type])" -> "1564([Symbol name=cat])" - "1559([Symbol name=slice_prim])" -> "1564([Symbol name=cat])" - "1579([Symbol name=reshape])" - "1578([Symbol name=transpose])" -> "1579([Symbol name=reshape])" - "1596([Symbol name=convert_element_type])" - "1594([Symbol name=convert_element_type])" -> "1596([Symbol name=convert_element_type])" - "1588([Symbol name=broadcast_in_dim])" - "1587([Symbol name=sum])" -> "1588([Symbol name=broadcast_in_dim])" - "1609([Symbol name=convert_element_type])" - "1607([Symbol name=convert_element_type])" -> "1609([Symbol name=convert_element_type])" - "1632([Symbol name=convert_element_type])" - "1630([Symbol name=convert_element_type])" -> "1632([Symbol name=convert_element_type])" - "1624([Symbol name=broadcast_in_dim])" - "1623([Symbol name=sum])" -> "1624([Symbol name=broadcast_in_dim])" - "1649([Symbol name=cat])" - "1648([Symbol name=convert_element_type])" -> "1649([Symbol name=cat])" - "1644([Symbol name=slice_prim])" -> "1649([Symbol name=cat])" - "1646([Symbol name=convert_element_type])" - "1645([Symbol name=slice_prim])" -> "1646([Symbol name=convert_element_type])" - "1664([Symbol name=cat])" - "1659([Symbol name=slice_prim])" -> "1664([Symbol name=cat])" - "1663([Symbol name=convert_element_type])" -> "1664([Symbol name=cat])" - "1661([Symbol name=convert_element_type])" - "1660([Symbol name=slice_prim])" -> "1661([Symbol name=convert_element_type])" - "1679([Symbol name=reshape])" - "1678([Symbol name=transpose])" -> "1679([Symbol name=reshape])" - "1696([Symbol name=convert_element_type])" - "1694([Symbol name=convert_element_type])" -> "1696([Symbol name=convert_element_type])" - "1688([Symbol name=broadcast_in_dim])" - "1687([Symbol name=sum])" -> "1688([Symbol name=broadcast_in_dim])" - "1709([Symbol name=convert_element_type])" - "1707([Symbol name=convert_element_type])" -> "1709([Symbol name=convert_element_type])" - "1732([Symbol name=convert_element_type])" - "1730([Symbol name=convert_element_type])" -> "1732([Symbol name=convert_element_type])" - "1724([Symbol name=broadcast_in_dim])" - "1723([Symbol name=sum])" -> "1724([Symbol name=broadcast_in_dim])" - "128([Symbol name=broadcast_in_dim])" - "127([Symbol name=rsqrt])" -> "128([Symbol name=broadcast_in_dim])" - "189([Symbol name=true_divide])" - "188([Symbol name=broadcast_in_dim])" -> "189([Symbol name=true_divide])" - "154([Symbol name=convert_element_type])" - "149([Symbol name=cat])" -> "154([Symbol name=convert_element_type])" - "147([Symbol name=neg])" - "146([Symbol name=convert_element_type])" -> "147([Symbol name=neg])" - "162([Symbol name=neg])" - "161([Symbol name=convert_element_type])" -> "162([Symbol name=neg])" - "169([Symbol name=convert_element_type])" - "164([Symbol name=cat])" -> "169([Symbol name=convert_element_type])" - "225([Symbol name=true_divide])" - "224([Symbol name=broadcast_in_dim])" -> "225([Symbol name=true_divide])" - "254([Symbol name=convert_element_type])" - "249([Symbol name=cat])" -> "254([Symbol name=convert_element_type])" - "247([Symbol name=neg])" - "246([Symbol name=convert_element_type])" -> "247([Symbol name=neg])" - "269([Symbol name=convert_element_type])" - "264([Symbol name=cat])" -> "269([Symbol name=convert_element_type])" - "262([Symbol name=neg])" - "261([Symbol name=convert_element_type])" -> "262([Symbol name=neg])" - "289([Symbol name=true_divide])" - "288([Symbol name=broadcast_in_dim])" -> "289([Symbol name=true_divide])" - "325([Symbol name=true_divide])" - "324([Symbol name=broadcast_in_dim])" -> "325([Symbol name=true_divide])" - "354([Symbol name=convert_element_type])" - "349([Symbol name=cat])" -> "354([Symbol name=convert_element_type])" - "347([Symbol name=neg])" - "346([Symbol name=convert_element_type])" -> "347([Symbol name=neg])" - "362([Symbol name=neg])" - "361([Symbol name=convert_element_type])" -> "362([Symbol name=neg])" - "369([Symbol name=convert_element_type])" - "364([Symbol name=cat])" -> "369([Symbol name=convert_element_type])" - "389([Symbol name=true_divide])" - "388([Symbol name=broadcast_in_dim])" -> "389([Symbol name=true_divide])" - "425([Symbol name=true_divide])" - "424([Symbol name=broadcast_in_dim])" -> "425([Symbol name=true_divide])" - "454([Symbol name=convert_element_type])" - "449([Symbol name=cat])" -> "454([Symbol name=convert_element_type])" - "447([Symbol name=neg])" - "446([Symbol name=convert_element_type])" -> "447([Symbol name=neg])" - "469([Symbol name=convert_element_type])" - "464([Symbol name=cat])" -> "469([Symbol name=convert_element_type])" - "462([Symbol name=neg])" - "461([Symbol name=convert_element_type])" -> "462([Symbol name=neg])" - "489([Symbol name=true_divide])" - "488([Symbol name=broadcast_in_dim])" -> "489([Symbol name=true_divide])" - "525([Symbol name=true_divide])" - "524([Symbol name=broadcast_in_dim])" -> "525([Symbol name=true_divide])" - "554([Symbol name=convert_element_type])" - "549([Symbol name=cat])" -> "554([Symbol name=convert_element_type])" - "547([Symbol name=neg])" - "546([Symbol name=convert_element_type])" -> "547([Symbol name=neg])" - "562([Symbol name=neg])" - "561([Symbol name=convert_element_type])" -> "562([Symbol name=neg])" - "569([Symbol name=convert_element_type])" - "564([Symbol name=cat])" -> "569([Symbol name=convert_element_type])" - "589([Symbol name=true_divide])" - "588([Symbol name=broadcast_in_dim])" -> "589([Symbol name=true_divide])" - "625([Symbol name=true_divide])" - "624([Symbol name=broadcast_in_dim])" -> "625([Symbol name=true_divide])" - "654([Symbol name=convert_element_type])" - "649([Symbol name=cat])" -> "654([Symbol name=convert_element_type])" - "647([Symbol name=neg])" - "646([Symbol name=convert_element_type])" -> "647([Symbol name=neg])" - "669([Symbol name=convert_element_type])" - "664([Symbol name=cat])" -> "669([Symbol name=convert_element_type])" - "662([Symbol name=neg])" - "661([Symbol name=convert_element_type])" -> "662([Symbol name=neg])" - "689([Symbol name=true_divide])" - "688([Symbol name=broadcast_in_dim])" -> "689([Symbol name=true_divide])" - "725([Symbol name=true_divide])" - "724([Symbol name=broadcast_in_dim])" -> "725([Symbol name=true_divide])" - "754([Symbol name=convert_element_type])" - "749([Symbol name=cat])" -> "754([Symbol name=convert_element_type])" - "747([Symbol name=neg])" - "746([Symbol name=convert_element_type])" -> "747([Symbol name=neg])" - "762([Symbol name=neg])" - "761([Symbol name=convert_element_type])" -> "762([Symbol name=neg])" - "769([Symbol name=convert_element_type])" - "764([Symbol name=cat])" -> "769([Symbol name=convert_element_type])" - "789([Symbol name=true_divide])" - "788([Symbol name=broadcast_in_dim])" -> "789([Symbol name=true_divide])" - "825([Symbol name=true_divide])" - "824([Symbol name=broadcast_in_dim])" -> "825([Symbol name=true_divide])" - "854([Symbol name=convert_element_type])" - "849([Symbol name=cat])" -> "854([Symbol name=convert_element_type])" - "847([Symbol name=neg])" - "846([Symbol name=convert_element_type])" -> "847([Symbol name=neg])" - "869([Symbol name=convert_element_type])" - "864([Symbol name=cat])" -> "869([Symbol name=convert_element_type])" - "862([Symbol name=neg])" - "861([Symbol name=convert_element_type])" -> "862([Symbol name=neg])" - "889([Symbol name=true_divide])" - "888([Symbol name=broadcast_in_dim])" -> "889([Symbol name=true_divide])" - "925([Symbol name=true_divide])" - "924([Symbol name=broadcast_in_dim])" -> "925([Symbol name=true_divide])" - "954([Symbol name=convert_element_type])" - "949([Symbol name=cat])" -> "954([Symbol name=convert_element_type])" - "947([Symbol name=neg])" - "946([Symbol name=convert_element_type])" -> "947([Symbol name=neg])" - "962([Symbol name=neg])" - "961([Symbol name=convert_element_type])" -> "962([Symbol name=neg])" - "969([Symbol name=convert_element_type])" - "964([Symbol name=cat])" -> "969([Symbol name=convert_element_type])" - "989([Symbol name=true_divide])" - "988([Symbol name=broadcast_in_dim])" -> "989([Symbol name=true_divide])" - "1025([Symbol name=true_divide])" - "1024([Symbol name=broadcast_in_dim])" -> "1025([Symbol name=true_divide])" - "1054([Symbol name=convert_element_type])" - "1049([Symbol name=cat])" -> "1054([Symbol name=convert_element_type])" - "1047([Symbol name=neg])" - "1046([Symbol name=convert_element_type])" -> "1047([Symbol name=neg])" - "1069([Symbol name=convert_element_type])" - "1064([Symbol name=cat])" -> "1069([Symbol name=convert_element_type])" - "1062([Symbol name=neg])" - "1061([Symbol name=convert_element_type])" -> "1062([Symbol name=neg])" - "1089([Symbol name=true_divide])" - "1088([Symbol name=broadcast_in_dim])" -> "1089([Symbol name=true_divide])" - "1125([Symbol name=true_divide])" - "1124([Symbol name=broadcast_in_dim])" -> "1125([Symbol name=true_divide])" - "1154([Symbol name=convert_element_type])" - "1149([Symbol name=cat])" -> "1154([Symbol name=convert_element_type])" - "1147([Symbol name=neg])" - "1146([Symbol name=convert_element_type])" -> "1147([Symbol name=neg])" - "1162([Symbol name=neg])" - "1161([Symbol name=convert_element_type])" -> "1162([Symbol name=neg])" - "1169([Symbol name=convert_element_type])" - "1164([Symbol name=cat])" -> "1169([Symbol name=convert_element_type])" - "1189([Symbol name=true_divide])" - "1188([Symbol name=broadcast_in_dim])" -> "1189([Symbol name=true_divide])" - "1225([Symbol name=true_divide])" - "1224([Symbol name=broadcast_in_dim])" -> "1225([Symbol name=true_divide])" - "1254([Symbol name=convert_element_type])" - "1249([Symbol name=cat])" -> "1254([Symbol name=convert_element_type])" - "1247([Symbol name=neg])" - "1246([Symbol name=convert_element_type])" -> "1247([Symbol name=neg])" - "1269([Symbol name=convert_element_type])" - "1264([Symbol name=cat])" -> "1269([Symbol name=convert_element_type])" - "1262([Symbol name=neg])" - "1261([Symbol name=convert_element_type])" -> "1262([Symbol name=neg])" - "1289([Symbol name=true_divide])" - "1288([Symbol name=broadcast_in_dim])" -> "1289([Symbol name=true_divide])" - "1325([Symbol name=true_divide])" - "1324([Symbol name=broadcast_in_dim])" -> "1325([Symbol name=true_divide])" - "1354([Symbol name=convert_element_type])" - "1349([Symbol name=cat])" -> "1354([Symbol name=convert_element_type])" - "1347([Symbol name=neg])" - "1346([Symbol name=convert_element_type])" -> "1347([Symbol name=neg])" - "1362([Symbol name=neg])" - "1361([Symbol name=convert_element_type])" -> "1362([Symbol name=neg])" - "1369([Symbol name=convert_element_type])" - "1364([Symbol name=cat])" -> "1369([Symbol name=convert_element_type])" - "1389([Symbol name=true_divide])" - "1388([Symbol name=broadcast_in_dim])" -> "1389([Symbol name=true_divide])" - "1425([Symbol name=true_divide])" - "1424([Symbol name=broadcast_in_dim])" -> "1425([Symbol name=true_divide])" - "1454([Symbol name=convert_element_type])" - "1449([Symbol name=cat])" -> "1454([Symbol name=convert_element_type])" - "1447([Symbol name=neg])" - "1446([Symbol name=convert_element_type])" -> "1447([Symbol name=neg])" - "1469([Symbol name=convert_element_type])" - "1464([Symbol name=cat])" -> "1469([Symbol name=convert_element_type])" - "1462([Symbol name=neg])" - "1461([Symbol name=convert_element_type])" -> "1462([Symbol name=neg])" - "1489([Symbol name=true_divide])" - "1488([Symbol name=broadcast_in_dim])" -> "1489([Symbol name=true_divide])" - "1525([Symbol name=true_divide])" - "1524([Symbol name=broadcast_in_dim])" -> "1525([Symbol name=true_divide])" - "1554([Symbol name=convert_element_type])" - "1549([Symbol name=cat])" -> "1554([Symbol name=convert_element_type])" - "1547([Symbol name=neg])" - "1546([Symbol name=convert_element_type])" -> "1547([Symbol name=neg])" - "1562([Symbol name=neg])" - "1561([Symbol name=convert_element_type])" -> "1562([Symbol name=neg])" - "1569([Symbol name=convert_element_type])" - "1564([Symbol name=cat])" -> "1569([Symbol name=convert_element_type])" - "1589([Symbol name=true_divide])" - "1588([Symbol name=broadcast_in_dim])" -> "1589([Symbol name=true_divide])" - "1625([Symbol name=true_divide])" - "1624([Symbol name=broadcast_in_dim])" -> "1625([Symbol name=true_divide])" - "1654([Symbol name=convert_element_type])" - "1649([Symbol name=cat])" -> "1654([Symbol name=convert_element_type])" - "1647([Symbol name=neg])" - "1646([Symbol name=convert_element_type])" -> "1647([Symbol name=neg])" - "1669([Symbol name=convert_element_type])" - "1664([Symbol name=cat])" -> "1669([Symbol name=convert_element_type])" - "1662([Symbol name=neg])" - "1661([Symbol name=convert_element_type])" -> "1662([Symbol name=neg])" - "1689([Symbol name=true_divide])" - "1688([Symbol name=broadcast_in_dim])" -> "1689([Symbol name=true_divide])" - "1725([Symbol name=true_divide])" - "1724([Symbol name=broadcast_in_dim])" -> "1725([Symbol name=true_divide])" - "190([Symbol name=add])" - "189([Symbol name=true_divide])" -> "190([Symbol name=add])" - "148([Symbol name=convert_element_type])" - "147([Symbol name=neg])" -> "148([Symbol name=convert_element_type])" - "163([Symbol name=convert_element_type])" - "162([Symbol name=neg])" -> "163([Symbol name=convert_element_type])" - "226([Symbol name=add])" - "225([Symbol name=true_divide])" -> "226([Symbol name=add])" - "248([Symbol name=convert_element_type])" - "247([Symbol name=neg])" -> "248([Symbol name=convert_element_type])" - "263([Symbol name=convert_element_type])" - "262([Symbol name=neg])" -> "263([Symbol name=convert_element_type])" - "290([Symbol name=add])" - "289([Symbol name=true_divide])" -> "290([Symbol name=add])" - "326([Symbol name=add])" - "325([Symbol name=true_divide])" -> "326([Symbol name=add])" - "348([Symbol name=convert_element_type])" - "347([Symbol name=neg])" -> "348([Symbol name=convert_element_type])" - "363([Symbol name=convert_element_type])" - "362([Symbol name=neg])" -> "363([Symbol name=convert_element_type])" - "390([Symbol name=add])" - "389([Symbol name=true_divide])" -> "390([Symbol name=add])" - "426([Symbol name=add])" - "425([Symbol name=true_divide])" -> "426([Symbol name=add])" - "448([Symbol name=convert_element_type])" - "447([Symbol name=neg])" -> "448([Symbol name=convert_element_type])" - "463([Symbol name=convert_element_type])" - "462([Symbol name=neg])" -> "463([Symbol name=convert_element_type])" - "490([Symbol name=add])" - "489([Symbol name=true_divide])" -> "490([Symbol name=add])" - "526([Symbol name=add])" - "525([Symbol name=true_divide])" -> "526([Symbol name=add])" - "548([Symbol name=convert_element_type])" - "547([Symbol name=neg])" -> "548([Symbol name=convert_element_type])" - "563([Symbol name=convert_element_type])" - "562([Symbol name=neg])" -> "563([Symbol name=convert_element_type])" - "590([Symbol name=add])" - "589([Symbol name=true_divide])" -> "590([Symbol name=add])" - "626([Symbol name=add])" - "625([Symbol name=true_divide])" -> "626([Symbol name=add])" - "648([Symbol name=convert_element_type])" - "647([Symbol name=neg])" -> "648([Symbol name=convert_element_type])" - "663([Symbol name=convert_element_type])" - "662([Symbol name=neg])" -> "663([Symbol name=convert_element_type])" - "690([Symbol name=add])" - "689([Symbol name=true_divide])" -> "690([Symbol name=add])" - "726([Symbol name=add])" - "725([Symbol name=true_divide])" -> "726([Symbol name=add])" - "748([Symbol name=convert_element_type])" - "747([Symbol name=neg])" -> "748([Symbol name=convert_element_type])" - "763([Symbol name=convert_element_type])" - "762([Symbol name=neg])" -> "763([Symbol name=convert_element_type])" - "790([Symbol name=add])" - "789([Symbol name=true_divide])" -> "790([Symbol name=add])" - "826([Symbol name=add])" - "825([Symbol name=true_divide])" -> "826([Symbol name=add])" - "848([Symbol name=convert_element_type])" - "847([Symbol name=neg])" -> "848([Symbol name=convert_element_type])" - "863([Symbol name=convert_element_type])" - "862([Symbol name=neg])" -> "863([Symbol name=convert_element_type])" - "890([Symbol name=add])" - "889([Symbol name=true_divide])" -> "890([Symbol name=add])" - "926([Symbol name=add])" - "925([Symbol name=true_divide])" -> "926([Symbol name=add])" - "948([Symbol name=convert_element_type])" - "947([Symbol name=neg])" -> "948([Symbol name=convert_element_type])" - "963([Symbol name=convert_element_type])" - "962([Symbol name=neg])" -> "963([Symbol name=convert_element_type])" - "990([Symbol name=add])" - "989([Symbol name=true_divide])" -> "990([Symbol name=add])" - "1026([Symbol name=add])" - "1025([Symbol name=true_divide])" -> "1026([Symbol name=add])" - "1048([Symbol name=convert_element_type])" - "1047([Symbol name=neg])" -> "1048([Symbol name=convert_element_type])" - "1063([Symbol name=convert_element_type])" - "1062([Symbol name=neg])" -> "1063([Symbol name=convert_element_type])" - "1090([Symbol name=add])" - "1089([Symbol name=true_divide])" -> "1090([Symbol name=add])" - "1126([Symbol name=add])" - "1125([Symbol name=true_divide])" -> "1126([Symbol name=add])" - "1148([Symbol name=convert_element_type])" - "1147([Symbol name=neg])" -> "1148([Symbol name=convert_element_type])" - "1163([Symbol name=convert_element_type])" - "1162([Symbol name=neg])" -> "1163([Symbol name=convert_element_type])" - "1190([Symbol name=add])" - "1189([Symbol name=true_divide])" -> "1190([Symbol name=add])" - "1226([Symbol name=add])" - "1225([Symbol name=true_divide])" -> "1226([Symbol name=add])" - "1248([Symbol name=convert_element_type])" - "1247([Symbol name=neg])" -> "1248([Symbol name=convert_element_type])" - "1263([Symbol name=convert_element_type])" - "1262([Symbol name=neg])" -> "1263([Symbol name=convert_element_type])" - "1290([Symbol name=add])" - "1289([Symbol name=true_divide])" -> "1290([Symbol name=add])" - "1326([Symbol name=add])" - "1325([Symbol name=true_divide])" -> "1326([Symbol name=add])" - "1348([Symbol name=convert_element_type])" - "1347([Symbol name=neg])" -> "1348([Symbol name=convert_element_type])" - "1363([Symbol name=convert_element_type])" - "1362([Symbol name=neg])" -> "1363([Symbol name=convert_element_type])" - "1390([Symbol name=add])" - "1389([Symbol name=true_divide])" -> "1390([Symbol name=add])" - "1426([Symbol name=add])" - "1425([Symbol name=true_divide])" -> "1426([Symbol name=add])" - "1448([Symbol name=convert_element_type])" - "1447([Symbol name=neg])" -> "1448([Symbol name=convert_element_type])" - "1463([Symbol name=convert_element_type])" - "1462([Symbol name=neg])" -> "1463([Symbol name=convert_element_type])" - "1490([Symbol name=add])" - "1489([Symbol name=true_divide])" -> "1490([Symbol name=add])" - "1526([Symbol name=add])" - "1525([Symbol name=true_divide])" -> "1526([Symbol name=add])" - "1548([Symbol name=convert_element_type])" - "1547([Symbol name=neg])" -> "1548([Symbol name=convert_element_type])" - "1563([Symbol name=convert_element_type])" - "1562([Symbol name=neg])" -> "1563([Symbol name=convert_element_type])" - "1590([Symbol name=add])" - "1589([Symbol name=true_divide])" -> "1590([Symbol name=add])" - "1626([Symbol name=add])" - "1625([Symbol name=true_divide])" -> "1626([Symbol name=add])" - "1648([Symbol name=convert_element_type])" - "1647([Symbol name=neg])" -> "1648([Symbol name=convert_element_type])" - "1663([Symbol name=convert_element_type])" - "1662([Symbol name=neg])" -> "1663([Symbol name=convert_element_type])" - "1690([Symbol name=add])" - "1689([Symbol name=true_divide])" -> "1690([Symbol name=add])" - "1726([Symbol name=add])" - "1725([Symbol name=true_divide])" -> "1726([Symbol name=add])" - "191([Symbol name=rsqrt])" - "190([Symbol name=add])" -> "191([Symbol name=rsqrt])" - "227([Symbol name=rsqrt])" - "226([Symbol name=add])" -> "227([Symbol name=rsqrt])" - "291([Symbol name=rsqrt])" - "290([Symbol name=add])" -> "291([Symbol name=rsqrt])" - "327([Symbol name=rsqrt])" - "326([Symbol name=add])" -> "327([Symbol name=rsqrt])" - "391([Symbol name=rsqrt])" - "390([Symbol name=add])" -> "391([Symbol name=rsqrt])" - "427([Symbol name=rsqrt])" - "426([Symbol name=add])" -> "427([Symbol name=rsqrt])" - "491([Symbol name=rsqrt])" - "490([Symbol name=add])" -> "491([Symbol name=rsqrt])" - "527([Symbol name=rsqrt])" - "526([Symbol name=add])" -> "527([Symbol name=rsqrt])" - "591([Symbol name=rsqrt])" - "590([Symbol name=add])" -> "591([Symbol name=rsqrt])" - "627([Symbol name=rsqrt])" - "626([Symbol name=add])" -> "627([Symbol name=rsqrt])" - "691([Symbol name=rsqrt])" - "690([Symbol name=add])" -> "691([Symbol name=rsqrt])" - "727([Symbol name=rsqrt])" - "726([Symbol name=add])" -> "727([Symbol name=rsqrt])" - "791([Symbol name=rsqrt])" - "790([Symbol name=add])" -> "791([Symbol name=rsqrt])" - "827([Symbol name=rsqrt])" - "826([Symbol name=add])" -> "827([Symbol name=rsqrt])" - "891([Symbol name=rsqrt])" - "890([Symbol name=add])" -> "891([Symbol name=rsqrt])" - "927([Symbol name=rsqrt])" - "926([Symbol name=add])" -> "927([Symbol name=rsqrt])" - "991([Symbol name=rsqrt])" - "990([Symbol name=add])" -> "991([Symbol name=rsqrt])" - "1027([Symbol name=rsqrt])" - "1026([Symbol name=add])" -> "1027([Symbol name=rsqrt])" - "1091([Symbol name=rsqrt])" - "1090([Symbol name=add])" -> "1091([Symbol name=rsqrt])" - "1127([Symbol name=rsqrt])" - "1126([Symbol name=add])" -> "1127([Symbol name=rsqrt])" - "1191([Symbol name=rsqrt])" - "1190([Symbol name=add])" -> "1191([Symbol name=rsqrt])" - "1227([Symbol name=rsqrt])" - "1226([Symbol name=add])" -> "1227([Symbol name=rsqrt])" - "1291([Symbol name=rsqrt])" - "1290([Symbol name=add])" -> "1291([Symbol name=rsqrt])" - "1327([Symbol name=rsqrt])" - "1326([Symbol name=add])" -> "1327([Symbol name=rsqrt])" - "1391([Symbol name=rsqrt])" - "1390([Symbol name=add])" -> "1391([Symbol name=rsqrt])" - "1427([Symbol name=rsqrt])" - "1426([Symbol name=add])" -> "1427([Symbol name=rsqrt])" - "1491([Symbol name=rsqrt])" - "1490([Symbol name=add])" -> "1491([Symbol name=rsqrt])" - "1527([Symbol name=rsqrt])" - "1526([Symbol name=add])" -> "1527([Symbol name=rsqrt])" - "1591([Symbol name=rsqrt])" - "1590([Symbol name=add])" -> "1591([Symbol name=rsqrt])" - "1627([Symbol name=rsqrt])" - "1626([Symbol name=add])" -> "1627([Symbol name=rsqrt])" - "1691([Symbol name=rsqrt])" - "1690([Symbol name=add])" -> "1691([Symbol name=rsqrt])" - "1727([Symbol name=rsqrt])" - "1726([Symbol name=add])" -> "1727([Symbol name=rsqrt])" - "192([Symbol name=broadcast_in_dim])" - "191([Symbol name=rsqrt])" -> "192([Symbol name=broadcast_in_dim])" - "228([Symbol name=broadcast_in_dim])" - "227([Symbol name=rsqrt])" -> "228([Symbol name=broadcast_in_dim])" - "292([Symbol name=broadcast_in_dim])" - "291([Symbol name=rsqrt])" -> "292([Symbol name=broadcast_in_dim])" - "328([Symbol name=broadcast_in_dim])" - "327([Symbol name=rsqrt])" -> "328([Symbol name=broadcast_in_dim])" - "392([Symbol name=broadcast_in_dim])" - "391([Symbol name=rsqrt])" -> "392([Symbol name=broadcast_in_dim])" - "428([Symbol name=broadcast_in_dim])" - "427([Symbol name=rsqrt])" -> "428([Symbol name=broadcast_in_dim])" - "492([Symbol name=broadcast_in_dim])" - "491([Symbol name=rsqrt])" -> "492([Symbol name=broadcast_in_dim])" - "528([Symbol name=broadcast_in_dim])" - "527([Symbol name=rsqrt])" -> "528([Symbol name=broadcast_in_dim])" - "592([Symbol name=broadcast_in_dim])" - "591([Symbol name=rsqrt])" -> "592([Symbol name=broadcast_in_dim])" - "628([Symbol name=broadcast_in_dim])" - "627([Symbol name=rsqrt])" -> "628([Symbol name=broadcast_in_dim])" - "692([Symbol name=broadcast_in_dim])" - "691([Symbol name=rsqrt])" -> "692([Symbol name=broadcast_in_dim])" - "728([Symbol name=broadcast_in_dim])" - "727([Symbol name=rsqrt])" -> "728([Symbol name=broadcast_in_dim])" - "792([Symbol name=broadcast_in_dim])" - "791([Symbol name=rsqrt])" -> "792([Symbol name=broadcast_in_dim])" - "828([Symbol name=broadcast_in_dim])" - "827([Symbol name=rsqrt])" -> "828([Symbol name=broadcast_in_dim])" - "892([Symbol name=broadcast_in_dim])" - "891([Symbol name=rsqrt])" -> "892([Symbol name=broadcast_in_dim])" - "928([Symbol name=broadcast_in_dim])" - "927([Symbol name=rsqrt])" -> "928([Symbol name=broadcast_in_dim])" - "992([Symbol name=broadcast_in_dim])" - "991([Symbol name=rsqrt])" -> "992([Symbol name=broadcast_in_dim])" - "1028([Symbol name=broadcast_in_dim])" - "1027([Symbol name=rsqrt])" -> "1028([Symbol name=broadcast_in_dim])" - "1092([Symbol name=broadcast_in_dim])" - "1091([Symbol name=rsqrt])" -> "1092([Symbol name=broadcast_in_dim])" - "1128([Symbol name=broadcast_in_dim])" - "1127([Symbol name=rsqrt])" -> "1128([Symbol name=broadcast_in_dim])" - "1192([Symbol name=broadcast_in_dim])" - "1191([Symbol name=rsqrt])" -> "1192([Symbol name=broadcast_in_dim])" - "1228([Symbol name=broadcast_in_dim])" - "1227([Symbol name=rsqrt])" -> "1228([Symbol name=broadcast_in_dim])" - "1292([Symbol name=broadcast_in_dim])" - "1291([Symbol name=rsqrt])" -> "1292([Symbol name=broadcast_in_dim])" - "1328([Symbol name=broadcast_in_dim])" - "1327([Symbol name=rsqrt])" -> "1328([Symbol name=broadcast_in_dim])" - "1392([Symbol name=broadcast_in_dim])" - "1391([Symbol name=rsqrt])" -> "1392([Symbol name=broadcast_in_dim])" - "1428([Symbol name=broadcast_in_dim])" - "1427([Symbol name=rsqrt])" -> "1428([Symbol name=broadcast_in_dim])" - "1492([Symbol name=broadcast_in_dim])" - "1491([Symbol name=rsqrt])" -> "1492([Symbol name=broadcast_in_dim])" - "1528([Symbol name=broadcast_in_dim])" - "1527([Symbol name=rsqrt])" -> "1528([Symbol name=broadcast_in_dim])" - "1592([Symbol name=broadcast_in_dim])" - "1591([Symbol name=rsqrt])" -> "1592([Symbol name=broadcast_in_dim])" - "1628([Symbol name=broadcast_in_dim])" - "1627([Symbol name=rsqrt])" -> "1628([Symbol name=broadcast_in_dim])" - "1692([Symbol name=broadcast_in_dim])" - "1691([Symbol name=rsqrt])" -> "1692([Symbol name=broadcast_in_dim])" - "1728([Symbol name=broadcast_in_dim])" - "1727([Symbol name=rsqrt])" -> "1728([Symbol name=broadcast_in_dim])" -} diff --git a/examples/dev/litGPT.out b/examples/dev/litGPT.out deleted file mode 100644 index 2fa95319a6..0000000000 --- a/examples/dev/litGPT.out +++ /dev/null @@ -1,25381 +0,0 @@ -============================================ START: LABEL default -============================================ START: computation_trc split_forward_backward -# Constructed by Dead Code Elimination (took 4 milliseconds) -import thunder -import thunder.core.dtypes as dtypes -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(idx, tos1, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_1_weight, t_transformer_h_0_mlp_fc_2_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_weight, t_transformer_h_0_norm_2_weight, t_transformer_h_1_attn_attn_weight, t_transformer_h_1_attn_proj_weight, t_transformer_h_1_mlp_fc_1_weight, t_transformer_h_1_mlp_fc_2_weight, t_transformer_h_1_mlp_proj_weight, t_transformer_h_1_norm_1_weight, t_transformer_h_1_norm_2_weight, t_transformer_h_2_attn_attn_weight, t_transformer_h_2_attn_proj_weight, t_transformer_h_2_mlp_fc_1_weight, t_transformer_h_2_mlp_fc_2_weight, t_transformer_h_2_mlp_proj_weight, t_transformer_h_2_norm_1_weight, t_transformer_h_2_norm_2_weight, t_transformer_h_3_attn_attn_weight, t_transformer_h_3_attn_proj_weight, t_transformer_h_3_mlp_fc_1_weight, t_transformer_h_3_mlp_fc_2_weight, t_transformer_h_3_mlp_proj_weight, t_transformer_h_3_norm_1_weight, t_transformer_h_3_norm_2_weight, t_transformer_h_4_attn_attn_weight, t_transformer_h_4_attn_proj_weight, t_transformer_h_4_mlp_fc_1_weight, t_transformer_h_4_mlp_fc_2_weight, t_transformer_h_4_mlp_proj_weight, t_transformer_h_4_norm_1_weight, t_transformer_h_4_norm_2_weight, t_transformer_h_5_attn_attn_weight, t_transformer_h_5_attn_proj_weight, t_transformer_h_5_mlp_fc_1_weight, t_transformer_h_5_mlp_fc_2_weight, t_transformer_h_5_mlp_proj_weight, t_transformer_h_5_norm_1_weight, t_transformer_h_5_norm_2_weight, t_transformer_h_6_attn_attn_weight, t_transformer_h_6_attn_proj_weight, t_transformer_h_6_mlp_fc_1_weight, t_transformer_h_6_mlp_fc_2_weight, t_transformer_h_6_mlp_proj_weight, t_transformer_h_6_norm_1_weight, t_transformer_h_6_norm_2_weight, t_transformer_h_7_attn_attn_weight, t_transformer_h_7_attn_proj_weight, t_transformer_h_7_mlp_fc_1_weight, t_transformer_h_7_mlp_fc_2_weight, t_transformer_h_7_mlp_proj_weight, t_transformer_h_7_norm_1_weight, t_transformer_h_7_norm_2_weight, t_transformer_h_8_attn_attn_weight, t_transformer_h_8_attn_proj_weight, t_transformer_h_8_mlp_fc_1_weight, t_transformer_h_8_mlp_fc_2_weight, t_transformer_h_8_mlp_proj_weight, t_transformer_h_8_norm_1_weight, t_transformer_h_8_norm_2_weight, t_transformer_h_9_attn_attn_weight, t_transformer_h_9_attn_proj_weight, t_transformer_h_9_mlp_fc_1_weight, t_transformer_h_9_mlp_fc_2_weight, t_transformer_h_9_mlp_proj_weight, t_transformer_h_9_norm_1_weight, t_transformer_h_9_norm_2_weight, t_transformer_h_10_attn_attn_weight, t_transformer_h_10_attn_proj_weight, t_transformer_h_10_mlp_fc_1_weight, t_transformer_h_10_mlp_fc_2_weight, t_transformer_h_10_mlp_proj_weight, t_transformer_h_10_norm_1_weight, t_transformer_h_10_norm_2_weight, t_transformer_h_11_attn_attn_weight, t_transformer_h_11_attn_proj_weight, t_transformer_h_11_mlp_fc_1_weight, t_transformer_h_11_mlp_fc_2_weight, t_transformer_h_11_mlp_proj_weight, t_transformer_h_11_norm_1_weight, t_transformer_h_11_norm_2_weight, t_transformer_h_12_attn_attn_weight, t_transformer_h_12_attn_proj_weight, t_transformer_h_12_mlp_fc_1_weight, t_transformer_h_12_mlp_fc_2_weight, t_transformer_h_12_mlp_proj_weight, t_transformer_h_12_norm_1_weight, t_transformer_h_12_norm_2_weight, t_transformer_h_13_attn_attn_weight, t_transformer_h_13_attn_proj_weight, t_transformer_h_13_mlp_fc_1_weight, t_transformer_h_13_mlp_fc_2_weight, t_transformer_h_13_mlp_proj_weight, t_transformer_h_13_norm_1_weight, t_transformer_h_13_norm_2_weight, t_transformer_h_14_attn_attn_weight, t_transformer_h_14_attn_proj_weight, t_transformer_h_14_mlp_fc_1_weight, t_transformer_h_14_mlp_fc_2_weight, t_transformer_h_14_mlp_proj_weight, t_transformer_h_14_norm_1_weight, t_transformer_h_14_norm_2_weight, t_transformer_h_15_attn_attn_weight, t_transformer_h_15_attn_proj_weight, t_transformer_h_15_mlp_fc_1_weight, t_transformer_h_15_mlp_fc_2_weight, t_transformer_h_15_mlp_proj_weight, t_transformer_h_15_norm_1_weight, t_transformer_h_15_norm_2_weight, t_transformer_ln_f_weight, t_transformer_wte_weight): - # idx: "cuda:0 i64[1, 512]" - # tos1: "cuda:0 f32[4096, 128]" - # t_lm_head_weight: "cuda:0 bf16[32000, 4096]" - # t_sin: "cuda:0 f32[4096, 128]" - # t_transformer_h_0_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_0_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_0_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_0_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_0_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_0_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_0_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_1_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_1_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_1_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_1_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_1_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_1_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_1_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_2_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_2_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_2_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_2_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_2_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_2_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_2_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_3_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_3_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_3_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_3_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_3_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_3_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_3_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_4_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_4_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_4_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_4_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_4_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_4_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_4_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_5_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_5_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_5_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_5_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_5_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_5_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_5_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_6_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_6_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_6_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_6_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_6_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_6_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_6_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_7_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_7_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_7_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_7_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_7_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_7_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_7_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_8_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_8_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_8_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_8_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_8_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_8_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_8_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_9_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_9_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_9_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_9_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_9_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_9_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_9_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_10_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_10_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_10_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_10_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_10_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_10_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_10_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_11_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_11_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_11_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_11_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_11_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_11_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_11_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_12_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_12_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_12_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_12_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_12_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_12_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_12_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_13_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_13_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_13_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_13_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_13_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_13_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_13_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_14_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_14_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_14_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_14_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_14_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_14_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_14_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_15_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_15_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_15_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_15_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_15_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_15_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_15_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_ln_f_weight: "cuda:0 bf16[4096]" - # t_transformer_wte_weight: "cuda:0 bf16[32000, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:85: cos = self.cos[:T] - cos = ltorch.getitem(tos1, slice(None, 512, None)) # cos: "cuda:0 f32[512, 128]" - # cos = prims.slice_prim(tos1, [0, 0], [512, 128], [1, 1]) # cos: "cuda:0 f32[512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:86: sin = self.sin[:T] - sin = ltorch.getitem(t_sin, slice(None, 512, None)) # sin: "cuda:0 f32[512, 128]" - # sin = prims.slice_prim(t_sin, [0, 0], [512, 128], [1, 1]) # sin: "cuda:0 f32[512, 128]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/sparse.py:190: return F.embedding( - x = ltorch.embedding(idx, t_transformer_wte_weight, None, None, 2.0, False, False) # x: "cuda:0 bf16[1, 512, 4096]" - # t16 = ltorch.reshape(idx, [512]) # t16: "cuda:0 i64[512]" - # t16 = prims.reshape(idx, (512,)) # t16: "cuda:0 i64[512]" - # t17 = prims.take(t_transformer_wte_weight, t16, 0) # t17: "cuda:0 bf16[512, 4096]" - # x = ltorch.reshape(t17, [1, 512, 4096]) # x: "cuda:0 bf16[1, 512, 4096]" - # x = prims.reshape(t17, (1, 512, 4096)) # x: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - a = prims.convert_element_type(x, dtypes.float32) # a: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - result = ltorch.mul(a, a) # result: "cuda:0 f32[1, 512, 4096]" - # result = prims.mul(a, a) # result: "cuda:0 f32[1, 512, 4096]" - norm_x = ltorch.mean(result, -1, True, dtype=None) # norm_x: "cuda:0 f32[1, 512, 1]" - # t24 = prims.sum(result, (2,)) # t24: "cuda:0 f32[1, 512]" - # t25 = prims.broadcast_in_dim(t24, [1, 512, 1], [0, 1]) # t25: "cuda:0 f32[1, 512, 1]" - # norm_x = ltorch.true_divide(t25, 4096) # norm_x: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # norm_x = prims.div(t25, 4096.0) # norm_x: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t28 = ltorch.add(norm_x, 1e-05, alpha=None) # t28: "cuda:0 f32[1, 512, 1]" - # t28 = prims.add(norm_x, 1e-05) # t28: "cuda:0 f32[1, 512, 1]" - b = ltorch.rsqrt(t28) # b: "cuda:0 f32[1, 512, 1]" - # b = prims.rsqrt(t28) # b: "cuda:0 f32[1, 512, 1]" - x_normed = ltorch.mul(a, b) # x_normed: "cuda:0 f32[1, 512, 4096]" - # t30 = prims.broadcast_in_dim(b, (1, 512, 4096), (0, 1, 2)) # t30: "cuda:0 f32[1, 512, 4096]" - # x_normed = prims.mul(a, t30) # x_normed: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t32 = ltorch.to(x_normed, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t32: "cuda:0 bf16[1, 512, 4096]" - # t32 = prims.convert_element_type(x_normed, dtypes.bfloat16) # t32: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - input = ltorch.mul(t32, t_transformer_h_0_norm_1_weight) # input: "cuda:0 bf16[1, 512, 4096]" - # t38 = prims.broadcast_in_dim(t_transformer_h_0_norm_1_weight, (1, 512, 4096), (2,)) # t38: "cuda:0 bf16[1, 512, 4096]" - # t39 = prims.convert_element_type(t32, dtypes.float32) # t39: "cuda:0 f32[1, 512, 4096]" - # t40 = prims.convert_element_type(t38, dtypes.float32) # t40: "cuda:0 f32[1, 512, 4096]" - # t41 = prims.mul(t39, t40) # t41: "cuda:0 f32[1, 512, 4096]" - # input = prims.convert_element_type(t41, dtypes.bfloat16) # input: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - qkv = ltorch.linear(input, t_transformer_h_0_attn_attn_weight, None) # qkv: "cuda:0 bf16[1, 512, 12288]" - # qkv = prims.linear(input, t_transformer_h_0_attn_attn_weight, None) # qkv: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t51 = ltorch.view(qkv, 1, 512, 32, 3, 128) # t51: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t51 = ltorch.reshape(qkv, (1, 512, 32, 3, 128)) # t51: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t51 = prims.reshape(qkv, (1, 512, 32, 3, 128)) # t51: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t52 = ltorch.permute(t51, 0, 2, 3, 1, 4) # t52: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t52 = prims.transpose(t51, (0, 2, 3, 1, 4)) # t52: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (res, k, v) = ltorch.split(t52, (1, 1, 1), 2) - # res = prims.slice_prim(t52, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # res: "cuda:0 bf16[1, 32, 1, 512, 128]" - # k = prims.slice_prim(t52, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # k: "cuda:0 bf16[1, 32, 1, 512, 128]" - # v = prims.slice_prim(t52, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # v: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - q = ltorch.reshape(res, 1, -1, 512, 128) # q: "cuda:0 bf16[1, 32, 512, 128]" - # q = prims.reshape(res, (1, 32, 512, 128)) # q: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t57 = ltorch.reshape(k, 1, -1, 512, 128) # t57: "cuda:0 bf16[1, 32, 512, 128]" - # t57 = prims.reshape(k, (1, 32, 512, 128)) # t57: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t58 = ltorch.reshape(v, 1, -1, 512, 128) # t58: "cuda:0 bf16[1, 32, 512, 128]" - # t58 = prims.reshape(v, (1, 32, 512, 128)) # t58: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t60 = ltorch.getitem(q, (..., slice(None, 128, None))) # t60: "cuda:0 bf16[1, 32, 512, 128]" - # t60 = prims.slice_prim(q, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t60: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - x1 = ltorch.getitem(t60, (..., slice(None, 64, None))) # x1: "cuda:0 bf16[1, 32, 512, 64]" - # x1 = prims.slice_prim(t60, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # x1: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - x2 = ltorch.getitem(t60, (..., slice(64, None, None))) # x2: "cuda:0 bf16[1, 32, 512, 64]" - # x2 = prims.slice_prim(t60, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # x2: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t65 = ltorch.neg(x2) # t65: "cuda:0 bf16[1, 32, 512, 64]" - # t63 = prims.convert_element_type(x2, dtypes.float32) # t63: "cuda:0 f32[1, 32, 512, 64]" - # t64 = prims.neg(t63) # t64: "cuda:0 f32[1, 32, 512, 64]" - # t65 = prims.convert_element_type(t64, dtypes.bfloat16) # t65: "cuda:0 bf16[1, 32, 512, 64]" - rotated = ltorch.cat((t65, x1), -1) # rotated: "cuda:0 bf16[1, 32, 512, 128]" - # rotated = prims.cat((t65, x1), -1) # rotated: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t69 = ltorch.mul(t60, cos) # t69: "cuda:0 f32[1, 32, 512, 128]" - # t67 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t67: "cuda:0 f32[1, 32, 512, 128]" - # t68 = prims.convert_element_type(t60, dtypes.float32) # t68: "cuda:0 f32[1, 32, 512, 128]" - # t69 = prims.mul(t68, t67) # t69: "cuda:0 f32[1, 32, 512, 128]" - t72 = ltorch.mul(rotated, sin) # t72: "cuda:0 f32[1, 32, 512, 128]" - # t70 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t70: "cuda:0 f32[1, 32, 512, 128]" - # t71 = prims.convert_element_type(rotated, dtypes.float32) # t71: "cuda:0 f32[1, 32, 512, 128]" - # t72 = prims.mul(t71, t70) # t72: "cuda:0 f32[1, 32, 512, 128]" - roped = ltorch.add(t69, t72, alpha=None) # roped: "cuda:0 f32[1, 32, 512, 128]" - # roped = prims.add(t69, t72) # roped: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - q_roped = ltorch.to(roped, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # q_roped: "cuda:0 bf16[1, 32, 512, 128]" - # q_roped = prims.convert_element_type(roped, dtypes.bfloat16) # q_roped: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t75 = ltorch.getitem(t57, (..., slice(None, 128, None))) # t75: "cuda:0 bf16[1, 32, 512, 128]" - # t75 = prims.slice_prim(t57, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t75: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t76 = ltorch.getitem(t75, (..., slice(None, 64, None))) # t76: "cuda:0 bf16[1, 32, 512, 64]" - # t76 = prims.slice_prim(t75, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t76: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - tos = ltorch.getitem(t75, (..., slice(64, None, None))) # tos: "cuda:0 bf16[1, 32, 512, 64]" - # tos = prims.slice_prim(t75, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # tos: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t80 = ltorch.neg(tos) # t80: "cuda:0 bf16[1, 32, 512, 64]" - # t78 = prims.convert_element_type(tos, dtypes.float32) # t78: "cuda:0 f32[1, 32, 512, 64]" - # t79 = prims.neg(t78) # t79: "cuda:0 f32[1, 32, 512, 64]" - # t80 = prims.convert_element_type(t79, dtypes.bfloat16) # t80: "cuda:0 bf16[1, 32, 512, 64]" - t81 = ltorch.cat((t80, t76), -1) # t81: "cuda:0 bf16[1, 32, 512, 128]" - # t81 = prims.cat((t80, t76), -1) # t81: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t84 = ltorch.mul(t75, cos) # t84: "cuda:0 f32[1, 32, 512, 128]" - # t82 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t82: "cuda:0 f32[1, 32, 512, 128]" - # t83 = prims.convert_element_type(t75, dtypes.float32) # t83: "cuda:0 f32[1, 32, 512, 128]" - # t84 = prims.mul(t83, t82) # t84: "cuda:0 f32[1, 32, 512, 128]" - t87 = ltorch.mul(t81, sin) # t87: "cuda:0 f32[1, 32, 512, 128]" - # t85 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t85: "cuda:0 f32[1, 32, 512, 128]" - # t86 = prims.convert_element_type(t81, dtypes.float32) # t86: "cuda:0 f32[1, 32, 512, 128]" - # t87 = prims.mul(t86, t85) # t87: "cuda:0 f32[1, 32, 512, 128]" - t88 = ltorch.add(t84, t87, alpha=None) # t88: "cuda:0 f32[1, 32, 512, 128]" - # t88 = prims.add(t84, t87) # t88: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - k_roped = ltorch.to(t88, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # k_roped: "cuda:0 bf16[1, 32, 512, 128]" - # k_roped = prims.convert_element_type(t88, dtypes.bfloat16) # k_roped: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t90 = ltorch.getitem(q, (..., slice(128, None, None))) # t90: "cuda:0 bf16[1, 32, 512, 0]" - # t90 = prims.slice_prim(q, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t90: "cuda:0 bf16[1, 32, 512, 0]" - t91 = ltorch.cat((q_roped, t90), -1) # t91: "cuda:0 bf16[1, 32, 512, 128]" - # t91 = prims.cat((q_roped, t90), -1) # t91: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t92 = ltorch.getitem(t57, (..., slice(128, None, None))) # t92: "cuda:0 bf16[1, 32, 512, 0]" - # t92 = prims.slice_prim(t57, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t92: "cuda:0 bf16[1, 32, 512, 0]" - t93 = ltorch.cat((k_roped, t92), -1) # t93: "cuda:0 bf16[1, 32, 512, 128]" - # t93 = prims.cat((k_roped, t92), -1) # t93: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - y = ltorch.scaled_dot_product_attention(t91, t93, t58, None, 0.0, True, scale=0.08838834764831843) # y: "cuda:0 bf16[1, 32, 512, 128]" - # t96 = ltorch.mul(t91, 0.29730177875068026) # t96: "cuda:0 bf16[1, 32, 512, 128]" - # t94 = prims.convert_element_type(t91, dtypes.float32) # t94: "cuda:0 f32[1, 32, 512, 128]" - # t95 = prims.mul(t94, 0.29730177875068026) # t95: "cuda:0 f32[1, 32, 512, 128]" - # t96 = prims.convert_element_type(t95, dtypes.bfloat16) # t96: "cuda:0 bf16[1, 32, 512, 128]" - # t97 = ltorch.transpose(t93, -2, -1) # t97: "cuda:0 bf16[1, 32, 128, 512]" - # t97 = prims.transpose(t93, (0, 1, 3, 2)) # t97: "cuda:0 bf16[1, 32, 128, 512]" - # t100 = ltorch.mul(t97, 0.29730177875068026) # t100: "cuda:0 bf16[1, 32, 128, 512]" - # t98 = prims.convert_element_type(t97, dtypes.float32) # t98: "cuda:0 f32[1, 32, 128, 512]" - # t99 = prims.mul(t98, 0.29730177875068026) # t99: "cuda:0 f32[1, 32, 128, 512]" - # t100 = prims.convert_element_type(t99, dtypes.bfloat16) # t100: "cuda:0 bf16[1, 32, 128, 512]" - # t101 = ltorch.matmul(t96, t100) # t101: "cuda:0 bf16[1, 32, 512, 512]" - # t101 = prims.matmul(t96, t100) # t101: "cuda:0 bf16[1, 32, 512, 512]" - # t111 = ltorch.tril(t101, 0, fill_value=-float('inf')) # t111: "cuda:0 bf16[1, 32, 512, 512]" - # t102 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t102: "cuda:0 i64[512]" - # t102 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t102: "cuda:0 i64[512]" - # t103 = ltorch.unsqueeze(t102, -1) # t103: "cuda:0 i64[512, 1]" - # t103 = prims.broadcast_in_dim(t102, [512, 1], [0]) # t103: "cuda:0 i64[512, 1]" - # t104 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t104: "cuda:0 i64[512]" - # t104 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t104: "cuda:0 i64[512]" - # t105 = ltorch.unsqueeze(t104, -2) # t105: "cuda:0 i64[1, 512]" - # t105 = prims.broadcast_in_dim(t104, [1, 512], [1]) # t105: "cuda:0 i64[1, 512]" - # t106 = ltorch.add(t103, 0, alpha=None) # t106: "cuda:0 i64[512, 1]" - # t106 = prims.add(t103, 0) # t106: "cuda:0 i64[512, 1]" - # t109 = ltorch.ge(t106, t105) # t109: "cuda:0 b8[512, 512]" - # t107 = prims.broadcast_in_dim(t106, (512, 512), (0, 1)) # t107: "cuda:0 i64[512, 512]" - # t108 = prims.broadcast_in_dim(t105, (512, 512), (0, 1)) # t108: "cuda:0 i64[512, 512]" - # t109 = prims.ge(t107, t108) # t109: "cuda:0 b8[512, 512]" - # t111 = ltorch.where(t109, t101, -float('inf')) # t111: "cuda:0 bf16[1, 32, 512, 512]" - # t110 = prims.broadcast_in_dim(t109, (1, 32, 512, 512), (2, 3)) # t110: "cuda:0 b8[1, 32, 512, 512]" - # t111 = prims.where(t110, t101, -float('inf')) # t111: "cuda:0 bf16[1, 32, 512, 512]" - # t122 = ltorch._softmax(t111, -1, dtype=None) # t122: "cuda:0 bf16[1, 32, 512, 512]" - # t112 = ltorch.to(t111, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t112: "cuda:0 f32[1, 32, 512, 512]" - # t112 = prims.convert_element_type(t111, dtypes.float32) # t112: "cuda:0 f32[1, 32, 512, 512]" - # t114 = ltorch.amax(t112, -1, True) # t114: "cuda:0 f32[1, 32, 512, 1]" - # t113 = prims.amax(t112, (3,)) # t113: "cuda:0 f32[1, 32, 512]" - # t114 = prims.broadcast_in_dim(t113, [1, 32, 512, 1], [0, 1, 2]) # t114: "cuda:0 f32[1, 32, 512, 1]" - # t116 = ltorch.sub(t112, t114, alpha=None) # t116: "cuda:0 f32[1, 32, 512, 512]" - # t115 = prims.broadcast_in_dim(t114, (1, 32, 512, 512), (0, 1, 2, 3)) # t115: "cuda:0 f32[1, 32, 512, 512]" - # t116 = prims.sub(t112, t115) # t116: "cuda:0 f32[1, 32, 512, 512]" - # t117 = ltorch.exp(t116) # t117: "cuda:0 f32[1, 32, 512, 512]" - # t117 = prims.exp(t116) # t117: "cuda:0 f32[1, 32, 512, 512]" - # t119 = ltorch.sum(t117, -1, True, dtype=None) # t119: "cuda:0 f32[1, 32, 512, 1]" - # t118 = prims.sum(t117, (3,)) # t118: "cuda:0 f32[1, 32, 512]" - # t119 = prims.broadcast_in_dim(t118, [1, 32, 512, 1], [0, 1, 2]) # t119: "cuda:0 f32[1, 32, 512, 1]" - # t121 = ltorch.true_divide(t117, t119) # t121: "cuda:0 f32[1, 32, 512, 512]" - # t120 = prims.broadcast_in_dim(t119, (1, 32, 512, 512), (0, 1, 2, 3)) # t120: "cuda:0 f32[1, 32, 512, 512]" - # t121 = prims.div(t117, t120) # t121: "cuda:0 f32[1, 32, 512, 512]" - # t122 = ltorch.to(t121, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t122: "cuda:0 bf16[1, 32, 512, 512]" - # t122 = prims.convert_element_type(t121, dtypes.bfloat16) # t122: "cuda:0 bf16[1, 32, 512, 512]" - # y = ltorch.matmul(t122, t58) # y: "cuda:0 bf16[1, 32, 512, 128]" - # y = prims.matmul(t122, t58) # y: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t124 = ltorch.transpose(y, 1, 2) # t124: "cuda:0 bf16[1, 512, 32, 128]" - # t124 = prims.transpose(y, (0, 2, 1, 3)) # t124: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t125 = ltorch.reshape(t124, 1, 512, 4096) # t125: "cuda:0 bf16[1, 512, 4096]" - # t125 = prims.reshape(t124, (1, 512, 4096)) # t125: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - attention_output = ltorch.linear(t125, t_transformer_h_0_attn_proj_weight, None) # attention_output: "cuda:0 bf16[1, 512, 4096]" - # attention_output = prims.linear(t125, t_transformer_h_0_attn_proj_weight, None) # attention_output: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t134 = ltorch.add(attention_output, x, alpha=None) # t134: "cuda:0 bf16[1, 512, 4096]" - # t131 = prims.convert_element_type(attention_output, dtypes.float32) # t131: "cuda:0 f32[1, 512, 4096]" - # t132 = prims.convert_element_type(x, dtypes.float32) # t132: "cuda:0 f32[1, 512, 4096]" - # t133 = prims.add(t131, t132) # t133: "cuda:0 f32[1, 512, 4096]" - # t134 = prims.convert_element_type(t133, dtypes.bfloat16) # t134: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t135 = prims.convert_element_type(t134, dtypes.float32) # t135: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t136 = ltorch.mul(t135, t135) # t136: "cuda:0 f32[1, 512, 4096]" - # t136 = prims.mul(t135, t135) # t136: "cuda:0 f32[1, 512, 4096]" - t140 = ltorch.mean(t136, -1, True, dtype=None) # t140: "cuda:0 f32[1, 512, 1]" - # t138 = prims.sum(t136, (2,)) # t138: "cuda:0 f32[1, 512]" - # t139 = prims.broadcast_in_dim(t138, [1, 512, 1], [0, 1]) # t139: "cuda:0 f32[1, 512, 1]" - # t140 = ltorch.true_divide(t139, 4096) # t140: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t140 = prims.div(t139, 4096.0) # t140: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t142 = ltorch.add(t140, 1e-05, alpha=None) # t142: "cuda:0 f32[1, 512, 1]" - # t142 = prims.add(t140, 1e-05) # t142: "cuda:0 f32[1, 512, 1]" - t143 = ltorch.rsqrt(t142) # t143: "cuda:0 f32[1, 512, 1]" - # t143 = prims.rsqrt(t142) # t143: "cuda:0 f32[1, 512, 1]" - t145 = ltorch.mul(t135, t143) # t145: "cuda:0 f32[1, 512, 4096]" - # t144 = prims.broadcast_in_dim(t143, (1, 512, 4096), (0, 1, 2)) # t144: "cuda:0 f32[1, 512, 4096]" - # t145 = prims.mul(t135, t144) # t145: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t146 = ltorch.to(t145, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t146: "cuda:0 bf16[1, 512, 4096]" - # t146 = prims.convert_element_type(t145, dtypes.bfloat16) # t146: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t156 = ltorch.mul(t146, t_transformer_h_0_norm_2_weight) # t156: "cuda:0 bf16[1, 512, 4096]" - # t152 = prims.broadcast_in_dim(t_transformer_h_0_norm_2_weight, (1, 512, 4096), (2,)) # t152: "cuda:0 bf16[1, 512, 4096]" - # t153 = prims.convert_element_type(t146, dtypes.float32) # t153: "cuda:0 f32[1, 512, 4096]" - # t154 = prims.convert_element_type(t152, dtypes.float32) # t154: "cuda:0 f32[1, 512, 4096]" - # t155 = prims.mul(t153, t154) # t155: "cuda:0 f32[1, 512, 4096]" - # t156 = prims.convert_element_type(t155, dtypes.bfloat16) # t156: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - x_fc_1 = ltorch.linear(t156, t_transformer_h_0_mlp_fc_1_weight, None) # x_fc_1: "cuda:0 bf16[1, 512, 11008]" - # x_fc_1 = prims.linear(t156, t_transformer_h_0_mlp_fc_1_weight, None) # x_fc_1: "cuda:0 bf16[1, 512, 11008]" - x_fc_2 = ltorch.linear(t156, t_transformer_h_0_mlp_fc_2_weight, None) # x_fc_2: "cuda:0 bf16[1, 512, 11008]" - # x_fc_2 = prims.linear(t156, t_transformer_h_0_mlp_fc_2_weight, None) # x_fc_2: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t175 = ltorch.silu(x_fc_1, False) # t175: "cuda:0 bf16[1, 512, 11008]" - # t166 = prims.convert_element_type(x_fc_1, dtypes.float32) # t166: "cuda:0 f32[1, 512, 11008]" - # t167 = prims.neg(t166) # t167: "cuda:0 f32[1, 512, 11008]" - # t168 = prims.exp(t167) # t168: "cuda:0 f32[1, 512, 11008]" - # t169 = prims.add(1.0, t168) # t169: "cuda:0 f32[1, 512, 11008]" - # t170 = prims.reciprocal(t169) # t170: "cuda:0 f32[1, 512, 11008]" - # t171 = prims.convert_element_type(t170, dtypes.bfloat16) # t171: "cuda:0 bf16[1, 512, 11008]" - # t172 = prims.convert_element_type(x_fc_1, dtypes.float32) # t172: "cuda:0 f32[1, 512, 11008]" - # t173 = prims.convert_element_type(t171, dtypes.float32) # t173: "cuda:0 f32[1, 512, 11008]" - # t174 = prims.mul(t172, t173) # t174: "cuda:0 f32[1, 512, 11008]" - # t175 = prims.convert_element_type(t174, dtypes.bfloat16) # t175: "cuda:0 bf16[1, 512, 11008]" - t179 = ltorch.mul(t175, x_fc_2) # t179: "cuda:0 bf16[1, 512, 11008]" - # t176 = prims.convert_element_type(t175, dtypes.float32) # t176: "cuda:0 f32[1, 512, 11008]" - # t177 = prims.convert_element_type(x_fc_2, dtypes.float32) # t177: "cuda:0 f32[1, 512, 11008]" - # t178 = prims.mul(t176, t177) # t178: "cuda:0 f32[1, 512, 11008]" - # t179 = prims.convert_element_type(t178, dtypes.bfloat16) # t179: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t183 = ltorch.linear(t179, t_transformer_h_0_mlp_proj_weight, None) # t183: "cuda:0 bf16[1, 512, 4096]" - # t183 = prims.linear(t179, t_transformer_h_0_mlp_proj_weight, None) # t183: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t187 = ltorch.add(t183, t134, alpha=None) # t187: "cuda:0 bf16[1, 512, 4096]" - # t184 = prims.convert_element_type(t183, dtypes.float32) # t184: "cuda:0 f32[1, 512, 4096]" - # t185 = prims.convert_element_type(t134, dtypes.float32) # t185: "cuda:0 f32[1, 512, 4096]" - # t186 = prims.add(t184, t185) # t186: "cuda:0 f32[1, 512, 4096]" - # t187 = prims.convert_element_type(t186, dtypes.bfloat16) # t187: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t189 = prims.convert_element_type(t187, dtypes.float32) # t189: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t190 = ltorch.mul(t189, t189) # t190: "cuda:0 f32[1, 512, 4096]" - # t190 = prims.mul(t189, t189) # t190: "cuda:0 f32[1, 512, 4096]" - t194 = ltorch.mean(t190, -1, True, dtype=None) # t194: "cuda:0 f32[1, 512, 1]" - # t192 = prims.sum(t190, (2,)) # t192: "cuda:0 f32[1, 512]" - # t193 = prims.broadcast_in_dim(t192, [1, 512, 1], [0, 1]) # t193: "cuda:0 f32[1, 512, 1]" - # t194 = ltorch.true_divide(t193, 4096) # t194: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t194 = prims.div(t193, 4096.0) # t194: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t196 = ltorch.add(t194, 1e-05, alpha=None) # t196: "cuda:0 f32[1, 512, 1]" - # t196 = prims.add(t194, 1e-05) # t196: "cuda:0 f32[1, 512, 1]" - t197 = ltorch.rsqrt(t196) # t197: "cuda:0 f32[1, 512, 1]" - # t197 = prims.rsqrt(t196) # t197: "cuda:0 f32[1, 512, 1]" - t199 = ltorch.mul(t189, t197) # t199: "cuda:0 f32[1, 512, 4096]" - # t198 = prims.broadcast_in_dim(t197, (1, 512, 4096), (0, 1, 2)) # t198: "cuda:0 f32[1, 512, 4096]" - # t199 = prims.mul(t189, t198) # t199: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t200 = ltorch.to(t199, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t200: "cuda:0 bf16[1, 512, 4096]" - # t200 = prims.convert_element_type(t199, dtypes.bfloat16) # t200: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t210 = ltorch.mul(t200, t_transformer_h_1_norm_1_weight) # t210: "cuda:0 bf16[1, 512, 4096]" - # t206 = prims.broadcast_in_dim(t_transformer_h_1_norm_1_weight, (1, 512, 4096), (2,)) # t206: "cuda:0 bf16[1, 512, 4096]" - # t207 = prims.convert_element_type(t200, dtypes.float32) # t207: "cuda:0 f32[1, 512, 4096]" - # t208 = prims.convert_element_type(t206, dtypes.float32) # t208: "cuda:0 f32[1, 512, 4096]" - # t209 = prims.mul(t207, t208) # t209: "cuda:0 f32[1, 512, 4096]" - # t210 = prims.convert_element_type(t209, dtypes.bfloat16) # t210: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t215 = ltorch.linear(t210, t_transformer_h_1_attn_attn_weight, None) # t215: "cuda:0 bf16[1, 512, 12288]" - # t215 = prims.linear(t210, t_transformer_h_1_attn_attn_weight, None) # t215: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t216 = ltorch.view(t215, 1, 512, 32, 3, 128) # t216: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t216 = ltorch.reshape(t215, (1, 512, 32, 3, 128)) # t216: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t216 = prims.reshape(t215, (1, 512, 32, 3, 128)) # t216: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t217 = ltorch.permute(t216, 0, 2, 3, 1, 4) # t217: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t217 = prims.transpose(t216, (0, 2, 3, 1, 4)) # t217: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t218, t219, t220) = ltorch.split(t217, (1, 1, 1), 2) - # t218 = prims.slice_prim(t217, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t218: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t219 = prims.slice_prim(t217, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t219: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t220 = prims.slice_prim(t217, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t220: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t221 = ltorch.reshape(t218, 1, -1, 512, 128) # t221: "cuda:0 bf16[1, 32, 512, 128]" - # t221 = prims.reshape(t218, (1, 32, 512, 128)) # t221: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t222 = ltorch.reshape(t219, 1, -1, 512, 128) # t222: "cuda:0 bf16[1, 32, 512, 128]" - # t222 = prims.reshape(t219, (1, 32, 512, 128)) # t222: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t223 = ltorch.reshape(t220, 1, -1, 512, 128) # t223: "cuda:0 bf16[1, 32, 512, 128]" - # t223 = prims.reshape(t220, (1, 32, 512, 128)) # t223: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t224 = ltorch.getitem(t221, (..., slice(None, 128, None))) # t224: "cuda:0 bf16[1, 32, 512, 128]" - # t224 = prims.slice_prim(t221, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t224: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t225 = ltorch.getitem(t224, (..., slice(None, 64, None))) # t225: "cuda:0 bf16[1, 32, 512, 64]" - # t225 = prims.slice_prim(t224, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t225: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t226 = ltorch.getitem(t224, (..., slice(64, None, None))) # t226: "cuda:0 bf16[1, 32, 512, 64]" - # t226 = prims.slice_prim(t224, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t226: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t229 = ltorch.neg(t226) # t229: "cuda:0 bf16[1, 32, 512, 64]" - # t227 = prims.convert_element_type(t226, dtypes.float32) # t227: "cuda:0 f32[1, 32, 512, 64]" - # t228 = prims.neg(t227) # t228: "cuda:0 f32[1, 32, 512, 64]" - # t229 = prims.convert_element_type(t228, dtypes.bfloat16) # t229: "cuda:0 bf16[1, 32, 512, 64]" - t230 = ltorch.cat((t229, t225), -1) # t230: "cuda:0 bf16[1, 32, 512, 128]" - # t230 = prims.cat((t229, t225), -1) # t230: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t233 = ltorch.mul(t224, cos) # t233: "cuda:0 f32[1, 32, 512, 128]" - # t231 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t231: "cuda:0 f32[1, 32, 512, 128]" - # t232 = prims.convert_element_type(t224, dtypes.float32) # t232: "cuda:0 f32[1, 32, 512, 128]" - # t233 = prims.mul(t232, t231) # t233: "cuda:0 f32[1, 32, 512, 128]" - t236 = ltorch.mul(t230, sin) # t236: "cuda:0 f32[1, 32, 512, 128]" - # t234 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t234: "cuda:0 f32[1, 32, 512, 128]" - # t235 = prims.convert_element_type(t230, dtypes.float32) # t235: "cuda:0 f32[1, 32, 512, 128]" - # t236 = prims.mul(t235, t234) # t236: "cuda:0 f32[1, 32, 512, 128]" - t237 = ltorch.add(t233, t236, alpha=None) # t237: "cuda:0 f32[1, 32, 512, 128]" - # t237 = prims.add(t233, t236) # t237: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t238 = ltorch.to(t237, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t238: "cuda:0 bf16[1, 32, 512, 128]" - # t238 = prims.convert_element_type(t237, dtypes.bfloat16) # t238: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t239 = ltorch.getitem(t222, (..., slice(None, 128, None))) # t239: "cuda:0 bf16[1, 32, 512, 128]" - # t239 = prims.slice_prim(t222, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t239: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t240 = ltorch.getitem(t239, (..., slice(None, 64, None))) # t240: "cuda:0 bf16[1, 32, 512, 64]" - # t240 = prims.slice_prim(t239, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t240: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t241 = ltorch.getitem(t239, (..., slice(64, None, None))) # t241: "cuda:0 bf16[1, 32, 512, 64]" - # t241 = prims.slice_prim(t239, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t241: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t244 = ltorch.neg(t241) # t244: "cuda:0 bf16[1, 32, 512, 64]" - # t242 = prims.convert_element_type(t241, dtypes.float32) # t242: "cuda:0 f32[1, 32, 512, 64]" - # t243 = prims.neg(t242) # t243: "cuda:0 f32[1, 32, 512, 64]" - # t244 = prims.convert_element_type(t243, dtypes.bfloat16) # t244: "cuda:0 bf16[1, 32, 512, 64]" - t245 = ltorch.cat((t244, t240), -1) # t245: "cuda:0 bf16[1, 32, 512, 128]" - # t245 = prims.cat((t244, t240), -1) # t245: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t248 = ltorch.mul(t239, cos) # t248: "cuda:0 f32[1, 32, 512, 128]" - # t246 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t246: "cuda:0 f32[1, 32, 512, 128]" - # t247 = prims.convert_element_type(t239, dtypes.float32) # t247: "cuda:0 f32[1, 32, 512, 128]" - # t248 = prims.mul(t247, t246) # t248: "cuda:0 f32[1, 32, 512, 128]" - t251 = ltorch.mul(t245, sin) # t251: "cuda:0 f32[1, 32, 512, 128]" - # t249 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t249: "cuda:0 f32[1, 32, 512, 128]" - # t250 = prims.convert_element_type(t245, dtypes.float32) # t250: "cuda:0 f32[1, 32, 512, 128]" - # t251 = prims.mul(t250, t249) # t251: "cuda:0 f32[1, 32, 512, 128]" - t252 = ltorch.add(t248, t251, alpha=None) # t252: "cuda:0 f32[1, 32, 512, 128]" - # t252 = prims.add(t248, t251) # t252: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t253 = ltorch.to(t252, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t253: "cuda:0 bf16[1, 32, 512, 128]" - # t253 = prims.convert_element_type(t252, dtypes.bfloat16) # t253: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t254 = ltorch.getitem(t221, (..., slice(128, None, None))) # t254: "cuda:0 bf16[1, 32, 512, 0]" - # t254 = prims.slice_prim(t221, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t254: "cuda:0 bf16[1, 32, 512, 0]" - t255 = ltorch.cat((t238, t254), -1) # t255: "cuda:0 bf16[1, 32, 512, 128]" - # t255 = prims.cat((t238, t254), -1) # t255: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t256 = ltorch.getitem(t222, (..., slice(128, None, None))) # t256: "cuda:0 bf16[1, 32, 512, 0]" - # t256 = prims.slice_prim(t222, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t256: "cuda:0 bf16[1, 32, 512, 0]" - t257 = ltorch.cat((t253, t256), -1) # t257: "cuda:0 bf16[1, 32, 512, 128]" - # t257 = prims.cat((t253, t256), -1) # t257: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t287 = ltorch.scaled_dot_product_attention(t255, t257, t223, None, 0.0, True, scale=0.08838834764831843) # t287: "cuda:0 bf16[1, 32, 512, 128]" - # t260 = ltorch.mul(t255, 0.29730177875068026) # t260: "cuda:0 bf16[1, 32, 512, 128]" - # t258 = prims.convert_element_type(t255, dtypes.float32) # t258: "cuda:0 f32[1, 32, 512, 128]" - # t259 = prims.mul(t258, 0.29730177875068026) # t259: "cuda:0 f32[1, 32, 512, 128]" - # t260 = prims.convert_element_type(t259, dtypes.bfloat16) # t260: "cuda:0 bf16[1, 32, 512, 128]" - # t261 = ltorch.transpose(t257, -2, -1) # t261: "cuda:0 bf16[1, 32, 128, 512]" - # t261 = prims.transpose(t257, (0, 1, 3, 2)) # t261: "cuda:0 bf16[1, 32, 128, 512]" - # t264 = ltorch.mul(t261, 0.29730177875068026) # t264: "cuda:0 bf16[1, 32, 128, 512]" - # t262 = prims.convert_element_type(t261, dtypes.float32) # t262: "cuda:0 f32[1, 32, 128, 512]" - # t263 = prims.mul(t262, 0.29730177875068026) # t263: "cuda:0 f32[1, 32, 128, 512]" - # t264 = prims.convert_element_type(t263, dtypes.bfloat16) # t264: "cuda:0 bf16[1, 32, 128, 512]" - # t265 = ltorch.matmul(t260, t264) # t265: "cuda:0 bf16[1, 32, 512, 512]" - # t265 = prims.matmul(t260, t264) # t265: "cuda:0 bf16[1, 32, 512, 512]" - # t275 = ltorch.tril(t265, 0, fill_value=-float('inf')) # t275: "cuda:0 bf16[1, 32, 512, 512]" - # t266 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t266: "cuda:0 i64[512]" - # t266 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t266: "cuda:0 i64[512]" - # t267 = ltorch.unsqueeze(t266, -1) # t267: "cuda:0 i64[512, 1]" - # t267 = prims.broadcast_in_dim(t266, [512, 1], [0]) # t267: "cuda:0 i64[512, 1]" - # t268 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t268: "cuda:0 i64[512]" - # t268 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t268: "cuda:0 i64[512]" - # t269 = ltorch.unsqueeze(t268, -2) # t269: "cuda:0 i64[1, 512]" - # t269 = prims.broadcast_in_dim(t268, [1, 512], [1]) # t269: "cuda:0 i64[1, 512]" - # t270 = ltorch.add(t267, 0, alpha=None) # t270: "cuda:0 i64[512, 1]" - # t270 = prims.add(t267, 0) # t270: "cuda:0 i64[512, 1]" - # t273 = ltorch.ge(t270, t269) # t273: "cuda:0 b8[512, 512]" - # t271 = prims.broadcast_in_dim(t270, (512, 512), (0, 1)) # t271: "cuda:0 i64[512, 512]" - # t272 = prims.broadcast_in_dim(t269, (512, 512), (0, 1)) # t272: "cuda:0 i64[512, 512]" - # t273 = prims.ge(t271, t272) # t273: "cuda:0 b8[512, 512]" - # t275 = ltorch.where(t273, t265, -float('inf')) # t275: "cuda:0 bf16[1, 32, 512, 512]" - # t274 = prims.broadcast_in_dim(t273, (1, 32, 512, 512), (2, 3)) # t274: "cuda:0 b8[1, 32, 512, 512]" - # t275 = prims.where(t274, t265, -float('inf')) # t275: "cuda:0 bf16[1, 32, 512, 512]" - # t286 = ltorch._softmax(t275, -1, dtype=None) # t286: "cuda:0 bf16[1, 32, 512, 512]" - # t276 = ltorch.to(t275, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t276: "cuda:0 f32[1, 32, 512, 512]" - # t276 = prims.convert_element_type(t275, dtypes.float32) # t276: "cuda:0 f32[1, 32, 512, 512]" - # t278 = ltorch.amax(t276, -1, True) # t278: "cuda:0 f32[1, 32, 512, 1]" - # t277 = prims.amax(t276, (3,)) # t277: "cuda:0 f32[1, 32, 512]" - # t278 = prims.broadcast_in_dim(t277, [1, 32, 512, 1], [0, 1, 2]) # t278: "cuda:0 f32[1, 32, 512, 1]" - # t280 = ltorch.sub(t276, t278, alpha=None) # t280: "cuda:0 f32[1, 32, 512, 512]" - # t279 = prims.broadcast_in_dim(t278, (1, 32, 512, 512), (0, 1, 2, 3)) # t279: "cuda:0 f32[1, 32, 512, 512]" - # t280 = prims.sub(t276, t279) # t280: "cuda:0 f32[1, 32, 512, 512]" - # t281 = ltorch.exp(t280) # t281: "cuda:0 f32[1, 32, 512, 512]" - # t281 = prims.exp(t280) # t281: "cuda:0 f32[1, 32, 512, 512]" - # t283 = ltorch.sum(t281, -1, True, dtype=None) # t283: "cuda:0 f32[1, 32, 512, 1]" - # t282 = prims.sum(t281, (3,)) # t282: "cuda:0 f32[1, 32, 512]" - # t283 = prims.broadcast_in_dim(t282, [1, 32, 512, 1], [0, 1, 2]) # t283: "cuda:0 f32[1, 32, 512, 1]" - # t285 = ltorch.true_divide(t281, t283) # t285: "cuda:0 f32[1, 32, 512, 512]" - # t284 = prims.broadcast_in_dim(t283, (1, 32, 512, 512), (0, 1, 2, 3)) # t284: "cuda:0 f32[1, 32, 512, 512]" - # t285 = prims.div(t281, t284) # t285: "cuda:0 f32[1, 32, 512, 512]" - # t286 = ltorch.to(t285, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t286: "cuda:0 bf16[1, 32, 512, 512]" - # t286 = prims.convert_element_type(t285, dtypes.bfloat16) # t286: "cuda:0 bf16[1, 32, 512, 512]" - # t287 = ltorch.matmul(t286, t223) # t287: "cuda:0 bf16[1, 32, 512, 128]" - # t287 = prims.matmul(t286, t223) # t287: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t288 = ltorch.transpose(t287, 1, 2) # t288: "cuda:0 bf16[1, 512, 32, 128]" - # t288 = prims.transpose(t287, (0, 2, 1, 3)) # t288: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t289 = ltorch.reshape(t288, 1, 512, 4096) # t289: "cuda:0 bf16[1, 512, 4096]" - # t289 = prims.reshape(t288, (1, 512, 4096)) # t289: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t293 = ltorch.linear(t289, t_transformer_h_1_attn_proj_weight, None) # t293: "cuda:0 bf16[1, 512, 4096]" - # t293 = prims.linear(t289, t_transformer_h_1_attn_proj_weight, None) # t293: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t297 = ltorch.add(t293, t187, alpha=None) # t297: "cuda:0 bf16[1, 512, 4096]" - # t294 = prims.convert_element_type(t293, dtypes.float32) # t294: "cuda:0 f32[1, 512, 4096]" - # t295 = prims.convert_element_type(t187, dtypes.float32) # t295: "cuda:0 f32[1, 512, 4096]" - # t296 = prims.add(t294, t295) # t296: "cuda:0 f32[1, 512, 4096]" - # t297 = prims.convert_element_type(t296, dtypes.bfloat16) # t297: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t298 = prims.convert_element_type(t297, dtypes.float32) # t298: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t299 = ltorch.mul(t298, t298) # t299: "cuda:0 f32[1, 512, 4096]" - # t299 = prims.mul(t298, t298) # t299: "cuda:0 f32[1, 512, 4096]" - t303 = ltorch.mean(t299, -1, True, dtype=None) # t303: "cuda:0 f32[1, 512, 1]" - # t301 = prims.sum(t299, (2,)) # t301: "cuda:0 f32[1, 512]" - # t302 = prims.broadcast_in_dim(t301, [1, 512, 1], [0, 1]) # t302: "cuda:0 f32[1, 512, 1]" - # t303 = ltorch.true_divide(t302, 4096) # t303: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t303 = prims.div(t302, 4096.0) # t303: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t305 = ltorch.add(t303, 1e-05, alpha=None) # t305: "cuda:0 f32[1, 512, 1]" - # t305 = prims.add(t303, 1e-05) # t305: "cuda:0 f32[1, 512, 1]" - t306 = ltorch.rsqrt(t305) # t306: "cuda:0 f32[1, 512, 1]" - # t306 = prims.rsqrt(t305) # t306: "cuda:0 f32[1, 512, 1]" - t308 = ltorch.mul(t298, t306) # t308: "cuda:0 f32[1, 512, 4096]" - # t307 = prims.broadcast_in_dim(t306, (1, 512, 4096), (0, 1, 2)) # t307: "cuda:0 f32[1, 512, 4096]" - # t308 = prims.mul(t298, t307) # t308: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t309 = ltorch.to(t308, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t309: "cuda:0 bf16[1, 512, 4096]" - # t309 = prims.convert_element_type(t308, dtypes.bfloat16) # t309: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t319 = ltorch.mul(t309, t_transformer_h_1_norm_2_weight) # t319: "cuda:0 bf16[1, 512, 4096]" - # t315 = prims.broadcast_in_dim(t_transformer_h_1_norm_2_weight, (1, 512, 4096), (2,)) # t315: "cuda:0 bf16[1, 512, 4096]" - # t316 = prims.convert_element_type(t309, dtypes.float32) # t316: "cuda:0 f32[1, 512, 4096]" - # t317 = prims.convert_element_type(t315, dtypes.float32) # t317: "cuda:0 f32[1, 512, 4096]" - # t318 = prims.mul(t316, t317) # t318: "cuda:0 f32[1, 512, 4096]" - # t319 = prims.convert_element_type(t318, dtypes.bfloat16) # t319: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t324 = ltorch.linear(t319, t_transformer_h_1_mlp_fc_1_weight, None) # t324: "cuda:0 bf16[1, 512, 11008]" - # t324 = prims.linear(t319, t_transformer_h_1_mlp_fc_1_weight, None) # t324: "cuda:0 bf16[1, 512, 11008]" - t328 = ltorch.linear(t319, t_transformer_h_1_mlp_fc_2_weight, None) # t328: "cuda:0 bf16[1, 512, 11008]" - # t328 = prims.linear(t319, t_transformer_h_1_mlp_fc_2_weight, None) # t328: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t338 = ltorch.silu(t324, False) # t338: "cuda:0 bf16[1, 512, 11008]" - # t329 = prims.convert_element_type(t324, dtypes.float32) # t329: "cuda:0 f32[1, 512, 11008]" - # t330 = prims.neg(t329) # t330: "cuda:0 f32[1, 512, 11008]" - # t331 = prims.exp(t330) # t331: "cuda:0 f32[1, 512, 11008]" - # t332 = prims.add(1.0, t331) # t332: "cuda:0 f32[1, 512, 11008]" - # t333 = prims.reciprocal(t332) # t333: "cuda:0 f32[1, 512, 11008]" - # t334 = prims.convert_element_type(t333, dtypes.bfloat16) # t334: "cuda:0 bf16[1, 512, 11008]" - # t335 = prims.convert_element_type(t324, dtypes.float32) # t335: "cuda:0 f32[1, 512, 11008]" - # t336 = prims.convert_element_type(t334, dtypes.float32) # t336: "cuda:0 f32[1, 512, 11008]" - # t337 = prims.mul(t335, t336) # t337: "cuda:0 f32[1, 512, 11008]" - # t338 = prims.convert_element_type(t337, dtypes.bfloat16) # t338: "cuda:0 bf16[1, 512, 11008]" - t342 = ltorch.mul(t338, t328) # t342: "cuda:0 bf16[1, 512, 11008]" - # t339 = prims.convert_element_type(t338, dtypes.float32) # t339: "cuda:0 f32[1, 512, 11008]" - # t340 = prims.convert_element_type(t328, dtypes.float32) # t340: "cuda:0 f32[1, 512, 11008]" - # t341 = prims.mul(t339, t340) # t341: "cuda:0 f32[1, 512, 11008]" - # t342 = prims.convert_element_type(t341, dtypes.bfloat16) # t342: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t346 = ltorch.linear(t342, t_transformer_h_1_mlp_proj_weight, None) # t346: "cuda:0 bf16[1, 512, 4096]" - # t346 = prims.linear(t342, t_transformer_h_1_mlp_proj_weight, None) # t346: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t350 = ltorch.add(t346, t297, alpha=None) # t350: "cuda:0 bf16[1, 512, 4096]" - # t347 = prims.convert_element_type(t346, dtypes.float32) # t347: "cuda:0 f32[1, 512, 4096]" - # t348 = prims.convert_element_type(t297, dtypes.float32) # t348: "cuda:0 f32[1, 512, 4096]" - # t349 = prims.add(t347, t348) # t349: "cuda:0 f32[1, 512, 4096]" - # t350 = prims.convert_element_type(t349, dtypes.bfloat16) # t350: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t352 = prims.convert_element_type(t350, dtypes.float32) # t352: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t353 = ltorch.mul(t352, t352) # t353: "cuda:0 f32[1, 512, 4096]" - # t353 = prims.mul(t352, t352) # t353: "cuda:0 f32[1, 512, 4096]" - t357 = ltorch.mean(t353, -1, True, dtype=None) # t357: "cuda:0 f32[1, 512, 1]" - # t355 = prims.sum(t353, (2,)) # t355: "cuda:0 f32[1, 512]" - # t356 = prims.broadcast_in_dim(t355, [1, 512, 1], [0, 1]) # t356: "cuda:0 f32[1, 512, 1]" - # t357 = ltorch.true_divide(t356, 4096) # t357: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t357 = prims.div(t356, 4096.0) # t357: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t359 = ltorch.add(t357, 1e-05, alpha=None) # t359: "cuda:0 f32[1, 512, 1]" - # t359 = prims.add(t357, 1e-05) # t359: "cuda:0 f32[1, 512, 1]" - t360 = ltorch.rsqrt(t359) # t360: "cuda:0 f32[1, 512, 1]" - # t360 = prims.rsqrt(t359) # t360: "cuda:0 f32[1, 512, 1]" - t362 = ltorch.mul(t352, t360) # t362: "cuda:0 f32[1, 512, 4096]" - # t361 = prims.broadcast_in_dim(t360, (1, 512, 4096), (0, 1, 2)) # t361: "cuda:0 f32[1, 512, 4096]" - # t362 = prims.mul(t352, t361) # t362: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t363 = ltorch.to(t362, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t363: "cuda:0 bf16[1, 512, 4096]" - # t363 = prims.convert_element_type(t362, dtypes.bfloat16) # t363: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t373 = ltorch.mul(t363, t_transformer_h_2_norm_1_weight) # t373: "cuda:0 bf16[1, 512, 4096]" - # t369 = prims.broadcast_in_dim(t_transformer_h_2_norm_1_weight, (1, 512, 4096), (2,)) # t369: "cuda:0 bf16[1, 512, 4096]" - # t370 = prims.convert_element_type(t363, dtypes.float32) # t370: "cuda:0 f32[1, 512, 4096]" - # t371 = prims.convert_element_type(t369, dtypes.float32) # t371: "cuda:0 f32[1, 512, 4096]" - # t372 = prims.mul(t370, t371) # t372: "cuda:0 f32[1, 512, 4096]" - # t373 = prims.convert_element_type(t372, dtypes.bfloat16) # t373: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t378 = ltorch.linear(t373, t_transformer_h_2_attn_attn_weight, None) # t378: "cuda:0 bf16[1, 512, 12288]" - # t378 = prims.linear(t373, t_transformer_h_2_attn_attn_weight, None) # t378: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t379 = ltorch.view(t378, 1, 512, 32, 3, 128) # t379: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t379 = ltorch.reshape(t378, (1, 512, 32, 3, 128)) # t379: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t379 = prims.reshape(t378, (1, 512, 32, 3, 128)) # t379: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t380 = ltorch.permute(t379, 0, 2, 3, 1, 4) # t380: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t380 = prims.transpose(t379, (0, 2, 3, 1, 4)) # t380: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t381, t382, t383) = ltorch.split(t380, (1, 1, 1), 2) - # t381 = prims.slice_prim(t380, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t381: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t382 = prims.slice_prim(t380, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t382: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t383 = prims.slice_prim(t380, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t383: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t384 = ltorch.reshape(t381, 1, -1, 512, 128) # t384: "cuda:0 bf16[1, 32, 512, 128]" - # t384 = prims.reshape(t381, (1, 32, 512, 128)) # t384: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t385 = ltorch.reshape(t382, 1, -1, 512, 128) # t385: "cuda:0 bf16[1, 32, 512, 128]" - # t385 = prims.reshape(t382, (1, 32, 512, 128)) # t385: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t386 = ltorch.reshape(t383, 1, -1, 512, 128) # t386: "cuda:0 bf16[1, 32, 512, 128]" - # t386 = prims.reshape(t383, (1, 32, 512, 128)) # t386: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t387 = ltorch.getitem(t384, (..., slice(None, 128, None))) # t387: "cuda:0 bf16[1, 32, 512, 128]" - # t387 = prims.slice_prim(t384, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t387: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t388 = ltorch.getitem(t387, (..., slice(None, 64, None))) # t388: "cuda:0 bf16[1, 32, 512, 64]" - # t388 = prims.slice_prim(t387, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t388: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t389 = ltorch.getitem(t387, (..., slice(64, None, None))) # t389: "cuda:0 bf16[1, 32, 512, 64]" - # t389 = prims.slice_prim(t387, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t389: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t392 = ltorch.neg(t389) # t392: "cuda:0 bf16[1, 32, 512, 64]" - # t390 = prims.convert_element_type(t389, dtypes.float32) # t390: "cuda:0 f32[1, 32, 512, 64]" - # t391 = prims.neg(t390) # t391: "cuda:0 f32[1, 32, 512, 64]" - # t392 = prims.convert_element_type(t391, dtypes.bfloat16) # t392: "cuda:0 bf16[1, 32, 512, 64]" - t393 = ltorch.cat((t392, t388), -1) # t393: "cuda:0 bf16[1, 32, 512, 128]" - # t393 = prims.cat((t392, t388), -1) # t393: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t396 = ltorch.mul(t387, cos) # t396: "cuda:0 f32[1, 32, 512, 128]" - # t394 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t394: "cuda:0 f32[1, 32, 512, 128]" - # t395 = prims.convert_element_type(t387, dtypes.float32) # t395: "cuda:0 f32[1, 32, 512, 128]" - # t396 = prims.mul(t395, t394) # t396: "cuda:0 f32[1, 32, 512, 128]" - t399 = ltorch.mul(t393, sin) # t399: "cuda:0 f32[1, 32, 512, 128]" - # t397 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t397: "cuda:0 f32[1, 32, 512, 128]" - # t398 = prims.convert_element_type(t393, dtypes.float32) # t398: "cuda:0 f32[1, 32, 512, 128]" - # t399 = prims.mul(t398, t397) # t399: "cuda:0 f32[1, 32, 512, 128]" - t400 = ltorch.add(t396, t399, alpha=None) # t400: "cuda:0 f32[1, 32, 512, 128]" - # t400 = prims.add(t396, t399) # t400: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t401 = ltorch.to(t400, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t401: "cuda:0 bf16[1, 32, 512, 128]" - # t401 = prims.convert_element_type(t400, dtypes.bfloat16) # t401: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t402 = ltorch.getitem(t385, (..., slice(None, 128, None))) # t402: "cuda:0 bf16[1, 32, 512, 128]" - # t402 = prims.slice_prim(t385, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t402: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t403 = ltorch.getitem(t402, (..., slice(None, 64, None))) # t403: "cuda:0 bf16[1, 32, 512, 64]" - # t403 = prims.slice_prim(t402, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t403: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t404 = ltorch.getitem(t402, (..., slice(64, None, None))) # t404: "cuda:0 bf16[1, 32, 512, 64]" - # t404 = prims.slice_prim(t402, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t404: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t407 = ltorch.neg(t404) # t407: "cuda:0 bf16[1, 32, 512, 64]" - # t405 = prims.convert_element_type(t404, dtypes.float32) # t405: "cuda:0 f32[1, 32, 512, 64]" - # t406 = prims.neg(t405) # t406: "cuda:0 f32[1, 32, 512, 64]" - # t407 = prims.convert_element_type(t406, dtypes.bfloat16) # t407: "cuda:0 bf16[1, 32, 512, 64]" - t408 = ltorch.cat((t407, t403), -1) # t408: "cuda:0 bf16[1, 32, 512, 128]" - # t408 = prims.cat((t407, t403), -1) # t408: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t411 = ltorch.mul(t402, cos) # t411: "cuda:0 f32[1, 32, 512, 128]" - # t409 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t409: "cuda:0 f32[1, 32, 512, 128]" - # t410 = prims.convert_element_type(t402, dtypes.float32) # t410: "cuda:0 f32[1, 32, 512, 128]" - # t411 = prims.mul(t410, t409) # t411: "cuda:0 f32[1, 32, 512, 128]" - t414 = ltorch.mul(t408, sin) # t414: "cuda:0 f32[1, 32, 512, 128]" - # t412 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t412: "cuda:0 f32[1, 32, 512, 128]" - # t413 = prims.convert_element_type(t408, dtypes.float32) # t413: "cuda:0 f32[1, 32, 512, 128]" - # t414 = prims.mul(t413, t412) # t414: "cuda:0 f32[1, 32, 512, 128]" - t415 = ltorch.add(t411, t414, alpha=None) # t415: "cuda:0 f32[1, 32, 512, 128]" - # t415 = prims.add(t411, t414) # t415: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t416 = ltorch.to(t415, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t416: "cuda:0 bf16[1, 32, 512, 128]" - # t416 = prims.convert_element_type(t415, dtypes.bfloat16) # t416: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t417 = ltorch.getitem(t384, (..., slice(128, None, None))) # t417: "cuda:0 bf16[1, 32, 512, 0]" - # t417 = prims.slice_prim(t384, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t417: "cuda:0 bf16[1, 32, 512, 0]" - t418 = ltorch.cat((t401, t417), -1) # t418: "cuda:0 bf16[1, 32, 512, 128]" - # t418 = prims.cat((t401, t417), -1) # t418: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t419 = ltorch.getitem(t385, (..., slice(128, None, None))) # t419: "cuda:0 bf16[1, 32, 512, 0]" - # t419 = prims.slice_prim(t385, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t419: "cuda:0 bf16[1, 32, 512, 0]" - t420 = ltorch.cat((t416, t419), -1) # t420: "cuda:0 bf16[1, 32, 512, 128]" - # t420 = prims.cat((t416, t419), -1) # t420: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t450 = ltorch.scaled_dot_product_attention(t418, t420, t386, None, 0.0, True, scale=0.08838834764831843) # t450: "cuda:0 bf16[1, 32, 512, 128]" - # t423 = ltorch.mul(t418, 0.29730177875068026) # t423: "cuda:0 bf16[1, 32, 512, 128]" - # t421 = prims.convert_element_type(t418, dtypes.float32) # t421: "cuda:0 f32[1, 32, 512, 128]" - # t422 = prims.mul(t421, 0.29730177875068026) # t422: "cuda:0 f32[1, 32, 512, 128]" - # t423 = prims.convert_element_type(t422, dtypes.bfloat16) # t423: "cuda:0 bf16[1, 32, 512, 128]" - # t424 = ltorch.transpose(t420, -2, -1) # t424: "cuda:0 bf16[1, 32, 128, 512]" - # t424 = prims.transpose(t420, (0, 1, 3, 2)) # t424: "cuda:0 bf16[1, 32, 128, 512]" - # t427 = ltorch.mul(t424, 0.29730177875068026) # t427: "cuda:0 bf16[1, 32, 128, 512]" - # t425 = prims.convert_element_type(t424, dtypes.float32) # t425: "cuda:0 f32[1, 32, 128, 512]" - # t426 = prims.mul(t425, 0.29730177875068026) # t426: "cuda:0 f32[1, 32, 128, 512]" - # t427 = prims.convert_element_type(t426, dtypes.bfloat16) # t427: "cuda:0 bf16[1, 32, 128, 512]" - # t428 = ltorch.matmul(t423, t427) # t428: "cuda:0 bf16[1, 32, 512, 512]" - # t428 = prims.matmul(t423, t427) # t428: "cuda:0 bf16[1, 32, 512, 512]" - # t438 = ltorch.tril(t428, 0, fill_value=-float('inf')) # t438: "cuda:0 bf16[1, 32, 512, 512]" - # t429 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t429: "cuda:0 i64[512]" - # t429 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t429: "cuda:0 i64[512]" - # t430 = ltorch.unsqueeze(t429, -1) # t430: "cuda:0 i64[512, 1]" - # t430 = prims.broadcast_in_dim(t429, [512, 1], [0]) # t430: "cuda:0 i64[512, 1]" - # t431 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t431: "cuda:0 i64[512]" - # t431 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t431: "cuda:0 i64[512]" - # t432 = ltorch.unsqueeze(t431, -2) # t432: "cuda:0 i64[1, 512]" - # t432 = prims.broadcast_in_dim(t431, [1, 512], [1]) # t432: "cuda:0 i64[1, 512]" - # t433 = ltorch.add(t430, 0, alpha=None) # t433: "cuda:0 i64[512, 1]" - # t433 = prims.add(t430, 0) # t433: "cuda:0 i64[512, 1]" - # t436 = ltorch.ge(t433, t432) # t436: "cuda:0 b8[512, 512]" - # t434 = prims.broadcast_in_dim(t433, (512, 512), (0, 1)) # t434: "cuda:0 i64[512, 512]" - # t435 = prims.broadcast_in_dim(t432, (512, 512), (0, 1)) # t435: "cuda:0 i64[512, 512]" - # t436 = prims.ge(t434, t435) # t436: "cuda:0 b8[512, 512]" - # t438 = ltorch.where(t436, t428, -float('inf')) # t438: "cuda:0 bf16[1, 32, 512, 512]" - # t437 = prims.broadcast_in_dim(t436, (1, 32, 512, 512), (2, 3)) # t437: "cuda:0 b8[1, 32, 512, 512]" - # t438 = prims.where(t437, t428, -float('inf')) # t438: "cuda:0 bf16[1, 32, 512, 512]" - # t449 = ltorch._softmax(t438, -1, dtype=None) # t449: "cuda:0 bf16[1, 32, 512, 512]" - # t439 = ltorch.to(t438, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t439: "cuda:0 f32[1, 32, 512, 512]" - # t439 = prims.convert_element_type(t438, dtypes.float32) # t439: "cuda:0 f32[1, 32, 512, 512]" - # t441 = ltorch.amax(t439, -1, True) # t441: "cuda:0 f32[1, 32, 512, 1]" - # t440 = prims.amax(t439, (3,)) # t440: "cuda:0 f32[1, 32, 512]" - # t441 = prims.broadcast_in_dim(t440, [1, 32, 512, 1], [0, 1, 2]) # t441: "cuda:0 f32[1, 32, 512, 1]" - # t443 = ltorch.sub(t439, t441, alpha=None) # t443: "cuda:0 f32[1, 32, 512, 512]" - # t442 = prims.broadcast_in_dim(t441, (1, 32, 512, 512), (0, 1, 2, 3)) # t442: "cuda:0 f32[1, 32, 512, 512]" - # t443 = prims.sub(t439, t442) # t443: "cuda:0 f32[1, 32, 512, 512]" - # t444 = ltorch.exp(t443) # t444: "cuda:0 f32[1, 32, 512, 512]" - # t444 = prims.exp(t443) # t444: "cuda:0 f32[1, 32, 512, 512]" - # t446 = ltorch.sum(t444, -1, True, dtype=None) # t446: "cuda:0 f32[1, 32, 512, 1]" - # t445 = prims.sum(t444, (3,)) # t445: "cuda:0 f32[1, 32, 512]" - # t446 = prims.broadcast_in_dim(t445, [1, 32, 512, 1], [0, 1, 2]) # t446: "cuda:0 f32[1, 32, 512, 1]" - # t448 = ltorch.true_divide(t444, t446) # t448: "cuda:0 f32[1, 32, 512, 512]" - # t447 = prims.broadcast_in_dim(t446, (1, 32, 512, 512), (0, 1, 2, 3)) # t447: "cuda:0 f32[1, 32, 512, 512]" - # t448 = prims.div(t444, t447) # t448: "cuda:0 f32[1, 32, 512, 512]" - # t449 = ltorch.to(t448, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t449: "cuda:0 bf16[1, 32, 512, 512]" - # t449 = prims.convert_element_type(t448, dtypes.bfloat16) # t449: "cuda:0 bf16[1, 32, 512, 512]" - # t450 = ltorch.matmul(t449, t386) # t450: "cuda:0 bf16[1, 32, 512, 128]" - # t450 = prims.matmul(t449, t386) # t450: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t451 = ltorch.transpose(t450, 1, 2) # t451: "cuda:0 bf16[1, 512, 32, 128]" - # t451 = prims.transpose(t450, (0, 2, 1, 3)) # t451: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t452 = ltorch.reshape(t451, 1, 512, 4096) # t452: "cuda:0 bf16[1, 512, 4096]" - # t452 = prims.reshape(t451, (1, 512, 4096)) # t452: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t456 = ltorch.linear(t452, t_transformer_h_2_attn_proj_weight, None) # t456: "cuda:0 bf16[1, 512, 4096]" - # t456 = prims.linear(t452, t_transformer_h_2_attn_proj_weight, None) # t456: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t460 = ltorch.add(t456, t350, alpha=None) # t460: "cuda:0 bf16[1, 512, 4096]" - # t457 = prims.convert_element_type(t456, dtypes.float32) # t457: "cuda:0 f32[1, 512, 4096]" - # t458 = prims.convert_element_type(t350, dtypes.float32) # t458: "cuda:0 f32[1, 512, 4096]" - # t459 = prims.add(t457, t458) # t459: "cuda:0 f32[1, 512, 4096]" - # t460 = prims.convert_element_type(t459, dtypes.bfloat16) # t460: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t461 = prims.convert_element_type(t460, dtypes.float32) # t461: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t462 = ltorch.mul(t461, t461) # t462: "cuda:0 f32[1, 512, 4096]" - # t462 = prims.mul(t461, t461) # t462: "cuda:0 f32[1, 512, 4096]" - t466 = ltorch.mean(t462, -1, True, dtype=None) # t466: "cuda:0 f32[1, 512, 1]" - # t464 = prims.sum(t462, (2,)) # t464: "cuda:0 f32[1, 512]" - # t465 = prims.broadcast_in_dim(t464, [1, 512, 1], [0, 1]) # t465: "cuda:0 f32[1, 512, 1]" - # t466 = ltorch.true_divide(t465, 4096) # t466: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t466 = prims.div(t465, 4096.0) # t466: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t468 = ltorch.add(t466, 1e-05, alpha=None) # t468: "cuda:0 f32[1, 512, 1]" - # t468 = prims.add(t466, 1e-05) # t468: "cuda:0 f32[1, 512, 1]" - t469 = ltorch.rsqrt(t468) # t469: "cuda:0 f32[1, 512, 1]" - # t469 = prims.rsqrt(t468) # t469: "cuda:0 f32[1, 512, 1]" - t471 = ltorch.mul(t461, t469) # t471: "cuda:0 f32[1, 512, 4096]" - # t470 = prims.broadcast_in_dim(t469, (1, 512, 4096), (0, 1, 2)) # t470: "cuda:0 f32[1, 512, 4096]" - # t471 = prims.mul(t461, t470) # t471: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t472 = ltorch.to(t471, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t472: "cuda:0 bf16[1, 512, 4096]" - # t472 = prims.convert_element_type(t471, dtypes.bfloat16) # t472: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t482 = ltorch.mul(t472, t_transformer_h_2_norm_2_weight) # t482: "cuda:0 bf16[1, 512, 4096]" - # t478 = prims.broadcast_in_dim(t_transformer_h_2_norm_2_weight, (1, 512, 4096), (2,)) # t478: "cuda:0 bf16[1, 512, 4096]" - # t479 = prims.convert_element_type(t472, dtypes.float32) # t479: "cuda:0 f32[1, 512, 4096]" - # t480 = prims.convert_element_type(t478, dtypes.float32) # t480: "cuda:0 f32[1, 512, 4096]" - # t481 = prims.mul(t479, t480) # t481: "cuda:0 f32[1, 512, 4096]" - # t482 = prims.convert_element_type(t481, dtypes.bfloat16) # t482: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t487 = ltorch.linear(t482, t_transformer_h_2_mlp_fc_1_weight, None) # t487: "cuda:0 bf16[1, 512, 11008]" - # t487 = prims.linear(t482, t_transformer_h_2_mlp_fc_1_weight, None) # t487: "cuda:0 bf16[1, 512, 11008]" - t491 = ltorch.linear(t482, t_transformer_h_2_mlp_fc_2_weight, None) # t491: "cuda:0 bf16[1, 512, 11008]" - # t491 = prims.linear(t482, t_transformer_h_2_mlp_fc_2_weight, None) # t491: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t501 = ltorch.silu(t487, False) # t501: "cuda:0 bf16[1, 512, 11008]" - # t492 = prims.convert_element_type(t487, dtypes.float32) # t492: "cuda:0 f32[1, 512, 11008]" - # t493 = prims.neg(t492) # t493: "cuda:0 f32[1, 512, 11008]" - # t494 = prims.exp(t493) # t494: "cuda:0 f32[1, 512, 11008]" - # t495 = prims.add(1.0, t494) # t495: "cuda:0 f32[1, 512, 11008]" - # t496 = prims.reciprocal(t495) # t496: "cuda:0 f32[1, 512, 11008]" - # t497 = prims.convert_element_type(t496, dtypes.bfloat16) # t497: "cuda:0 bf16[1, 512, 11008]" - # t498 = prims.convert_element_type(t487, dtypes.float32) # t498: "cuda:0 f32[1, 512, 11008]" - # t499 = prims.convert_element_type(t497, dtypes.float32) # t499: "cuda:0 f32[1, 512, 11008]" - # t500 = prims.mul(t498, t499) # t500: "cuda:0 f32[1, 512, 11008]" - # t501 = prims.convert_element_type(t500, dtypes.bfloat16) # t501: "cuda:0 bf16[1, 512, 11008]" - t505 = ltorch.mul(t501, t491) # t505: "cuda:0 bf16[1, 512, 11008]" - # t502 = prims.convert_element_type(t501, dtypes.float32) # t502: "cuda:0 f32[1, 512, 11008]" - # t503 = prims.convert_element_type(t491, dtypes.float32) # t503: "cuda:0 f32[1, 512, 11008]" - # t504 = prims.mul(t502, t503) # t504: "cuda:0 f32[1, 512, 11008]" - # t505 = prims.convert_element_type(t504, dtypes.bfloat16) # t505: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t509 = ltorch.linear(t505, t_transformer_h_2_mlp_proj_weight, None) # t509: "cuda:0 bf16[1, 512, 4096]" - # t509 = prims.linear(t505, t_transformer_h_2_mlp_proj_weight, None) # t509: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t513 = ltorch.add(t509, t460, alpha=None) # t513: "cuda:0 bf16[1, 512, 4096]" - # t510 = prims.convert_element_type(t509, dtypes.float32) # t510: "cuda:0 f32[1, 512, 4096]" - # t511 = prims.convert_element_type(t460, dtypes.float32) # t511: "cuda:0 f32[1, 512, 4096]" - # t512 = prims.add(t510, t511) # t512: "cuda:0 f32[1, 512, 4096]" - # t513 = prims.convert_element_type(t512, dtypes.bfloat16) # t513: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t515 = prims.convert_element_type(t513, dtypes.float32) # t515: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t516 = ltorch.mul(t515, t515) # t516: "cuda:0 f32[1, 512, 4096]" - # t516 = prims.mul(t515, t515) # t516: "cuda:0 f32[1, 512, 4096]" - t520 = ltorch.mean(t516, -1, True, dtype=None) # t520: "cuda:0 f32[1, 512, 1]" - # t518 = prims.sum(t516, (2,)) # t518: "cuda:0 f32[1, 512]" - # t519 = prims.broadcast_in_dim(t518, [1, 512, 1], [0, 1]) # t519: "cuda:0 f32[1, 512, 1]" - # t520 = ltorch.true_divide(t519, 4096) # t520: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t520 = prims.div(t519, 4096.0) # t520: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t522 = ltorch.add(t520, 1e-05, alpha=None) # t522: "cuda:0 f32[1, 512, 1]" - # t522 = prims.add(t520, 1e-05) # t522: "cuda:0 f32[1, 512, 1]" - t523 = ltorch.rsqrt(t522) # t523: "cuda:0 f32[1, 512, 1]" - # t523 = prims.rsqrt(t522) # t523: "cuda:0 f32[1, 512, 1]" - t525 = ltorch.mul(t515, t523) # t525: "cuda:0 f32[1, 512, 4096]" - # t524 = prims.broadcast_in_dim(t523, (1, 512, 4096), (0, 1, 2)) # t524: "cuda:0 f32[1, 512, 4096]" - # t525 = prims.mul(t515, t524) # t525: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t526 = ltorch.to(t525, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t526: "cuda:0 bf16[1, 512, 4096]" - # t526 = prims.convert_element_type(t525, dtypes.bfloat16) # t526: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t536 = ltorch.mul(t526, t_transformer_h_3_norm_1_weight) # t536: "cuda:0 bf16[1, 512, 4096]" - # t532 = prims.broadcast_in_dim(t_transformer_h_3_norm_1_weight, (1, 512, 4096), (2,)) # t532: "cuda:0 bf16[1, 512, 4096]" - # t533 = prims.convert_element_type(t526, dtypes.float32) # t533: "cuda:0 f32[1, 512, 4096]" - # t534 = prims.convert_element_type(t532, dtypes.float32) # t534: "cuda:0 f32[1, 512, 4096]" - # t535 = prims.mul(t533, t534) # t535: "cuda:0 f32[1, 512, 4096]" - # t536 = prims.convert_element_type(t535, dtypes.bfloat16) # t536: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t541 = ltorch.linear(t536, t_transformer_h_3_attn_attn_weight, None) # t541: "cuda:0 bf16[1, 512, 12288]" - # t541 = prims.linear(t536, t_transformer_h_3_attn_attn_weight, None) # t541: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t542 = ltorch.view(t541, 1, 512, 32, 3, 128) # t542: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t542 = ltorch.reshape(t541, (1, 512, 32, 3, 128)) # t542: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t542 = prims.reshape(t541, (1, 512, 32, 3, 128)) # t542: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t543 = ltorch.permute(t542, 0, 2, 3, 1, 4) # t543: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t543 = prims.transpose(t542, (0, 2, 3, 1, 4)) # t543: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t544, t545, t546) = ltorch.split(t543, (1, 1, 1), 2) - # t544 = prims.slice_prim(t543, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t544: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t545 = prims.slice_prim(t543, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t545: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t546 = prims.slice_prim(t543, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t546: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t547 = ltorch.reshape(t544, 1, -1, 512, 128) # t547: "cuda:0 bf16[1, 32, 512, 128]" - # t547 = prims.reshape(t544, (1, 32, 512, 128)) # t547: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t548 = ltorch.reshape(t545, 1, -1, 512, 128) # t548: "cuda:0 bf16[1, 32, 512, 128]" - # t548 = prims.reshape(t545, (1, 32, 512, 128)) # t548: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t549 = ltorch.reshape(t546, 1, -1, 512, 128) # t549: "cuda:0 bf16[1, 32, 512, 128]" - # t549 = prims.reshape(t546, (1, 32, 512, 128)) # t549: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t550 = ltorch.getitem(t547, (..., slice(None, 128, None))) # t550: "cuda:0 bf16[1, 32, 512, 128]" - # t550 = prims.slice_prim(t547, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t550: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t551 = ltorch.getitem(t550, (..., slice(None, 64, None))) # t551: "cuda:0 bf16[1, 32, 512, 64]" - # t551 = prims.slice_prim(t550, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t551: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t552 = ltorch.getitem(t550, (..., slice(64, None, None))) # t552: "cuda:0 bf16[1, 32, 512, 64]" - # t552 = prims.slice_prim(t550, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t552: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t555 = ltorch.neg(t552) # t555: "cuda:0 bf16[1, 32, 512, 64]" - # t553 = prims.convert_element_type(t552, dtypes.float32) # t553: "cuda:0 f32[1, 32, 512, 64]" - # t554 = prims.neg(t553) # t554: "cuda:0 f32[1, 32, 512, 64]" - # t555 = prims.convert_element_type(t554, dtypes.bfloat16) # t555: "cuda:0 bf16[1, 32, 512, 64]" - t556 = ltorch.cat((t555, t551), -1) # t556: "cuda:0 bf16[1, 32, 512, 128]" - # t556 = prims.cat((t555, t551), -1) # t556: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t559 = ltorch.mul(t550, cos) # t559: "cuda:0 f32[1, 32, 512, 128]" - # t557 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t557: "cuda:0 f32[1, 32, 512, 128]" - # t558 = prims.convert_element_type(t550, dtypes.float32) # t558: "cuda:0 f32[1, 32, 512, 128]" - # t559 = prims.mul(t558, t557) # t559: "cuda:0 f32[1, 32, 512, 128]" - t562 = ltorch.mul(t556, sin) # t562: "cuda:0 f32[1, 32, 512, 128]" - # t560 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t560: "cuda:0 f32[1, 32, 512, 128]" - # t561 = prims.convert_element_type(t556, dtypes.float32) # t561: "cuda:0 f32[1, 32, 512, 128]" - # t562 = prims.mul(t561, t560) # t562: "cuda:0 f32[1, 32, 512, 128]" - t563 = ltorch.add(t559, t562, alpha=None) # t563: "cuda:0 f32[1, 32, 512, 128]" - # t563 = prims.add(t559, t562) # t563: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t564 = ltorch.to(t563, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t564: "cuda:0 bf16[1, 32, 512, 128]" - # t564 = prims.convert_element_type(t563, dtypes.bfloat16) # t564: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t565 = ltorch.getitem(t548, (..., slice(None, 128, None))) # t565: "cuda:0 bf16[1, 32, 512, 128]" - # t565 = prims.slice_prim(t548, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t565: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t566 = ltorch.getitem(t565, (..., slice(None, 64, None))) # t566: "cuda:0 bf16[1, 32, 512, 64]" - # t566 = prims.slice_prim(t565, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t566: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t567 = ltorch.getitem(t565, (..., slice(64, None, None))) # t567: "cuda:0 bf16[1, 32, 512, 64]" - # t567 = prims.slice_prim(t565, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t567: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t570 = ltorch.neg(t567) # t570: "cuda:0 bf16[1, 32, 512, 64]" - # t568 = prims.convert_element_type(t567, dtypes.float32) # t568: "cuda:0 f32[1, 32, 512, 64]" - # t569 = prims.neg(t568) # t569: "cuda:0 f32[1, 32, 512, 64]" - # t570 = prims.convert_element_type(t569, dtypes.bfloat16) # t570: "cuda:0 bf16[1, 32, 512, 64]" - t571 = ltorch.cat((t570, t566), -1) # t571: "cuda:0 bf16[1, 32, 512, 128]" - # t571 = prims.cat((t570, t566), -1) # t571: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t574 = ltorch.mul(t565, cos) # t574: "cuda:0 f32[1, 32, 512, 128]" - # t572 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t572: "cuda:0 f32[1, 32, 512, 128]" - # t573 = prims.convert_element_type(t565, dtypes.float32) # t573: "cuda:0 f32[1, 32, 512, 128]" - # t574 = prims.mul(t573, t572) # t574: "cuda:0 f32[1, 32, 512, 128]" - t577 = ltorch.mul(t571, sin) # t577: "cuda:0 f32[1, 32, 512, 128]" - # t575 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t575: "cuda:0 f32[1, 32, 512, 128]" - # t576 = prims.convert_element_type(t571, dtypes.float32) # t576: "cuda:0 f32[1, 32, 512, 128]" - # t577 = prims.mul(t576, t575) # t577: "cuda:0 f32[1, 32, 512, 128]" - t578 = ltorch.add(t574, t577, alpha=None) # t578: "cuda:0 f32[1, 32, 512, 128]" - # t578 = prims.add(t574, t577) # t578: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t579 = ltorch.to(t578, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t579: "cuda:0 bf16[1, 32, 512, 128]" - # t579 = prims.convert_element_type(t578, dtypes.bfloat16) # t579: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t580 = ltorch.getitem(t547, (..., slice(128, None, None))) # t580: "cuda:0 bf16[1, 32, 512, 0]" - # t580 = prims.slice_prim(t547, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t580: "cuda:0 bf16[1, 32, 512, 0]" - t581 = ltorch.cat((t564, t580), -1) # t581: "cuda:0 bf16[1, 32, 512, 128]" - # t581 = prims.cat((t564, t580), -1) # t581: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t582 = ltorch.getitem(t548, (..., slice(128, None, None))) # t582: "cuda:0 bf16[1, 32, 512, 0]" - # t582 = prims.slice_prim(t548, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t582: "cuda:0 bf16[1, 32, 512, 0]" - t583 = ltorch.cat((t579, t582), -1) # t583: "cuda:0 bf16[1, 32, 512, 128]" - # t583 = prims.cat((t579, t582), -1) # t583: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t613 = ltorch.scaled_dot_product_attention(t581, t583, t549, None, 0.0, True, scale=0.08838834764831843) # t613: "cuda:0 bf16[1, 32, 512, 128]" - # t586 = ltorch.mul(t581, 0.29730177875068026) # t586: "cuda:0 bf16[1, 32, 512, 128]" - # t584 = prims.convert_element_type(t581, dtypes.float32) # t584: "cuda:0 f32[1, 32, 512, 128]" - # t585 = prims.mul(t584, 0.29730177875068026) # t585: "cuda:0 f32[1, 32, 512, 128]" - # t586 = prims.convert_element_type(t585, dtypes.bfloat16) # t586: "cuda:0 bf16[1, 32, 512, 128]" - # t587 = ltorch.transpose(t583, -2, -1) # t587: "cuda:0 bf16[1, 32, 128, 512]" - # t587 = prims.transpose(t583, (0, 1, 3, 2)) # t587: "cuda:0 bf16[1, 32, 128, 512]" - # t590 = ltorch.mul(t587, 0.29730177875068026) # t590: "cuda:0 bf16[1, 32, 128, 512]" - # t588 = prims.convert_element_type(t587, dtypes.float32) # t588: "cuda:0 f32[1, 32, 128, 512]" - # t589 = prims.mul(t588, 0.29730177875068026) # t589: "cuda:0 f32[1, 32, 128, 512]" - # t590 = prims.convert_element_type(t589, dtypes.bfloat16) # t590: "cuda:0 bf16[1, 32, 128, 512]" - # t591 = ltorch.matmul(t586, t590) # t591: "cuda:0 bf16[1, 32, 512, 512]" - # t591 = prims.matmul(t586, t590) # t591: "cuda:0 bf16[1, 32, 512, 512]" - # t601 = ltorch.tril(t591, 0, fill_value=-float('inf')) # t601: "cuda:0 bf16[1, 32, 512, 512]" - # t592 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t592: "cuda:0 i64[512]" - # t592 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t592: "cuda:0 i64[512]" - # t593 = ltorch.unsqueeze(t592, -1) # t593: "cuda:0 i64[512, 1]" - # t593 = prims.broadcast_in_dim(t592, [512, 1], [0]) # t593: "cuda:0 i64[512, 1]" - # t594 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t594: "cuda:0 i64[512]" - # t594 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t594: "cuda:0 i64[512]" - # t595 = ltorch.unsqueeze(t594, -2) # t595: "cuda:0 i64[1, 512]" - # t595 = prims.broadcast_in_dim(t594, [1, 512], [1]) # t595: "cuda:0 i64[1, 512]" - # t596 = ltorch.add(t593, 0, alpha=None) # t596: "cuda:0 i64[512, 1]" - # t596 = prims.add(t593, 0) # t596: "cuda:0 i64[512, 1]" - # t599 = ltorch.ge(t596, t595) # t599: "cuda:0 b8[512, 512]" - # t597 = prims.broadcast_in_dim(t596, (512, 512), (0, 1)) # t597: "cuda:0 i64[512, 512]" - # t598 = prims.broadcast_in_dim(t595, (512, 512), (0, 1)) # t598: "cuda:0 i64[512, 512]" - # t599 = prims.ge(t597, t598) # t599: "cuda:0 b8[512, 512]" - # t601 = ltorch.where(t599, t591, -float('inf')) # t601: "cuda:0 bf16[1, 32, 512, 512]" - # t600 = prims.broadcast_in_dim(t599, (1, 32, 512, 512), (2, 3)) # t600: "cuda:0 b8[1, 32, 512, 512]" - # t601 = prims.where(t600, t591, -float('inf')) # t601: "cuda:0 bf16[1, 32, 512, 512]" - # t612 = ltorch._softmax(t601, -1, dtype=None) # t612: "cuda:0 bf16[1, 32, 512, 512]" - # t602 = ltorch.to(t601, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t602: "cuda:0 f32[1, 32, 512, 512]" - # t602 = prims.convert_element_type(t601, dtypes.float32) # t602: "cuda:0 f32[1, 32, 512, 512]" - # t604 = ltorch.amax(t602, -1, True) # t604: "cuda:0 f32[1, 32, 512, 1]" - # t603 = prims.amax(t602, (3,)) # t603: "cuda:0 f32[1, 32, 512]" - # t604 = prims.broadcast_in_dim(t603, [1, 32, 512, 1], [0, 1, 2]) # t604: "cuda:0 f32[1, 32, 512, 1]" - # t606 = ltorch.sub(t602, t604, alpha=None) # t606: "cuda:0 f32[1, 32, 512, 512]" - # t605 = prims.broadcast_in_dim(t604, (1, 32, 512, 512), (0, 1, 2, 3)) # t605: "cuda:0 f32[1, 32, 512, 512]" - # t606 = prims.sub(t602, t605) # t606: "cuda:0 f32[1, 32, 512, 512]" - # t607 = ltorch.exp(t606) # t607: "cuda:0 f32[1, 32, 512, 512]" - # t607 = prims.exp(t606) # t607: "cuda:0 f32[1, 32, 512, 512]" - # t609 = ltorch.sum(t607, -1, True, dtype=None) # t609: "cuda:0 f32[1, 32, 512, 1]" - # t608 = prims.sum(t607, (3,)) # t608: "cuda:0 f32[1, 32, 512]" - # t609 = prims.broadcast_in_dim(t608, [1, 32, 512, 1], [0, 1, 2]) # t609: "cuda:0 f32[1, 32, 512, 1]" - # t611 = ltorch.true_divide(t607, t609) # t611: "cuda:0 f32[1, 32, 512, 512]" - # t610 = prims.broadcast_in_dim(t609, (1, 32, 512, 512), (0, 1, 2, 3)) # t610: "cuda:0 f32[1, 32, 512, 512]" - # t611 = prims.div(t607, t610) # t611: "cuda:0 f32[1, 32, 512, 512]" - # t612 = ltorch.to(t611, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t612: "cuda:0 bf16[1, 32, 512, 512]" - # t612 = prims.convert_element_type(t611, dtypes.bfloat16) # t612: "cuda:0 bf16[1, 32, 512, 512]" - # t613 = ltorch.matmul(t612, t549) # t613: "cuda:0 bf16[1, 32, 512, 128]" - # t613 = prims.matmul(t612, t549) # t613: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t614 = ltorch.transpose(t613, 1, 2) # t614: "cuda:0 bf16[1, 512, 32, 128]" - # t614 = prims.transpose(t613, (0, 2, 1, 3)) # t614: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t615 = ltorch.reshape(t614, 1, 512, 4096) # t615: "cuda:0 bf16[1, 512, 4096]" - # t615 = prims.reshape(t614, (1, 512, 4096)) # t615: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t619 = ltorch.linear(t615, t_transformer_h_3_attn_proj_weight, None) # t619: "cuda:0 bf16[1, 512, 4096]" - # t619 = prims.linear(t615, t_transformer_h_3_attn_proj_weight, None) # t619: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t623 = ltorch.add(t619, t513, alpha=None) # t623: "cuda:0 bf16[1, 512, 4096]" - # t620 = prims.convert_element_type(t619, dtypes.float32) # t620: "cuda:0 f32[1, 512, 4096]" - # t621 = prims.convert_element_type(t513, dtypes.float32) # t621: "cuda:0 f32[1, 512, 4096]" - # t622 = prims.add(t620, t621) # t622: "cuda:0 f32[1, 512, 4096]" - # t623 = prims.convert_element_type(t622, dtypes.bfloat16) # t623: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t624 = prims.convert_element_type(t623, dtypes.float32) # t624: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t625 = ltorch.mul(t624, t624) # t625: "cuda:0 f32[1, 512, 4096]" - # t625 = prims.mul(t624, t624) # t625: "cuda:0 f32[1, 512, 4096]" - t629 = ltorch.mean(t625, -1, True, dtype=None) # t629: "cuda:0 f32[1, 512, 1]" - # t627 = prims.sum(t625, (2,)) # t627: "cuda:0 f32[1, 512]" - # t628 = prims.broadcast_in_dim(t627, [1, 512, 1], [0, 1]) # t628: "cuda:0 f32[1, 512, 1]" - # t629 = ltorch.true_divide(t628, 4096) # t629: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t629 = prims.div(t628, 4096.0) # t629: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t631 = ltorch.add(t629, 1e-05, alpha=None) # t631: "cuda:0 f32[1, 512, 1]" - # t631 = prims.add(t629, 1e-05) # t631: "cuda:0 f32[1, 512, 1]" - t632 = ltorch.rsqrt(t631) # t632: "cuda:0 f32[1, 512, 1]" - # t632 = prims.rsqrt(t631) # t632: "cuda:0 f32[1, 512, 1]" - t634 = ltorch.mul(t624, t632) # t634: "cuda:0 f32[1, 512, 4096]" - # t633 = prims.broadcast_in_dim(t632, (1, 512, 4096), (0, 1, 2)) # t633: "cuda:0 f32[1, 512, 4096]" - # t634 = prims.mul(t624, t633) # t634: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t635 = ltorch.to(t634, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t635: "cuda:0 bf16[1, 512, 4096]" - # t635 = prims.convert_element_type(t634, dtypes.bfloat16) # t635: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t645 = ltorch.mul(t635, t_transformer_h_3_norm_2_weight) # t645: "cuda:0 bf16[1, 512, 4096]" - # t641 = prims.broadcast_in_dim(t_transformer_h_3_norm_2_weight, (1, 512, 4096), (2,)) # t641: "cuda:0 bf16[1, 512, 4096]" - # t642 = prims.convert_element_type(t635, dtypes.float32) # t642: "cuda:0 f32[1, 512, 4096]" - # t643 = prims.convert_element_type(t641, dtypes.float32) # t643: "cuda:0 f32[1, 512, 4096]" - # t644 = prims.mul(t642, t643) # t644: "cuda:0 f32[1, 512, 4096]" - # t645 = prims.convert_element_type(t644, dtypes.bfloat16) # t645: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t650 = ltorch.linear(t645, t_transformer_h_3_mlp_fc_1_weight, None) # t650: "cuda:0 bf16[1, 512, 11008]" - # t650 = prims.linear(t645, t_transformer_h_3_mlp_fc_1_weight, None) # t650: "cuda:0 bf16[1, 512, 11008]" - t654 = ltorch.linear(t645, t_transformer_h_3_mlp_fc_2_weight, None) # t654: "cuda:0 bf16[1, 512, 11008]" - # t654 = prims.linear(t645, t_transformer_h_3_mlp_fc_2_weight, None) # t654: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t664 = ltorch.silu(t650, False) # t664: "cuda:0 bf16[1, 512, 11008]" - # t655 = prims.convert_element_type(t650, dtypes.float32) # t655: "cuda:0 f32[1, 512, 11008]" - # t656 = prims.neg(t655) # t656: "cuda:0 f32[1, 512, 11008]" - # t657 = prims.exp(t656) # t657: "cuda:0 f32[1, 512, 11008]" - # t658 = prims.add(1.0, t657) # t658: "cuda:0 f32[1, 512, 11008]" - # t659 = prims.reciprocal(t658) # t659: "cuda:0 f32[1, 512, 11008]" - # t660 = prims.convert_element_type(t659, dtypes.bfloat16) # t660: "cuda:0 bf16[1, 512, 11008]" - # t661 = prims.convert_element_type(t650, dtypes.float32) # t661: "cuda:0 f32[1, 512, 11008]" - # t662 = prims.convert_element_type(t660, dtypes.float32) # t662: "cuda:0 f32[1, 512, 11008]" - # t663 = prims.mul(t661, t662) # t663: "cuda:0 f32[1, 512, 11008]" - # t664 = prims.convert_element_type(t663, dtypes.bfloat16) # t664: "cuda:0 bf16[1, 512, 11008]" - t668 = ltorch.mul(t664, t654) # t668: "cuda:0 bf16[1, 512, 11008]" - # t665 = prims.convert_element_type(t664, dtypes.float32) # t665: "cuda:0 f32[1, 512, 11008]" - # t666 = prims.convert_element_type(t654, dtypes.float32) # t666: "cuda:0 f32[1, 512, 11008]" - # t667 = prims.mul(t665, t666) # t667: "cuda:0 f32[1, 512, 11008]" - # t668 = prims.convert_element_type(t667, dtypes.bfloat16) # t668: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t672 = ltorch.linear(t668, t_transformer_h_3_mlp_proj_weight, None) # t672: "cuda:0 bf16[1, 512, 4096]" - # t672 = prims.linear(t668, t_transformer_h_3_mlp_proj_weight, None) # t672: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t676 = ltorch.add(t672, t623, alpha=None) # t676: "cuda:0 bf16[1, 512, 4096]" - # t673 = prims.convert_element_type(t672, dtypes.float32) # t673: "cuda:0 f32[1, 512, 4096]" - # t674 = prims.convert_element_type(t623, dtypes.float32) # t674: "cuda:0 f32[1, 512, 4096]" - # t675 = prims.add(t673, t674) # t675: "cuda:0 f32[1, 512, 4096]" - # t676 = prims.convert_element_type(t675, dtypes.bfloat16) # t676: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t678 = prims.convert_element_type(t676, dtypes.float32) # t678: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t679 = ltorch.mul(t678, t678) # t679: "cuda:0 f32[1, 512, 4096]" - # t679 = prims.mul(t678, t678) # t679: "cuda:0 f32[1, 512, 4096]" - t683 = ltorch.mean(t679, -1, True, dtype=None) # t683: "cuda:0 f32[1, 512, 1]" - # t681 = prims.sum(t679, (2,)) # t681: "cuda:0 f32[1, 512]" - # t682 = prims.broadcast_in_dim(t681, [1, 512, 1], [0, 1]) # t682: "cuda:0 f32[1, 512, 1]" - # t683 = ltorch.true_divide(t682, 4096) # t683: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t683 = prims.div(t682, 4096.0) # t683: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t685 = ltorch.add(t683, 1e-05, alpha=None) # t685: "cuda:0 f32[1, 512, 1]" - # t685 = prims.add(t683, 1e-05) # t685: "cuda:0 f32[1, 512, 1]" - t686 = ltorch.rsqrt(t685) # t686: "cuda:0 f32[1, 512, 1]" - # t686 = prims.rsqrt(t685) # t686: "cuda:0 f32[1, 512, 1]" - t688 = ltorch.mul(t678, t686) # t688: "cuda:0 f32[1, 512, 4096]" - # t687 = prims.broadcast_in_dim(t686, (1, 512, 4096), (0, 1, 2)) # t687: "cuda:0 f32[1, 512, 4096]" - # t688 = prims.mul(t678, t687) # t688: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t689 = ltorch.to(t688, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t689: "cuda:0 bf16[1, 512, 4096]" - # t689 = prims.convert_element_type(t688, dtypes.bfloat16) # t689: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t699 = ltorch.mul(t689, t_transformer_h_4_norm_1_weight) # t699: "cuda:0 bf16[1, 512, 4096]" - # t695 = prims.broadcast_in_dim(t_transformer_h_4_norm_1_weight, (1, 512, 4096), (2,)) # t695: "cuda:0 bf16[1, 512, 4096]" - # t696 = prims.convert_element_type(t689, dtypes.float32) # t696: "cuda:0 f32[1, 512, 4096]" - # t697 = prims.convert_element_type(t695, dtypes.float32) # t697: "cuda:0 f32[1, 512, 4096]" - # t698 = prims.mul(t696, t697) # t698: "cuda:0 f32[1, 512, 4096]" - # t699 = prims.convert_element_type(t698, dtypes.bfloat16) # t699: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t704 = ltorch.linear(t699, t_transformer_h_4_attn_attn_weight, None) # t704: "cuda:0 bf16[1, 512, 12288]" - # t704 = prims.linear(t699, t_transformer_h_4_attn_attn_weight, None) # t704: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t705 = ltorch.view(t704, 1, 512, 32, 3, 128) # t705: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t705 = ltorch.reshape(t704, (1, 512, 32, 3, 128)) # t705: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t705 = prims.reshape(t704, (1, 512, 32, 3, 128)) # t705: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t706 = ltorch.permute(t705, 0, 2, 3, 1, 4) # t706: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t706 = prims.transpose(t705, (0, 2, 3, 1, 4)) # t706: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t707, t708, t709) = ltorch.split(t706, (1, 1, 1), 2) - # t707 = prims.slice_prim(t706, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t707: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t708 = prims.slice_prim(t706, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t708: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t709 = prims.slice_prim(t706, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t709: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t710 = ltorch.reshape(t707, 1, -1, 512, 128) # t710: "cuda:0 bf16[1, 32, 512, 128]" - # t710 = prims.reshape(t707, (1, 32, 512, 128)) # t710: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t711 = ltorch.reshape(t708, 1, -1, 512, 128) # t711: "cuda:0 bf16[1, 32, 512, 128]" - # t711 = prims.reshape(t708, (1, 32, 512, 128)) # t711: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t712 = ltorch.reshape(t709, 1, -1, 512, 128) # t712: "cuda:0 bf16[1, 32, 512, 128]" - # t712 = prims.reshape(t709, (1, 32, 512, 128)) # t712: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t713 = ltorch.getitem(t710, (..., slice(None, 128, None))) # t713: "cuda:0 bf16[1, 32, 512, 128]" - # t713 = prims.slice_prim(t710, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t713: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t714 = ltorch.getitem(t713, (..., slice(None, 64, None))) # t714: "cuda:0 bf16[1, 32, 512, 64]" - # t714 = prims.slice_prim(t713, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t714: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t715 = ltorch.getitem(t713, (..., slice(64, None, None))) # t715: "cuda:0 bf16[1, 32, 512, 64]" - # t715 = prims.slice_prim(t713, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t715: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t718 = ltorch.neg(t715) # t718: "cuda:0 bf16[1, 32, 512, 64]" - # t716 = prims.convert_element_type(t715, dtypes.float32) # t716: "cuda:0 f32[1, 32, 512, 64]" - # t717 = prims.neg(t716) # t717: "cuda:0 f32[1, 32, 512, 64]" - # t718 = prims.convert_element_type(t717, dtypes.bfloat16) # t718: "cuda:0 bf16[1, 32, 512, 64]" - t719 = ltorch.cat((t718, t714), -1) # t719: "cuda:0 bf16[1, 32, 512, 128]" - # t719 = prims.cat((t718, t714), -1) # t719: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t722 = ltorch.mul(t713, cos) # t722: "cuda:0 f32[1, 32, 512, 128]" - # t720 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t720: "cuda:0 f32[1, 32, 512, 128]" - # t721 = prims.convert_element_type(t713, dtypes.float32) # t721: "cuda:0 f32[1, 32, 512, 128]" - # t722 = prims.mul(t721, t720) # t722: "cuda:0 f32[1, 32, 512, 128]" - t725 = ltorch.mul(t719, sin) # t725: "cuda:0 f32[1, 32, 512, 128]" - # t723 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t723: "cuda:0 f32[1, 32, 512, 128]" - # t724 = prims.convert_element_type(t719, dtypes.float32) # t724: "cuda:0 f32[1, 32, 512, 128]" - # t725 = prims.mul(t724, t723) # t725: "cuda:0 f32[1, 32, 512, 128]" - t726 = ltorch.add(t722, t725, alpha=None) # t726: "cuda:0 f32[1, 32, 512, 128]" - # t726 = prims.add(t722, t725) # t726: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t727 = ltorch.to(t726, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t727: "cuda:0 bf16[1, 32, 512, 128]" - # t727 = prims.convert_element_type(t726, dtypes.bfloat16) # t727: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t728 = ltorch.getitem(t711, (..., slice(None, 128, None))) # t728: "cuda:0 bf16[1, 32, 512, 128]" - # t728 = prims.slice_prim(t711, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t728: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t729 = ltorch.getitem(t728, (..., slice(None, 64, None))) # t729: "cuda:0 bf16[1, 32, 512, 64]" - # t729 = prims.slice_prim(t728, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t729: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t730 = ltorch.getitem(t728, (..., slice(64, None, None))) # t730: "cuda:0 bf16[1, 32, 512, 64]" - # t730 = prims.slice_prim(t728, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t730: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t733 = ltorch.neg(t730) # t733: "cuda:0 bf16[1, 32, 512, 64]" - # t731 = prims.convert_element_type(t730, dtypes.float32) # t731: "cuda:0 f32[1, 32, 512, 64]" - # t732 = prims.neg(t731) # t732: "cuda:0 f32[1, 32, 512, 64]" - # t733 = prims.convert_element_type(t732, dtypes.bfloat16) # t733: "cuda:0 bf16[1, 32, 512, 64]" - t734 = ltorch.cat((t733, t729), -1) # t734: "cuda:0 bf16[1, 32, 512, 128]" - # t734 = prims.cat((t733, t729), -1) # t734: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t737 = ltorch.mul(t728, cos) # t737: "cuda:0 f32[1, 32, 512, 128]" - # t735 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t735: "cuda:0 f32[1, 32, 512, 128]" - # t736 = prims.convert_element_type(t728, dtypes.float32) # t736: "cuda:0 f32[1, 32, 512, 128]" - # t737 = prims.mul(t736, t735) # t737: "cuda:0 f32[1, 32, 512, 128]" - t740 = ltorch.mul(t734, sin) # t740: "cuda:0 f32[1, 32, 512, 128]" - # t738 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t738: "cuda:0 f32[1, 32, 512, 128]" - # t739 = prims.convert_element_type(t734, dtypes.float32) # t739: "cuda:0 f32[1, 32, 512, 128]" - # t740 = prims.mul(t739, t738) # t740: "cuda:0 f32[1, 32, 512, 128]" - t741 = ltorch.add(t737, t740, alpha=None) # t741: "cuda:0 f32[1, 32, 512, 128]" - # t741 = prims.add(t737, t740) # t741: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t742 = ltorch.to(t741, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t742: "cuda:0 bf16[1, 32, 512, 128]" - # t742 = prims.convert_element_type(t741, dtypes.bfloat16) # t742: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t743 = ltorch.getitem(t710, (..., slice(128, None, None))) # t743: "cuda:0 bf16[1, 32, 512, 0]" - # t743 = prims.slice_prim(t710, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t743: "cuda:0 bf16[1, 32, 512, 0]" - t744 = ltorch.cat((t727, t743), -1) # t744: "cuda:0 bf16[1, 32, 512, 128]" - # t744 = prims.cat((t727, t743), -1) # t744: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t745 = ltorch.getitem(t711, (..., slice(128, None, None))) # t745: "cuda:0 bf16[1, 32, 512, 0]" - # t745 = prims.slice_prim(t711, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t745: "cuda:0 bf16[1, 32, 512, 0]" - t746 = ltorch.cat((t742, t745), -1) # t746: "cuda:0 bf16[1, 32, 512, 128]" - # t746 = prims.cat((t742, t745), -1) # t746: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t776 = ltorch.scaled_dot_product_attention(t744, t746, t712, None, 0.0, True, scale=0.08838834764831843) # t776: "cuda:0 bf16[1, 32, 512, 128]" - # t749 = ltorch.mul(t744, 0.29730177875068026) # t749: "cuda:0 bf16[1, 32, 512, 128]" - # t747 = prims.convert_element_type(t744, dtypes.float32) # t747: "cuda:0 f32[1, 32, 512, 128]" - # t748 = prims.mul(t747, 0.29730177875068026) # t748: "cuda:0 f32[1, 32, 512, 128]" - # t749 = prims.convert_element_type(t748, dtypes.bfloat16) # t749: "cuda:0 bf16[1, 32, 512, 128]" - # t750 = ltorch.transpose(t746, -2, -1) # t750: "cuda:0 bf16[1, 32, 128, 512]" - # t750 = prims.transpose(t746, (0, 1, 3, 2)) # t750: "cuda:0 bf16[1, 32, 128, 512]" - # t753 = ltorch.mul(t750, 0.29730177875068026) # t753: "cuda:0 bf16[1, 32, 128, 512]" - # t751 = prims.convert_element_type(t750, dtypes.float32) # t751: "cuda:0 f32[1, 32, 128, 512]" - # t752 = prims.mul(t751, 0.29730177875068026) # t752: "cuda:0 f32[1, 32, 128, 512]" - # t753 = prims.convert_element_type(t752, dtypes.bfloat16) # t753: "cuda:0 bf16[1, 32, 128, 512]" - # t754 = ltorch.matmul(t749, t753) # t754: "cuda:0 bf16[1, 32, 512, 512]" - # t754 = prims.matmul(t749, t753) # t754: "cuda:0 bf16[1, 32, 512, 512]" - # t764 = ltorch.tril(t754, 0, fill_value=-float('inf')) # t764: "cuda:0 bf16[1, 32, 512, 512]" - # t755 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t755: "cuda:0 i64[512]" - # t755 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t755: "cuda:0 i64[512]" - # t756 = ltorch.unsqueeze(t755, -1) # t756: "cuda:0 i64[512, 1]" - # t756 = prims.broadcast_in_dim(t755, [512, 1], [0]) # t756: "cuda:0 i64[512, 1]" - # t757 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t757: "cuda:0 i64[512]" - # t757 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t757: "cuda:0 i64[512]" - # t758 = ltorch.unsqueeze(t757, -2) # t758: "cuda:0 i64[1, 512]" - # t758 = prims.broadcast_in_dim(t757, [1, 512], [1]) # t758: "cuda:0 i64[1, 512]" - # t759 = ltorch.add(t756, 0, alpha=None) # t759: "cuda:0 i64[512, 1]" - # t759 = prims.add(t756, 0) # t759: "cuda:0 i64[512, 1]" - # t762 = ltorch.ge(t759, t758) # t762: "cuda:0 b8[512, 512]" - # t760 = prims.broadcast_in_dim(t759, (512, 512), (0, 1)) # t760: "cuda:0 i64[512, 512]" - # t761 = prims.broadcast_in_dim(t758, (512, 512), (0, 1)) # t761: "cuda:0 i64[512, 512]" - # t762 = prims.ge(t760, t761) # t762: "cuda:0 b8[512, 512]" - # t764 = ltorch.where(t762, t754, -float('inf')) # t764: "cuda:0 bf16[1, 32, 512, 512]" - # t763 = prims.broadcast_in_dim(t762, (1, 32, 512, 512), (2, 3)) # t763: "cuda:0 b8[1, 32, 512, 512]" - # t764 = prims.where(t763, t754, -float('inf')) # t764: "cuda:0 bf16[1, 32, 512, 512]" - # t775 = ltorch._softmax(t764, -1, dtype=None) # t775: "cuda:0 bf16[1, 32, 512, 512]" - # t765 = ltorch.to(t764, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t765: "cuda:0 f32[1, 32, 512, 512]" - # t765 = prims.convert_element_type(t764, dtypes.float32) # t765: "cuda:0 f32[1, 32, 512, 512]" - # t767 = ltorch.amax(t765, -1, True) # t767: "cuda:0 f32[1, 32, 512, 1]" - # t766 = prims.amax(t765, (3,)) # t766: "cuda:0 f32[1, 32, 512]" - # t767 = prims.broadcast_in_dim(t766, [1, 32, 512, 1], [0, 1, 2]) # t767: "cuda:0 f32[1, 32, 512, 1]" - # t769 = ltorch.sub(t765, t767, alpha=None) # t769: "cuda:0 f32[1, 32, 512, 512]" - # t768 = prims.broadcast_in_dim(t767, (1, 32, 512, 512), (0, 1, 2, 3)) # t768: "cuda:0 f32[1, 32, 512, 512]" - # t769 = prims.sub(t765, t768) # t769: "cuda:0 f32[1, 32, 512, 512]" - # t770 = ltorch.exp(t769) # t770: "cuda:0 f32[1, 32, 512, 512]" - # t770 = prims.exp(t769) # t770: "cuda:0 f32[1, 32, 512, 512]" - # t772 = ltorch.sum(t770, -1, True, dtype=None) # t772: "cuda:0 f32[1, 32, 512, 1]" - # t771 = prims.sum(t770, (3,)) # t771: "cuda:0 f32[1, 32, 512]" - # t772 = prims.broadcast_in_dim(t771, [1, 32, 512, 1], [0, 1, 2]) # t772: "cuda:0 f32[1, 32, 512, 1]" - # t774 = ltorch.true_divide(t770, t772) # t774: "cuda:0 f32[1, 32, 512, 512]" - # t773 = prims.broadcast_in_dim(t772, (1, 32, 512, 512), (0, 1, 2, 3)) # t773: "cuda:0 f32[1, 32, 512, 512]" - # t774 = prims.div(t770, t773) # t774: "cuda:0 f32[1, 32, 512, 512]" - # t775 = ltorch.to(t774, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t775: "cuda:0 bf16[1, 32, 512, 512]" - # t775 = prims.convert_element_type(t774, dtypes.bfloat16) # t775: "cuda:0 bf16[1, 32, 512, 512]" - # t776 = ltorch.matmul(t775, t712) # t776: "cuda:0 bf16[1, 32, 512, 128]" - # t776 = prims.matmul(t775, t712) # t776: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t777 = ltorch.transpose(t776, 1, 2) # t777: "cuda:0 bf16[1, 512, 32, 128]" - # t777 = prims.transpose(t776, (0, 2, 1, 3)) # t777: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t778 = ltorch.reshape(t777, 1, 512, 4096) # t778: "cuda:0 bf16[1, 512, 4096]" - # t778 = prims.reshape(t777, (1, 512, 4096)) # t778: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t782 = ltorch.linear(t778, t_transformer_h_4_attn_proj_weight, None) # t782: "cuda:0 bf16[1, 512, 4096]" - # t782 = prims.linear(t778, t_transformer_h_4_attn_proj_weight, None) # t782: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t786 = ltorch.add(t782, t676, alpha=None) # t786: "cuda:0 bf16[1, 512, 4096]" - # t783 = prims.convert_element_type(t782, dtypes.float32) # t783: "cuda:0 f32[1, 512, 4096]" - # t784 = prims.convert_element_type(t676, dtypes.float32) # t784: "cuda:0 f32[1, 512, 4096]" - # t785 = prims.add(t783, t784) # t785: "cuda:0 f32[1, 512, 4096]" - # t786 = prims.convert_element_type(t785, dtypes.bfloat16) # t786: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t787 = prims.convert_element_type(t786, dtypes.float32) # t787: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t788 = ltorch.mul(t787, t787) # t788: "cuda:0 f32[1, 512, 4096]" - # t788 = prims.mul(t787, t787) # t788: "cuda:0 f32[1, 512, 4096]" - t792 = ltorch.mean(t788, -1, True, dtype=None) # t792: "cuda:0 f32[1, 512, 1]" - # t790 = prims.sum(t788, (2,)) # t790: "cuda:0 f32[1, 512]" - # t791 = prims.broadcast_in_dim(t790, [1, 512, 1], [0, 1]) # t791: "cuda:0 f32[1, 512, 1]" - # t792 = ltorch.true_divide(t791, 4096) # t792: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t792 = prims.div(t791, 4096.0) # t792: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t794 = ltorch.add(t792, 1e-05, alpha=None) # t794: "cuda:0 f32[1, 512, 1]" - # t794 = prims.add(t792, 1e-05) # t794: "cuda:0 f32[1, 512, 1]" - t795 = ltorch.rsqrt(t794) # t795: "cuda:0 f32[1, 512, 1]" - # t795 = prims.rsqrt(t794) # t795: "cuda:0 f32[1, 512, 1]" - t797 = ltorch.mul(t787, t795) # t797: "cuda:0 f32[1, 512, 4096]" - # t796 = prims.broadcast_in_dim(t795, (1, 512, 4096), (0, 1, 2)) # t796: "cuda:0 f32[1, 512, 4096]" - # t797 = prims.mul(t787, t796) # t797: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t798 = ltorch.to(t797, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t798: "cuda:0 bf16[1, 512, 4096]" - # t798 = prims.convert_element_type(t797, dtypes.bfloat16) # t798: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t808 = ltorch.mul(t798, t_transformer_h_4_norm_2_weight) # t808: "cuda:0 bf16[1, 512, 4096]" - # t804 = prims.broadcast_in_dim(t_transformer_h_4_norm_2_weight, (1, 512, 4096), (2,)) # t804: "cuda:0 bf16[1, 512, 4096]" - # t805 = prims.convert_element_type(t798, dtypes.float32) # t805: "cuda:0 f32[1, 512, 4096]" - # t806 = prims.convert_element_type(t804, dtypes.float32) # t806: "cuda:0 f32[1, 512, 4096]" - # t807 = prims.mul(t805, t806) # t807: "cuda:0 f32[1, 512, 4096]" - # t808 = prims.convert_element_type(t807, dtypes.bfloat16) # t808: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t813 = ltorch.linear(t808, t_transformer_h_4_mlp_fc_1_weight, None) # t813: "cuda:0 bf16[1, 512, 11008]" - # t813 = prims.linear(t808, t_transformer_h_4_mlp_fc_1_weight, None) # t813: "cuda:0 bf16[1, 512, 11008]" - t817 = ltorch.linear(t808, t_transformer_h_4_mlp_fc_2_weight, None) # t817: "cuda:0 bf16[1, 512, 11008]" - # t817 = prims.linear(t808, t_transformer_h_4_mlp_fc_2_weight, None) # t817: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t827 = ltorch.silu(t813, False) # t827: "cuda:0 bf16[1, 512, 11008]" - # t818 = prims.convert_element_type(t813, dtypes.float32) # t818: "cuda:0 f32[1, 512, 11008]" - # t819 = prims.neg(t818) # t819: "cuda:0 f32[1, 512, 11008]" - # t820 = prims.exp(t819) # t820: "cuda:0 f32[1, 512, 11008]" - # t821 = prims.add(1.0, t820) # t821: "cuda:0 f32[1, 512, 11008]" - # t822 = prims.reciprocal(t821) # t822: "cuda:0 f32[1, 512, 11008]" - # t823 = prims.convert_element_type(t822, dtypes.bfloat16) # t823: "cuda:0 bf16[1, 512, 11008]" - # t824 = prims.convert_element_type(t813, dtypes.float32) # t824: "cuda:0 f32[1, 512, 11008]" - # t825 = prims.convert_element_type(t823, dtypes.float32) # t825: "cuda:0 f32[1, 512, 11008]" - # t826 = prims.mul(t824, t825) # t826: "cuda:0 f32[1, 512, 11008]" - # t827 = prims.convert_element_type(t826, dtypes.bfloat16) # t827: "cuda:0 bf16[1, 512, 11008]" - t831 = ltorch.mul(t827, t817) # t831: "cuda:0 bf16[1, 512, 11008]" - # t828 = prims.convert_element_type(t827, dtypes.float32) # t828: "cuda:0 f32[1, 512, 11008]" - # t829 = prims.convert_element_type(t817, dtypes.float32) # t829: "cuda:0 f32[1, 512, 11008]" - # t830 = prims.mul(t828, t829) # t830: "cuda:0 f32[1, 512, 11008]" - # t831 = prims.convert_element_type(t830, dtypes.bfloat16) # t831: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t835 = ltorch.linear(t831, t_transformer_h_4_mlp_proj_weight, None) # t835: "cuda:0 bf16[1, 512, 4096]" - # t835 = prims.linear(t831, t_transformer_h_4_mlp_proj_weight, None) # t835: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t839 = ltorch.add(t835, t786, alpha=None) # t839: "cuda:0 bf16[1, 512, 4096]" - # t836 = prims.convert_element_type(t835, dtypes.float32) # t836: "cuda:0 f32[1, 512, 4096]" - # t837 = prims.convert_element_type(t786, dtypes.float32) # t837: "cuda:0 f32[1, 512, 4096]" - # t838 = prims.add(t836, t837) # t838: "cuda:0 f32[1, 512, 4096]" - # t839 = prims.convert_element_type(t838, dtypes.bfloat16) # t839: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t841 = prims.convert_element_type(t839, dtypes.float32) # t841: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t842 = ltorch.mul(t841, t841) # t842: "cuda:0 f32[1, 512, 4096]" - # t842 = prims.mul(t841, t841) # t842: "cuda:0 f32[1, 512, 4096]" - t846 = ltorch.mean(t842, -1, True, dtype=None) # t846: "cuda:0 f32[1, 512, 1]" - # t844 = prims.sum(t842, (2,)) # t844: "cuda:0 f32[1, 512]" - # t845 = prims.broadcast_in_dim(t844, [1, 512, 1], [0, 1]) # t845: "cuda:0 f32[1, 512, 1]" - # t846 = ltorch.true_divide(t845, 4096) # t846: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t846 = prims.div(t845, 4096.0) # t846: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t848 = ltorch.add(t846, 1e-05, alpha=None) # t848: "cuda:0 f32[1, 512, 1]" - # t848 = prims.add(t846, 1e-05) # t848: "cuda:0 f32[1, 512, 1]" - t849 = ltorch.rsqrt(t848) # t849: "cuda:0 f32[1, 512, 1]" - # t849 = prims.rsqrt(t848) # t849: "cuda:0 f32[1, 512, 1]" - t851 = ltorch.mul(t841, t849) # t851: "cuda:0 f32[1, 512, 4096]" - # t850 = prims.broadcast_in_dim(t849, (1, 512, 4096), (0, 1, 2)) # t850: "cuda:0 f32[1, 512, 4096]" - # t851 = prims.mul(t841, t850) # t851: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t852 = ltorch.to(t851, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t852: "cuda:0 bf16[1, 512, 4096]" - # t852 = prims.convert_element_type(t851, dtypes.bfloat16) # t852: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t862 = ltorch.mul(t852, t_transformer_h_5_norm_1_weight) # t862: "cuda:0 bf16[1, 512, 4096]" - # t858 = prims.broadcast_in_dim(t_transformer_h_5_norm_1_weight, (1, 512, 4096), (2,)) # t858: "cuda:0 bf16[1, 512, 4096]" - # t859 = prims.convert_element_type(t852, dtypes.float32) # t859: "cuda:0 f32[1, 512, 4096]" - # t860 = prims.convert_element_type(t858, dtypes.float32) # t860: "cuda:0 f32[1, 512, 4096]" - # t861 = prims.mul(t859, t860) # t861: "cuda:0 f32[1, 512, 4096]" - # t862 = prims.convert_element_type(t861, dtypes.bfloat16) # t862: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t867 = ltorch.linear(t862, t_transformer_h_5_attn_attn_weight, None) # t867: "cuda:0 bf16[1, 512, 12288]" - # t867 = prims.linear(t862, t_transformer_h_5_attn_attn_weight, None) # t867: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t868 = ltorch.view(t867, 1, 512, 32, 3, 128) # t868: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t868 = ltorch.reshape(t867, (1, 512, 32, 3, 128)) # t868: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t868 = prims.reshape(t867, (1, 512, 32, 3, 128)) # t868: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t869 = ltorch.permute(t868, 0, 2, 3, 1, 4) # t869: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t869 = prims.transpose(t868, (0, 2, 3, 1, 4)) # t869: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t870, t871, t872) = ltorch.split(t869, (1, 1, 1), 2) - # t870 = prims.slice_prim(t869, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t870: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t871 = prims.slice_prim(t869, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t871: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t872 = prims.slice_prim(t869, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t872: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t873 = ltorch.reshape(t870, 1, -1, 512, 128) # t873: "cuda:0 bf16[1, 32, 512, 128]" - # t873 = prims.reshape(t870, (1, 32, 512, 128)) # t873: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t874 = ltorch.reshape(t871, 1, -1, 512, 128) # t874: "cuda:0 bf16[1, 32, 512, 128]" - # t874 = prims.reshape(t871, (1, 32, 512, 128)) # t874: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t875 = ltorch.reshape(t872, 1, -1, 512, 128) # t875: "cuda:0 bf16[1, 32, 512, 128]" - # t875 = prims.reshape(t872, (1, 32, 512, 128)) # t875: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t876 = ltorch.getitem(t873, (..., slice(None, 128, None))) # t876: "cuda:0 bf16[1, 32, 512, 128]" - # t876 = prims.slice_prim(t873, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t876: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t877 = ltorch.getitem(t876, (..., slice(None, 64, None))) # t877: "cuda:0 bf16[1, 32, 512, 64]" - # t877 = prims.slice_prim(t876, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t877: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t878 = ltorch.getitem(t876, (..., slice(64, None, None))) # t878: "cuda:0 bf16[1, 32, 512, 64]" - # t878 = prims.slice_prim(t876, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t878: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t881 = ltorch.neg(t878) # t881: "cuda:0 bf16[1, 32, 512, 64]" - # t879 = prims.convert_element_type(t878, dtypes.float32) # t879: "cuda:0 f32[1, 32, 512, 64]" - # t880 = prims.neg(t879) # t880: "cuda:0 f32[1, 32, 512, 64]" - # t881 = prims.convert_element_type(t880, dtypes.bfloat16) # t881: "cuda:0 bf16[1, 32, 512, 64]" - t882 = ltorch.cat((t881, t877), -1) # t882: "cuda:0 bf16[1, 32, 512, 128]" - # t882 = prims.cat((t881, t877), -1) # t882: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t885 = ltorch.mul(t876, cos) # t885: "cuda:0 f32[1, 32, 512, 128]" - # t883 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t883: "cuda:0 f32[1, 32, 512, 128]" - # t884 = prims.convert_element_type(t876, dtypes.float32) # t884: "cuda:0 f32[1, 32, 512, 128]" - # t885 = prims.mul(t884, t883) # t885: "cuda:0 f32[1, 32, 512, 128]" - t888 = ltorch.mul(t882, sin) # t888: "cuda:0 f32[1, 32, 512, 128]" - # t886 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t886: "cuda:0 f32[1, 32, 512, 128]" - # t887 = prims.convert_element_type(t882, dtypes.float32) # t887: "cuda:0 f32[1, 32, 512, 128]" - # t888 = prims.mul(t887, t886) # t888: "cuda:0 f32[1, 32, 512, 128]" - t889 = ltorch.add(t885, t888, alpha=None) # t889: "cuda:0 f32[1, 32, 512, 128]" - # t889 = prims.add(t885, t888) # t889: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t890 = ltorch.to(t889, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t890: "cuda:0 bf16[1, 32, 512, 128]" - # t890 = prims.convert_element_type(t889, dtypes.bfloat16) # t890: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t891 = ltorch.getitem(t874, (..., slice(None, 128, None))) # t891: "cuda:0 bf16[1, 32, 512, 128]" - # t891 = prims.slice_prim(t874, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t891: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t892 = ltorch.getitem(t891, (..., slice(None, 64, None))) # t892: "cuda:0 bf16[1, 32, 512, 64]" - # t892 = prims.slice_prim(t891, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t892: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t893 = ltorch.getitem(t891, (..., slice(64, None, None))) # t893: "cuda:0 bf16[1, 32, 512, 64]" - # t893 = prims.slice_prim(t891, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t893: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t896 = ltorch.neg(t893) # t896: "cuda:0 bf16[1, 32, 512, 64]" - # t894 = prims.convert_element_type(t893, dtypes.float32) # t894: "cuda:0 f32[1, 32, 512, 64]" - # t895 = prims.neg(t894) # t895: "cuda:0 f32[1, 32, 512, 64]" - # t896 = prims.convert_element_type(t895, dtypes.bfloat16) # t896: "cuda:0 bf16[1, 32, 512, 64]" - t897 = ltorch.cat((t896, t892), -1) # t897: "cuda:0 bf16[1, 32, 512, 128]" - # t897 = prims.cat((t896, t892), -1) # t897: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t900 = ltorch.mul(t891, cos) # t900: "cuda:0 f32[1, 32, 512, 128]" - # t898 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t898: "cuda:0 f32[1, 32, 512, 128]" - # t899 = prims.convert_element_type(t891, dtypes.float32) # t899: "cuda:0 f32[1, 32, 512, 128]" - # t900 = prims.mul(t899, t898) # t900: "cuda:0 f32[1, 32, 512, 128]" - t903 = ltorch.mul(t897, sin) # t903: "cuda:0 f32[1, 32, 512, 128]" - # t901 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t901: "cuda:0 f32[1, 32, 512, 128]" - # t902 = prims.convert_element_type(t897, dtypes.float32) # t902: "cuda:0 f32[1, 32, 512, 128]" - # t903 = prims.mul(t902, t901) # t903: "cuda:0 f32[1, 32, 512, 128]" - t904 = ltorch.add(t900, t903, alpha=None) # t904: "cuda:0 f32[1, 32, 512, 128]" - # t904 = prims.add(t900, t903) # t904: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t905 = ltorch.to(t904, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t905: "cuda:0 bf16[1, 32, 512, 128]" - # t905 = prims.convert_element_type(t904, dtypes.bfloat16) # t905: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t906 = ltorch.getitem(t873, (..., slice(128, None, None))) # t906: "cuda:0 bf16[1, 32, 512, 0]" - # t906 = prims.slice_prim(t873, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t906: "cuda:0 bf16[1, 32, 512, 0]" - t907 = ltorch.cat((t890, t906), -1) # t907: "cuda:0 bf16[1, 32, 512, 128]" - # t907 = prims.cat((t890, t906), -1) # t907: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t908 = ltorch.getitem(t874, (..., slice(128, None, None))) # t908: "cuda:0 bf16[1, 32, 512, 0]" - # t908 = prims.slice_prim(t874, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t908: "cuda:0 bf16[1, 32, 512, 0]" - t909 = ltorch.cat((t905, t908), -1) # t909: "cuda:0 bf16[1, 32, 512, 128]" - # t909 = prims.cat((t905, t908), -1) # t909: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t939 = ltorch.scaled_dot_product_attention(t907, t909, t875, None, 0.0, True, scale=0.08838834764831843) # t939: "cuda:0 bf16[1, 32, 512, 128]" - # t912 = ltorch.mul(t907, 0.29730177875068026) # t912: "cuda:0 bf16[1, 32, 512, 128]" - # t910 = prims.convert_element_type(t907, dtypes.float32) # t910: "cuda:0 f32[1, 32, 512, 128]" - # t911 = prims.mul(t910, 0.29730177875068026) # t911: "cuda:0 f32[1, 32, 512, 128]" - # t912 = prims.convert_element_type(t911, dtypes.bfloat16) # t912: "cuda:0 bf16[1, 32, 512, 128]" - # t913 = ltorch.transpose(t909, -2, -1) # t913: "cuda:0 bf16[1, 32, 128, 512]" - # t913 = prims.transpose(t909, (0, 1, 3, 2)) # t913: "cuda:0 bf16[1, 32, 128, 512]" - # t916 = ltorch.mul(t913, 0.29730177875068026) # t916: "cuda:0 bf16[1, 32, 128, 512]" - # t914 = prims.convert_element_type(t913, dtypes.float32) # t914: "cuda:0 f32[1, 32, 128, 512]" - # t915 = prims.mul(t914, 0.29730177875068026) # t915: "cuda:0 f32[1, 32, 128, 512]" - # t916 = prims.convert_element_type(t915, dtypes.bfloat16) # t916: "cuda:0 bf16[1, 32, 128, 512]" - # t917 = ltorch.matmul(t912, t916) # t917: "cuda:0 bf16[1, 32, 512, 512]" - # t917 = prims.matmul(t912, t916) # t917: "cuda:0 bf16[1, 32, 512, 512]" - # t927 = ltorch.tril(t917, 0, fill_value=-float('inf')) # t927: "cuda:0 bf16[1, 32, 512, 512]" - # t918 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t918: "cuda:0 i64[512]" - # t918 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t918: "cuda:0 i64[512]" - # t919 = ltorch.unsqueeze(t918, -1) # t919: "cuda:0 i64[512, 1]" - # t919 = prims.broadcast_in_dim(t918, [512, 1], [0]) # t919: "cuda:0 i64[512, 1]" - # t920 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t920: "cuda:0 i64[512]" - # t920 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t920: "cuda:0 i64[512]" - # t921 = ltorch.unsqueeze(t920, -2) # t921: "cuda:0 i64[1, 512]" - # t921 = prims.broadcast_in_dim(t920, [1, 512], [1]) # t921: "cuda:0 i64[1, 512]" - # t922 = ltorch.add(t919, 0, alpha=None) # t922: "cuda:0 i64[512, 1]" - # t922 = prims.add(t919, 0) # t922: "cuda:0 i64[512, 1]" - # t925 = ltorch.ge(t922, t921) # t925: "cuda:0 b8[512, 512]" - # t923 = prims.broadcast_in_dim(t922, (512, 512), (0, 1)) # t923: "cuda:0 i64[512, 512]" - # t924 = prims.broadcast_in_dim(t921, (512, 512), (0, 1)) # t924: "cuda:0 i64[512, 512]" - # t925 = prims.ge(t923, t924) # t925: "cuda:0 b8[512, 512]" - # t927 = ltorch.where(t925, t917, -float('inf')) # t927: "cuda:0 bf16[1, 32, 512, 512]" - # t926 = prims.broadcast_in_dim(t925, (1, 32, 512, 512), (2, 3)) # t926: "cuda:0 b8[1, 32, 512, 512]" - # t927 = prims.where(t926, t917, -float('inf')) # t927: "cuda:0 bf16[1, 32, 512, 512]" - # t938 = ltorch._softmax(t927, -1, dtype=None) # t938: "cuda:0 bf16[1, 32, 512, 512]" - # t928 = ltorch.to(t927, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t928: "cuda:0 f32[1, 32, 512, 512]" - # t928 = prims.convert_element_type(t927, dtypes.float32) # t928: "cuda:0 f32[1, 32, 512, 512]" - # t930 = ltorch.amax(t928, -1, True) # t930: "cuda:0 f32[1, 32, 512, 1]" - # t929 = prims.amax(t928, (3,)) # t929: "cuda:0 f32[1, 32, 512]" - # t930 = prims.broadcast_in_dim(t929, [1, 32, 512, 1], [0, 1, 2]) # t930: "cuda:0 f32[1, 32, 512, 1]" - # t932 = ltorch.sub(t928, t930, alpha=None) # t932: "cuda:0 f32[1, 32, 512, 512]" - # t931 = prims.broadcast_in_dim(t930, (1, 32, 512, 512), (0, 1, 2, 3)) # t931: "cuda:0 f32[1, 32, 512, 512]" - # t932 = prims.sub(t928, t931) # t932: "cuda:0 f32[1, 32, 512, 512]" - # t933 = ltorch.exp(t932) # t933: "cuda:0 f32[1, 32, 512, 512]" - # t933 = prims.exp(t932) # t933: "cuda:0 f32[1, 32, 512, 512]" - # t935 = ltorch.sum(t933, -1, True, dtype=None) # t935: "cuda:0 f32[1, 32, 512, 1]" - # t934 = prims.sum(t933, (3,)) # t934: "cuda:0 f32[1, 32, 512]" - # t935 = prims.broadcast_in_dim(t934, [1, 32, 512, 1], [0, 1, 2]) # t935: "cuda:0 f32[1, 32, 512, 1]" - # t937 = ltorch.true_divide(t933, t935) # t937: "cuda:0 f32[1, 32, 512, 512]" - # t936 = prims.broadcast_in_dim(t935, (1, 32, 512, 512), (0, 1, 2, 3)) # t936: "cuda:0 f32[1, 32, 512, 512]" - # t937 = prims.div(t933, t936) # t937: "cuda:0 f32[1, 32, 512, 512]" - # t938 = ltorch.to(t937, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t938: "cuda:0 bf16[1, 32, 512, 512]" - # t938 = prims.convert_element_type(t937, dtypes.bfloat16) # t938: "cuda:0 bf16[1, 32, 512, 512]" - # t939 = ltorch.matmul(t938, t875) # t939: "cuda:0 bf16[1, 32, 512, 128]" - # t939 = prims.matmul(t938, t875) # t939: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t940 = ltorch.transpose(t939, 1, 2) # t940: "cuda:0 bf16[1, 512, 32, 128]" - # t940 = prims.transpose(t939, (0, 2, 1, 3)) # t940: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t941 = ltorch.reshape(t940, 1, 512, 4096) # t941: "cuda:0 bf16[1, 512, 4096]" - # t941 = prims.reshape(t940, (1, 512, 4096)) # t941: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t945 = ltorch.linear(t941, t_transformer_h_5_attn_proj_weight, None) # t945: "cuda:0 bf16[1, 512, 4096]" - # t945 = prims.linear(t941, t_transformer_h_5_attn_proj_weight, None) # t945: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t949 = ltorch.add(t945, t839, alpha=None) # t949: "cuda:0 bf16[1, 512, 4096]" - # t946 = prims.convert_element_type(t945, dtypes.float32) # t946: "cuda:0 f32[1, 512, 4096]" - # t947 = prims.convert_element_type(t839, dtypes.float32) # t947: "cuda:0 f32[1, 512, 4096]" - # t948 = prims.add(t946, t947) # t948: "cuda:0 f32[1, 512, 4096]" - # t949 = prims.convert_element_type(t948, dtypes.bfloat16) # t949: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t950 = prims.convert_element_type(t949, dtypes.float32) # t950: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t951 = ltorch.mul(t950, t950) # t951: "cuda:0 f32[1, 512, 4096]" - # t951 = prims.mul(t950, t950) # t951: "cuda:0 f32[1, 512, 4096]" - t955 = ltorch.mean(t951, -1, True, dtype=None) # t955: "cuda:0 f32[1, 512, 1]" - # t953 = prims.sum(t951, (2,)) # t953: "cuda:0 f32[1, 512]" - # t954 = prims.broadcast_in_dim(t953, [1, 512, 1], [0, 1]) # t954: "cuda:0 f32[1, 512, 1]" - # t955 = ltorch.true_divide(t954, 4096) # t955: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t955 = prims.div(t954, 4096.0) # t955: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t957 = ltorch.add(t955, 1e-05, alpha=None) # t957: "cuda:0 f32[1, 512, 1]" - # t957 = prims.add(t955, 1e-05) # t957: "cuda:0 f32[1, 512, 1]" - t958 = ltorch.rsqrt(t957) # t958: "cuda:0 f32[1, 512, 1]" - # t958 = prims.rsqrt(t957) # t958: "cuda:0 f32[1, 512, 1]" - t960 = ltorch.mul(t950, t958) # t960: "cuda:0 f32[1, 512, 4096]" - # t959 = prims.broadcast_in_dim(t958, (1, 512, 4096), (0, 1, 2)) # t959: "cuda:0 f32[1, 512, 4096]" - # t960 = prims.mul(t950, t959) # t960: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t961 = ltorch.to(t960, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t961: "cuda:0 bf16[1, 512, 4096]" - # t961 = prims.convert_element_type(t960, dtypes.bfloat16) # t961: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t971 = ltorch.mul(t961, t_transformer_h_5_norm_2_weight) # t971: "cuda:0 bf16[1, 512, 4096]" - # t967 = prims.broadcast_in_dim(t_transformer_h_5_norm_2_weight, (1, 512, 4096), (2,)) # t967: "cuda:0 bf16[1, 512, 4096]" - # t968 = prims.convert_element_type(t961, dtypes.float32) # t968: "cuda:0 f32[1, 512, 4096]" - # t969 = prims.convert_element_type(t967, dtypes.float32) # t969: "cuda:0 f32[1, 512, 4096]" - # t970 = prims.mul(t968, t969) # t970: "cuda:0 f32[1, 512, 4096]" - # t971 = prims.convert_element_type(t970, dtypes.bfloat16) # t971: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t976 = ltorch.linear(t971, t_transformer_h_5_mlp_fc_1_weight, None) # t976: "cuda:0 bf16[1, 512, 11008]" - # t976 = prims.linear(t971, t_transformer_h_5_mlp_fc_1_weight, None) # t976: "cuda:0 bf16[1, 512, 11008]" - t980 = ltorch.linear(t971, t_transformer_h_5_mlp_fc_2_weight, None) # t980: "cuda:0 bf16[1, 512, 11008]" - # t980 = prims.linear(t971, t_transformer_h_5_mlp_fc_2_weight, None) # t980: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t990 = ltorch.silu(t976, False) # t990: "cuda:0 bf16[1, 512, 11008]" - # t981 = prims.convert_element_type(t976, dtypes.float32) # t981: "cuda:0 f32[1, 512, 11008]" - # t982 = prims.neg(t981) # t982: "cuda:0 f32[1, 512, 11008]" - # t983 = prims.exp(t982) # t983: "cuda:0 f32[1, 512, 11008]" - # t984 = prims.add(1.0, t983) # t984: "cuda:0 f32[1, 512, 11008]" - # t985 = prims.reciprocal(t984) # t985: "cuda:0 f32[1, 512, 11008]" - # t986 = prims.convert_element_type(t985, dtypes.bfloat16) # t986: "cuda:0 bf16[1, 512, 11008]" - # t987 = prims.convert_element_type(t976, dtypes.float32) # t987: "cuda:0 f32[1, 512, 11008]" - # t988 = prims.convert_element_type(t986, dtypes.float32) # t988: "cuda:0 f32[1, 512, 11008]" - # t989 = prims.mul(t987, t988) # t989: "cuda:0 f32[1, 512, 11008]" - # t990 = prims.convert_element_type(t989, dtypes.bfloat16) # t990: "cuda:0 bf16[1, 512, 11008]" - t994 = ltorch.mul(t990, t980) # t994: "cuda:0 bf16[1, 512, 11008]" - # t991 = prims.convert_element_type(t990, dtypes.float32) # t991: "cuda:0 f32[1, 512, 11008]" - # t992 = prims.convert_element_type(t980, dtypes.float32) # t992: "cuda:0 f32[1, 512, 11008]" - # t993 = prims.mul(t991, t992) # t993: "cuda:0 f32[1, 512, 11008]" - # t994 = prims.convert_element_type(t993, dtypes.bfloat16) # t994: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t998 = ltorch.linear(t994, t_transformer_h_5_mlp_proj_weight, None) # t998: "cuda:0 bf16[1, 512, 4096]" - # t998 = prims.linear(t994, t_transformer_h_5_mlp_proj_weight, None) # t998: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t1002 = ltorch.add(t998, t949, alpha=None) # t1002: "cuda:0 bf16[1, 512, 4096]" - # t999 = prims.convert_element_type(t998, dtypes.float32) # t999: "cuda:0 f32[1, 512, 4096]" - # t1000 = prims.convert_element_type(t949, dtypes.float32) # t1000: "cuda:0 f32[1, 512, 4096]" - # t1001 = prims.add(t999, t1000) # t1001: "cuda:0 f32[1, 512, 4096]" - # t1002 = prims.convert_element_type(t1001, dtypes.bfloat16) # t1002: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1004 = prims.convert_element_type(t1002, dtypes.float32) # t1004: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1005 = ltorch.mul(t1004, t1004) # t1005: "cuda:0 f32[1, 512, 4096]" - # t1005 = prims.mul(t1004, t1004) # t1005: "cuda:0 f32[1, 512, 4096]" - t1009 = ltorch.mean(t1005, -1, True, dtype=None) # t1009: "cuda:0 f32[1, 512, 1]" - # t1007 = prims.sum(t1005, (2,)) # t1007: "cuda:0 f32[1, 512]" - # t1008 = prims.broadcast_in_dim(t1007, [1, 512, 1], [0, 1]) # t1008: "cuda:0 f32[1, 512, 1]" - # t1009 = ltorch.true_divide(t1008, 4096) # t1009: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1009 = prims.div(t1008, 4096.0) # t1009: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1011 = ltorch.add(t1009, 1e-05, alpha=None) # t1011: "cuda:0 f32[1, 512, 1]" - # t1011 = prims.add(t1009, 1e-05) # t1011: "cuda:0 f32[1, 512, 1]" - t1012 = ltorch.rsqrt(t1011) # t1012: "cuda:0 f32[1, 512, 1]" - # t1012 = prims.rsqrt(t1011) # t1012: "cuda:0 f32[1, 512, 1]" - t1014 = ltorch.mul(t1004, t1012) # t1014: "cuda:0 f32[1, 512, 4096]" - # t1013 = prims.broadcast_in_dim(t1012, (1, 512, 4096), (0, 1, 2)) # t1013: "cuda:0 f32[1, 512, 4096]" - # t1014 = prims.mul(t1004, t1013) # t1014: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1015 = ltorch.to(t1014, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1015: "cuda:0 bf16[1, 512, 4096]" - # t1015 = prims.convert_element_type(t1014, dtypes.bfloat16) # t1015: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1025 = ltorch.mul(t1015, t_transformer_h_6_norm_1_weight) # t1025: "cuda:0 bf16[1, 512, 4096]" - # t1021 = prims.broadcast_in_dim(t_transformer_h_6_norm_1_weight, (1, 512, 4096), (2,)) # t1021: "cuda:0 bf16[1, 512, 4096]" - # t1022 = prims.convert_element_type(t1015, dtypes.float32) # t1022: "cuda:0 f32[1, 512, 4096]" - # t1023 = prims.convert_element_type(t1021, dtypes.float32) # t1023: "cuda:0 f32[1, 512, 4096]" - # t1024 = prims.mul(t1022, t1023) # t1024: "cuda:0 f32[1, 512, 4096]" - # t1025 = prims.convert_element_type(t1024, dtypes.bfloat16) # t1025: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1030 = ltorch.linear(t1025, t_transformer_h_6_attn_attn_weight, None) # t1030: "cuda:0 bf16[1, 512, 12288]" - # t1030 = prims.linear(t1025, t_transformer_h_6_attn_attn_weight, None) # t1030: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t1031 = ltorch.view(t1030, 1, 512, 32, 3, 128) # t1031: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1031 = ltorch.reshape(t1030, (1, 512, 32, 3, 128)) # t1031: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1031 = prims.reshape(t1030, (1, 512, 32, 3, 128)) # t1031: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t1032 = ltorch.permute(t1031, 0, 2, 3, 1, 4) # t1032: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t1032 = prims.transpose(t1031, (0, 2, 3, 1, 4)) # t1032: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t1033, t1034, t1035) = ltorch.split(t1032, (1, 1, 1), 2) - # t1033 = prims.slice_prim(t1032, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1033: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1034 = prims.slice_prim(t1032, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1034: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1035 = prims.slice_prim(t1032, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1035: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t1036 = ltorch.reshape(t1033, 1, -1, 512, 128) # t1036: "cuda:0 bf16[1, 32, 512, 128]" - # t1036 = prims.reshape(t1033, (1, 32, 512, 128)) # t1036: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t1037 = ltorch.reshape(t1034, 1, -1, 512, 128) # t1037: "cuda:0 bf16[1, 32, 512, 128]" - # t1037 = prims.reshape(t1034, (1, 32, 512, 128)) # t1037: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t1038 = ltorch.reshape(t1035, 1, -1, 512, 128) # t1038: "cuda:0 bf16[1, 32, 512, 128]" - # t1038 = prims.reshape(t1035, (1, 32, 512, 128)) # t1038: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t1039 = ltorch.getitem(t1036, (..., slice(None, 128, None))) # t1039: "cuda:0 bf16[1, 32, 512, 128]" - # t1039 = prims.slice_prim(t1036, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1039: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1040 = ltorch.getitem(t1039, (..., slice(None, 64, None))) # t1040: "cuda:0 bf16[1, 32, 512, 64]" - # t1040 = prims.slice_prim(t1039, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1040: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1041 = ltorch.getitem(t1039, (..., slice(64, None, None))) # t1041: "cuda:0 bf16[1, 32, 512, 64]" - # t1041 = prims.slice_prim(t1039, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1041: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1044 = ltorch.neg(t1041) # t1044: "cuda:0 bf16[1, 32, 512, 64]" - # t1042 = prims.convert_element_type(t1041, dtypes.float32) # t1042: "cuda:0 f32[1, 32, 512, 64]" - # t1043 = prims.neg(t1042) # t1043: "cuda:0 f32[1, 32, 512, 64]" - # t1044 = prims.convert_element_type(t1043, dtypes.bfloat16) # t1044: "cuda:0 bf16[1, 32, 512, 64]" - t1045 = ltorch.cat((t1044, t1040), -1) # t1045: "cuda:0 bf16[1, 32, 512, 128]" - # t1045 = prims.cat((t1044, t1040), -1) # t1045: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1048 = ltorch.mul(t1039, cos) # t1048: "cuda:0 f32[1, 32, 512, 128]" - # t1046 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1046: "cuda:0 f32[1, 32, 512, 128]" - # t1047 = prims.convert_element_type(t1039, dtypes.float32) # t1047: "cuda:0 f32[1, 32, 512, 128]" - # t1048 = prims.mul(t1047, t1046) # t1048: "cuda:0 f32[1, 32, 512, 128]" - t1051 = ltorch.mul(t1045, sin) # t1051: "cuda:0 f32[1, 32, 512, 128]" - # t1049 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1049: "cuda:0 f32[1, 32, 512, 128]" - # t1050 = prims.convert_element_type(t1045, dtypes.float32) # t1050: "cuda:0 f32[1, 32, 512, 128]" - # t1051 = prims.mul(t1050, t1049) # t1051: "cuda:0 f32[1, 32, 512, 128]" - t1052 = ltorch.add(t1048, t1051, alpha=None) # t1052: "cuda:0 f32[1, 32, 512, 128]" - # t1052 = prims.add(t1048, t1051) # t1052: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1053 = ltorch.to(t1052, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1053: "cuda:0 bf16[1, 32, 512, 128]" - # t1053 = prims.convert_element_type(t1052, dtypes.bfloat16) # t1053: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t1054 = ltorch.getitem(t1037, (..., slice(None, 128, None))) # t1054: "cuda:0 bf16[1, 32, 512, 128]" - # t1054 = prims.slice_prim(t1037, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1054: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1055 = ltorch.getitem(t1054, (..., slice(None, 64, None))) # t1055: "cuda:0 bf16[1, 32, 512, 64]" - # t1055 = prims.slice_prim(t1054, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1055: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1056 = ltorch.getitem(t1054, (..., slice(64, None, None))) # t1056: "cuda:0 bf16[1, 32, 512, 64]" - # t1056 = prims.slice_prim(t1054, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1056: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1059 = ltorch.neg(t1056) # t1059: "cuda:0 bf16[1, 32, 512, 64]" - # t1057 = prims.convert_element_type(t1056, dtypes.float32) # t1057: "cuda:0 f32[1, 32, 512, 64]" - # t1058 = prims.neg(t1057) # t1058: "cuda:0 f32[1, 32, 512, 64]" - # t1059 = prims.convert_element_type(t1058, dtypes.bfloat16) # t1059: "cuda:0 bf16[1, 32, 512, 64]" - t1060 = ltorch.cat((t1059, t1055), -1) # t1060: "cuda:0 bf16[1, 32, 512, 128]" - # t1060 = prims.cat((t1059, t1055), -1) # t1060: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1063 = ltorch.mul(t1054, cos) # t1063: "cuda:0 f32[1, 32, 512, 128]" - # t1061 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1061: "cuda:0 f32[1, 32, 512, 128]" - # t1062 = prims.convert_element_type(t1054, dtypes.float32) # t1062: "cuda:0 f32[1, 32, 512, 128]" - # t1063 = prims.mul(t1062, t1061) # t1063: "cuda:0 f32[1, 32, 512, 128]" - t1066 = ltorch.mul(t1060, sin) # t1066: "cuda:0 f32[1, 32, 512, 128]" - # t1064 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1064: "cuda:0 f32[1, 32, 512, 128]" - # t1065 = prims.convert_element_type(t1060, dtypes.float32) # t1065: "cuda:0 f32[1, 32, 512, 128]" - # t1066 = prims.mul(t1065, t1064) # t1066: "cuda:0 f32[1, 32, 512, 128]" - t1067 = ltorch.add(t1063, t1066, alpha=None) # t1067: "cuda:0 f32[1, 32, 512, 128]" - # t1067 = prims.add(t1063, t1066) # t1067: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1068 = ltorch.to(t1067, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1068: "cuda:0 bf16[1, 32, 512, 128]" - # t1068 = prims.convert_element_type(t1067, dtypes.bfloat16) # t1068: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t1069 = ltorch.getitem(t1036, (..., slice(128, None, None))) # t1069: "cuda:0 bf16[1, 32, 512, 0]" - # t1069 = prims.slice_prim(t1036, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1069: "cuda:0 bf16[1, 32, 512, 0]" - t1070 = ltorch.cat((t1053, t1069), -1) # t1070: "cuda:0 bf16[1, 32, 512, 128]" - # t1070 = prims.cat((t1053, t1069), -1) # t1070: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t1071 = ltorch.getitem(t1037, (..., slice(128, None, None))) # t1071: "cuda:0 bf16[1, 32, 512, 0]" - # t1071 = prims.slice_prim(t1037, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1071: "cuda:0 bf16[1, 32, 512, 0]" - t1072 = ltorch.cat((t1068, t1071), -1) # t1072: "cuda:0 bf16[1, 32, 512, 128]" - # t1072 = prims.cat((t1068, t1071), -1) # t1072: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t1102 = ltorch.scaled_dot_product_attention(t1070, t1072, t1038, None, 0.0, True, scale=0.08838834764831843) # t1102: "cuda:0 bf16[1, 32, 512, 128]" - # t1075 = ltorch.mul(t1070, 0.29730177875068026) # t1075: "cuda:0 bf16[1, 32, 512, 128]" - # t1073 = prims.convert_element_type(t1070, dtypes.float32) # t1073: "cuda:0 f32[1, 32, 512, 128]" - # t1074 = prims.mul(t1073, 0.29730177875068026) # t1074: "cuda:0 f32[1, 32, 512, 128]" - # t1075 = prims.convert_element_type(t1074, dtypes.bfloat16) # t1075: "cuda:0 bf16[1, 32, 512, 128]" - # t1076 = ltorch.transpose(t1072, -2, -1) # t1076: "cuda:0 bf16[1, 32, 128, 512]" - # t1076 = prims.transpose(t1072, (0, 1, 3, 2)) # t1076: "cuda:0 bf16[1, 32, 128, 512]" - # t1079 = ltorch.mul(t1076, 0.29730177875068026) # t1079: "cuda:0 bf16[1, 32, 128, 512]" - # t1077 = prims.convert_element_type(t1076, dtypes.float32) # t1077: "cuda:0 f32[1, 32, 128, 512]" - # t1078 = prims.mul(t1077, 0.29730177875068026) # t1078: "cuda:0 f32[1, 32, 128, 512]" - # t1079 = prims.convert_element_type(t1078, dtypes.bfloat16) # t1079: "cuda:0 bf16[1, 32, 128, 512]" - # t1080 = ltorch.matmul(t1075, t1079) # t1080: "cuda:0 bf16[1, 32, 512, 512]" - # t1080 = prims.matmul(t1075, t1079) # t1080: "cuda:0 bf16[1, 32, 512, 512]" - # t1090 = ltorch.tril(t1080, 0, fill_value=-float('inf')) # t1090: "cuda:0 bf16[1, 32, 512, 512]" - # t1081 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1081: "cuda:0 i64[512]" - # t1081 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1081: "cuda:0 i64[512]" - # t1082 = ltorch.unsqueeze(t1081, -1) # t1082: "cuda:0 i64[512, 1]" - # t1082 = prims.broadcast_in_dim(t1081, [512, 1], [0]) # t1082: "cuda:0 i64[512, 1]" - # t1083 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1083: "cuda:0 i64[512]" - # t1083 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1083: "cuda:0 i64[512]" - # t1084 = ltorch.unsqueeze(t1083, -2) # t1084: "cuda:0 i64[1, 512]" - # t1084 = prims.broadcast_in_dim(t1083, [1, 512], [1]) # t1084: "cuda:0 i64[1, 512]" - # t1085 = ltorch.add(t1082, 0, alpha=None) # t1085: "cuda:0 i64[512, 1]" - # t1085 = prims.add(t1082, 0) # t1085: "cuda:0 i64[512, 1]" - # t1088 = ltorch.ge(t1085, t1084) # t1088: "cuda:0 b8[512, 512]" - # t1086 = prims.broadcast_in_dim(t1085, (512, 512), (0, 1)) # t1086: "cuda:0 i64[512, 512]" - # t1087 = prims.broadcast_in_dim(t1084, (512, 512), (0, 1)) # t1087: "cuda:0 i64[512, 512]" - # t1088 = prims.ge(t1086, t1087) # t1088: "cuda:0 b8[512, 512]" - # t1090 = ltorch.where(t1088, t1080, -float('inf')) # t1090: "cuda:0 bf16[1, 32, 512, 512]" - # t1089 = prims.broadcast_in_dim(t1088, (1, 32, 512, 512), (2, 3)) # t1089: "cuda:0 b8[1, 32, 512, 512]" - # t1090 = prims.where(t1089, t1080, -float('inf')) # t1090: "cuda:0 bf16[1, 32, 512, 512]" - # t1101 = ltorch._softmax(t1090, -1, dtype=None) # t1101: "cuda:0 bf16[1, 32, 512, 512]" - # t1091 = ltorch.to(t1090, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t1091: "cuda:0 f32[1, 32, 512, 512]" - # t1091 = prims.convert_element_type(t1090, dtypes.float32) # t1091: "cuda:0 f32[1, 32, 512, 512]" - # t1093 = ltorch.amax(t1091, -1, True) # t1093: "cuda:0 f32[1, 32, 512, 1]" - # t1092 = prims.amax(t1091, (3,)) # t1092: "cuda:0 f32[1, 32, 512]" - # t1093 = prims.broadcast_in_dim(t1092, [1, 32, 512, 1], [0, 1, 2]) # t1093: "cuda:0 f32[1, 32, 512, 1]" - # t1095 = ltorch.sub(t1091, t1093, alpha=None) # t1095: "cuda:0 f32[1, 32, 512, 512]" - # t1094 = prims.broadcast_in_dim(t1093, (1, 32, 512, 512), (0, 1, 2, 3)) # t1094: "cuda:0 f32[1, 32, 512, 512]" - # t1095 = prims.sub(t1091, t1094) # t1095: "cuda:0 f32[1, 32, 512, 512]" - # t1096 = ltorch.exp(t1095) # t1096: "cuda:0 f32[1, 32, 512, 512]" - # t1096 = prims.exp(t1095) # t1096: "cuda:0 f32[1, 32, 512, 512]" - # t1098 = ltorch.sum(t1096, -1, True, dtype=None) # t1098: "cuda:0 f32[1, 32, 512, 1]" - # t1097 = prims.sum(t1096, (3,)) # t1097: "cuda:0 f32[1, 32, 512]" - # t1098 = prims.broadcast_in_dim(t1097, [1, 32, 512, 1], [0, 1, 2]) # t1098: "cuda:0 f32[1, 32, 512, 1]" - # t1100 = ltorch.true_divide(t1096, t1098) # t1100: "cuda:0 f32[1, 32, 512, 512]" - # t1099 = prims.broadcast_in_dim(t1098, (1, 32, 512, 512), (0, 1, 2, 3)) # t1099: "cuda:0 f32[1, 32, 512, 512]" - # t1100 = prims.div(t1096, t1099) # t1100: "cuda:0 f32[1, 32, 512, 512]" - # t1101 = ltorch.to(t1100, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t1101: "cuda:0 bf16[1, 32, 512, 512]" - # t1101 = prims.convert_element_type(t1100, dtypes.bfloat16) # t1101: "cuda:0 bf16[1, 32, 512, 512]" - # t1102 = ltorch.matmul(t1101, t1038) # t1102: "cuda:0 bf16[1, 32, 512, 128]" - # t1102 = prims.matmul(t1101, t1038) # t1102: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t1103 = ltorch.transpose(t1102, 1, 2) # t1103: "cuda:0 bf16[1, 512, 32, 128]" - # t1103 = prims.transpose(t1102, (0, 2, 1, 3)) # t1103: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t1104 = ltorch.reshape(t1103, 1, 512, 4096) # t1104: "cuda:0 bf16[1, 512, 4096]" - # t1104 = prims.reshape(t1103, (1, 512, 4096)) # t1104: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1108 = ltorch.linear(t1104, t_transformer_h_6_attn_proj_weight, None) # t1108: "cuda:0 bf16[1, 512, 4096]" - # t1108 = prims.linear(t1104, t_transformer_h_6_attn_proj_weight, None) # t1108: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t1112 = ltorch.add(t1108, t1002, alpha=None) # t1112: "cuda:0 bf16[1, 512, 4096]" - # t1109 = prims.convert_element_type(t1108, dtypes.float32) # t1109: "cuda:0 f32[1, 512, 4096]" - # t1110 = prims.convert_element_type(t1002, dtypes.float32) # t1110: "cuda:0 f32[1, 512, 4096]" - # t1111 = prims.add(t1109, t1110) # t1111: "cuda:0 f32[1, 512, 4096]" - # t1112 = prims.convert_element_type(t1111, dtypes.bfloat16) # t1112: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1113 = prims.convert_element_type(t1112, dtypes.float32) # t1113: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1114 = ltorch.mul(t1113, t1113) # t1114: "cuda:0 f32[1, 512, 4096]" - # t1114 = prims.mul(t1113, t1113) # t1114: "cuda:0 f32[1, 512, 4096]" - t1118 = ltorch.mean(t1114, -1, True, dtype=None) # t1118: "cuda:0 f32[1, 512, 1]" - # t1116 = prims.sum(t1114, (2,)) # t1116: "cuda:0 f32[1, 512]" - # t1117 = prims.broadcast_in_dim(t1116, [1, 512, 1], [0, 1]) # t1117: "cuda:0 f32[1, 512, 1]" - # t1118 = ltorch.true_divide(t1117, 4096) # t1118: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1118 = prims.div(t1117, 4096.0) # t1118: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1120 = ltorch.add(t1118, 1e-05, alpha=None) # t1120: "cuda:0 f32[1, 512, 1]" - # t1120 = prims.add(t1118, 1e-05) # t1120: "cuda:0 f32[1, 512, 1]" - t1121 = ltorch.rsqrt(t1120) # t1121: "cuda:0 f32[1, 512, 1]" - # t1121 = prims.rsqrt(t1120) # t1121: "cuda:0 f32[1, 512, 1]" - t1123 = ltorch.mul(t1113, t1121) # t1123: "cuda:0 f32[1, 512, 4096]" - # t1122 = prims.broadcast_in_dim(t1121, (1, 512, 4096), (0, 1, 2)) # t1122: "cuda:0 f32[1, 512, 4096]" - # t1123 = prims.mul(t1113, t1122) # t1123: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1124 = ltorch.to(t1123, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1124: "cuda:0 bf16[1, 512, 4096]" - # t1124 = prims.convert_element_type(t1123, dtypes.bfloat16) # t1124: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1134 = ltorch.mul(t1124, t_transformer_h_6_norm_2_weight) # t1134: "cuda:0 bf16[1, 512, 4096]" - # t1130 = prims.broadcast_in_dim(t_transformer_h_6_norm_2_weight, (1, 512, 4096), (2,)) # t1130: "cuda:0 bf16[1, 512, 4096]" - # t1131 = prims.convert_element_type(t1124, dtypes.float32) # t1131: "cuda:0 f32[1, 512, 4096]" - # t1132 = prims.convert_element_type(t1130, dtypes.float32) # t1132: "cuda:0 f32[1, 512, 4096]" - # t1133 = prims.mul(t1131, t1132) # t1133: "cuda:0 f32[1, 512, 4096]" - # t1134 = prims.convert_element_type(t1133, dtypes.bfloat16) # t1134: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1139 = ltorch.linear(t1134, t_transformer_h_6_mlp_fc_1_weight, None) # t1139: "cuda:0 bf16[1, 512, 11008]" - # t1139 = prims.linear(t1134, t_transformer_h_6_mlp_fc_1_weight, None) # t1139: "cuda:0 bf16[1, 512, 11008]" - t1143 = ltorch.linear(t1134, t_transformer_h_6_mlp_fc_2_weight, None) # t1143: "cuda:0 bf16[1, 512, 11008]" - # t1143 = prims.linear(t1134, t_transformer_h_6_mlp_fc_2_weight, None) # t1143: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t1153 = ltorch.silu(t1139, False) # t1153: "cuda:0 bf16[1, 512, 11008]" - # t1144 = prims.convert_element_type(t1139, dtypes.float32) # t1144: "cuda:0 f32[1, 512, 11008]" - # t1145 = prims.neg(t1144) # t1145: "cuda:0 f32[1, 512, 11008]" - # t1146 = prims.exp(t1145) # t1146: "cuda:0 f32[1, 512, 11008]" - # t1147 = prims.add(1.0, t1146) # t1147: "cuda:0 f32[1, 512, 11008]" - # t1148 = prims.reciprocal(t1147) # t1148: "cuda:0 f32[1, 512, 11008]" - # t1149 = prims.convert_element_type(t1148, dtypes.bfloat16) # t1149: "cuda:0 bf16[1, 512, 11008]" - # t1150 = prims.convert_element_type(t1139, dtypes.float32) # t1150: "cuda:0 f32[1, 512, 11008]" - # t1151 = prims.convert_element_type(t1149, dtypes.float32) # t1151: "cuda:0 f32[1, 512, 11008]" - # t1152 = prims.mul(t1150, t1151) # t1152: "cuda:0 f32[1, 512, 11008]" - # t1153 = prims.convert_element_type(t1152, dtypes.bfloat16) # t1153: "cuda:0 bf16[1, 512, 11008]" - t1157 = ltorch.mul(t1153, t1143) # t1157: "cuda:0 bf16[1, 512, 11008]" - # t1154 = prims.convert_element_type(t1153, dtypes.float32) # t1154: "cuda:0 f32[1, 512, 11008]" - # t1155 = prims.convert_element_type(t1143, dtypes.float32) # t1155: "cuda:0 f32[1, 512, 11008]" - # t1156 = prims.mul(t1154, t1155) # t1156: "cuda:0 f32[1, 512, 11008]" - # t1157 = prims.convert_element_type(t1156, dtypes.bfloat16) # t1157: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1161 = ltorch.linear(t1157, t_transformer_h_6_mlp_proj_weight, None) # t1161: "cuda:0 bf16[1, 512, 4096]" - # t1161 = prims.linear(t1157, t_transformer_h_6_mlp_proj_weight, None) # t1161: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t1165 = ltorch.add(t1161, t1112, alpha=None) # t1165: "cuda:0 bf16[1, 512, 4096]" - # t1162 = prims.convert_element_type(t1161, dtypes.float32) # t1162: "cuda:0 f32[1, 512, 4096]" - # t1163 = prims.convert_element_type(t1112, dtypes.float32) # t1163: "cuda:0 f32[1, 512, 4096]" - # t1164 = prims.add(t1162, t1163) # t1164: "cuda:0 f32[1, 512, 4096]" - # t1165 = prims.convert_element_type(t1164, dtypes.bfloat16) # t1165: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1167 = prims.convert_element_type(t1165, dtypes.float32) # t1167: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1168 = ltorch.mul(t1167, t1167) # t1168: "cuda:0 f32[1, 512, 4096]" - # t1168 = prims.mul(t1167, t1167) # t1168: "cuda:0 f32[1, 512, 4096]" - t1172 = ltorch.mean(t1168, -1, True, dtype=None) # t1172: "cuda:0 f32[1, 512, 1]" - # t1170 = prims.sum(t1168, (2,)) # t1170: "cuda:0 f32[1, 512]" - # t1171 = prims.broadcast_in_dim(t1170, [1, 512, 1], [0, 1]) # t1171: "cuda:0 f32[1, 512, 1]" - # t1172 = ltorch.true_divide(t1171, 4096) # t1172: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1172 = prims.div(t1171, 4096.0) # t1172: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1174 = ltorch.add(t1172, 1e-05, alpha=None) # t1174: "cuda:0 f32[1, 512, 1]" - # t1174 = prims.add(t1172, 1e-05) # t1174: "cuda:0 f32[1, 512, 1]" - t1175 = ltorch.rsqrt(t1174) # t1175: "cuda:0 f32[1, 512, 1]" - # t1175 = prims.rsqrt(t1174) # t1175: "cuda:0 f32[1, 512, 1]" - t1177 = ltorch.mul(t1167, t1175) # t1177: "cuda:0 f32[1, 512, 4096]" - # t1176 = prims.broadcast_in_dim(t1175, (1, 512, 4096), (0, 1, 2)) # t1176: "cuda:0 f32[1, 512, 4096]" - # t1177 = prims.mul(t1167, t1176) # t1177: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1178 = ltorch.to(t1177, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1178: "cuda:0 bf16[1, 512, 4096]" - # t1178 = prims.convert_element_type(t1177, dtypes.bfloat16) # t1178: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1188 = ltorch.mul(t1178, t_transformer_h_7_norm_1_weight) # t1188: "cuda:0 bf16[1, 512, 4096]" - # t1184 = prims.broadcast_in_dim(t_transformer_h_7_norm_1_weight, (1, 512, 4096), (2,)) # t1184: "cuda:0 bf16[1, 512, 4096]" - # t1185 = prims.convert_element_type(t1178, dtypes.float32) # t1185: "cuda:0 f32[1, 512, 4096]" - # t1186 = prims.convert_element_type(t1184, dtypes.float32) # t1186: "cuda:0 f32[1, 512, 4096]" - # t1187 = prims.mul(t1185, t1186) # t1187: "cuda:0 f32[1, 512, 4096]" - # t1188 = prims.convert_element_type(t1187, dtypes.bfloat16) # t1188: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1193 = ltorch.linear(t1188, t_transformer_h_7_attn_attn_weight, None) # t1193: "cuda:0 bf16[1, 512, 12288]" - # t1193 = prims.linear(t1188, t_transformer_h_7_attn_attn_weight, None) # t1193: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t1194 = ltorch.view(t1193, 1, 512, 32, 3, 128) # t1194: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1194 = ltorch.reshape(t1193, (1, 512, 32, 3, 128)) # t1194: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1194 = prims.reshape(t1193, (1, 512, 32, 3, 128)) # t1194: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t1195 = ltorch.permute(t1194, 0, 2, 3, 1, 4) # t1195: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t1195 = prims.transpose(t1194, (0, 2, 3, 1, 4)) # t1195: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t1196, t1197, t1198) = ltorch.split(t1195, (1, 1, 1), 2) - # t1196 = prims.slice_prim(t1195, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1196: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1197 = prims.slice_prim(t1195, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1197: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1198 = prims.slice_prim(t1195, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1198: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t1199 = ltorch.reshape(t1196, 1, -1, 512, 128) # t1199: "cuda:0 bf16[1, 32, 512, 128]" - # t1199 = prims.reshape(t1196, (1, 32, 512, 128)) # t1199: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t1200 = ltorch.reshape(t1197, 1, -1, 512, 128) # t1200: "cuda:0 bf16[1, 32, 512, 128]" - # t1200 = prims.reshape(t1197, (1, 32, 512, 128)) # t1200: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t1201 = ltorch.reshape(t1198, 1, -1, 512, 128) # t1201: "cuda:0 bf16[1, 32, 512, 128]" - # t1201 = prims.reshape(t1198, (1, 32, 512, 128)) # t1201: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t1202 = ltorch.getitem(t1199, (..., slice(None, 128, None))) # t1202: "cuda:0 bf16[1, 32, 512, 128]" - # t1202 = prims.slice_prim(t1199, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1202: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1203 = ltorch.getitem(t1202, (..., slice(None, 64, None))) # t1203: "cuda:0 bf16[1, 32, 512, 64]" - # t1203 = prims.slice_prim(t1202, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1203: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1204 = ltorch.getitem(t1202, (..., slice(64, None, None))) # t1204: "cuda:0 bf16[1, 32, 512, 64]" - # t1204 = prims.slice_prim(t1202, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1204: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1207 = ltorch.neg(t1204) # t1207: "cuda:0 bf16[1, 32, 512, 64]" - # t1205 = prims.convert_element_type(t1204, dtypes.float32) # t1205: "cuda:0 f32[1, 32, 512, 64]" - # t1206 = prims.neg(t1205) # t1206: "cuda:0 f32[1, 32, 512, 64]" - # t1207 = prims.convert_element_type(t1206, dtypes.bfloat16) # t1207: "cuda:0 bf16[1, 32, 512, 64]" - t1208 = ltorch.cat((t1207, t1203), -1) # t1208: "cuda:0 bf16[1, 32, 512, 128]" - # t1208 = prims.cat((t1207, t1203), -1) # t1208: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1211 = ltorch.mul(t1202, cos) # t1211: "cuda:0 f32[1, 32, 512, 128]" - # t1209 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1209: "cuda:0 f32[1, 32, 512, 128]" - # t1210 = prims.convert_element_type(t1202, dtypes.float32) # t1210: "cuda:0 f32[1, 32, 512, 128]" - # t1211 = prims.mul(t1210, t1209) # t1211: "cuda:0 f32[1, 32, 512, 128]" - t1214 = ltorch.mul(t1208, sin) # t1214: "cuda:0 f32[1, 32, 512, 128]" - # t1212 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1212: "cuda:0 f32[1, 32, 512, 128]" - # t1213 = prims.convert_element_type(t1208, dtypes.float32) # t1213: "cuda:0 f32[1, 32, 512, 128]" - # t1214 = prims.mul(t1213, t1212) # t1214: "cuda:0 f32[1, 32, 512, 128]" - t1215 = ltorch.add(t1211, t1214, alpha=None) # t1215: "cuda:0 f32[1, 32, 512, 128]" - # t1215 = prims.add(t1211, t1214) # t1215: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1216 = ltorch.to(t1215, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1216: "cuda:0 bf16[1, 32, 512, 128]" - # t1216 = prims.convert_element_type(t1215, dtypes.bfloat16) # t1216: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t1217 = ltorch.getitem(t1200, (..., slice(None, 128, None))) # t1217: "cuda:0 bf16[1, 32, 512, 128]" - # t1217 = prims.slice_prim(t1200, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1217: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1218 = ltorch.getitem(t1217, (..., slice(None, 64, None))) # t1218: "cuda:0 bf16[1, 32, 512, 64]" - # t1218 = prims.slice_prim(t1217, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1218: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1219 = ltorch.getitem(t1217, (..., slice(64, None, None))) # t1219: "cuda:0 bf16[1, 32, 512, 64]" - # t1219 = prims.slice_prim(t1217, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1219: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1222 = ltorch.neg(t1219) # t1222: "cuda:0 bf16[1, 32, 512, 64]" - # t1220 = prims.convert_element_type(t1219, dtypes.float32) # t1220: "cuda:0 f32[1, 32, 512, 64]" - # t1221 = prims.neg(t1220) # t1221: "cuda:0 f32[1, 32, 512, 64]" - # t1222 = prims.convert_element_type(t1221, dtypes.bfloat16) # t1222: "cuda:0 bf16[1, 32, 512, 64]" - t1223 = ltorch.cat((t1222, t1218), -1) # t1223: "cuda:0 bf16[1, 32, 512, 128]" - # t1223 = prims.cat((t1222, t1218), -1) # t1223: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1226 = ltorch.mul(t1217, cos) # t1226: "cuda:0 f32[1, 32, 512, 128]" - # t1224 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1224: "cuda:0 f32[1, 32, 512, 128]" - # t1225 = prims.convert_element_type(t1217, dtypes.float32) # t1225: "cuda:0 f32[1, 32, 512, 128]" - # t1226 = prims.mul(t1225, t1224) # t1226: "cuda:0 f32[1, 32, 512, 128]" - t1229 = ltorch.mul(t1223, sin) # t1229: "cuda:0 f32[1, 32, 512, 128]" - # t1227 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1227: "cuda:0 f32[1, 32, 512, 128]" - # t1228 = prims.convert_element_type(t1223, dtypes.float32) # t1228: "cuda:0 f32[1, 32, 512, 128]" - # t1229 = prims.mul(t1228, t1227) # t1229: "cuda:0 f32[1, 32, 512, 128]" - t1230 = ltorch.add(t1226, t1229, alpha=None) # t1230: "cuda:0 f32[1, 32, 512, 128]" - # t1230 = prims.add(t1226, t1229) # t1230: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1231 = ltorch.to(t1230, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1231: "cuda:0 bf16[1, 32, 512, 128]" - # t1231 = prims.convert_element_type(t1230, dtypes.bfloat16) # t1231: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t1232 = ltorch.getitem(t1199, (..., slice(128, None, None))) # t1232: "cuda:0 bf16[1, 32, 512, 0]" - # t1232 = prims.slice_prim(t1199, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1232: "cuda:0 bf16[1, 32, 512, 0]" - t1233 = ltorch.cat((t1216, t1232), -1) # t1233: "cuda:0 bf16[1, 32, 512, 128]" - # t1233 = prims.cat((t1216, t1232), -1) # t1233: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t1234 = ltorch.getitem(t1200, (..., slice(128, None, None))) # t1234: "cuda:0 bf16[1, 32, 512, 0]" - # t1234 = prims.slice_prim(t1200, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1234: "cuda:0 bf16[1, 32, 512, 0]" - t1235 = ltorch.cat((t1231, t1234), -1) # t1235: "cuda:0 bf16[1, 32, 512, 128]" - # t1235 = prims.cat((t1231, t1234), -1) # t1235: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t1265 = ltorch.scaled_dot_product_attention(t1233, t1235, t1201, None, 0.0, True, scale=0.08838834764831843) # t1265: "cuda:0 bf16[1, 32, 512, 128]" - # t1238 = ltorch.mul(t1233, 0.29730177875068026) # t1238: "cuda:0 bf16[1, 32, 512, 128]" - # t1236 = prims.convert_element_type(t1233, dtypes.float32) # t1236: "cuda:0 f32[1, 32, 512, 128]" - # t1237 = prims.mul(t1236, 0.29730177875068026) # t1237: "cuda:0 f32[1, 32, 512, 128]" - # t1238 = prims.convert_element_type(t1237, dtypes.bfloat16) # t1238: "cuda:0 bf16[1, 32, 512, 128]" - # t1239 = ltorch.transpose(t1235, -2, -1) # t1239: "cuda:0 bf16[1, 32, 128, 512]" - # t1239 = prims.transpose(t1235, (0, 1, 3, 2)) # t1239: "cuda:0 bf16[1, 32, 128, 512]" - # t1242 = ltorch.mul(t1239, 0.29730177875068026) # t1242: "cuda:0 bf16[1, 32, 128, 512]" - # t1240 = prims.convert_element_type(t1239, dtypes.float32) # t1240: "cuda:0 f32[1, 32, 128, 512]" - # t1241 = prims.mul(t1240, 0.29730177875068026) # t1241: "cuda:0 f32[1, 32, 128, 512]" - # t1242 = prims.convert_element_type(t1241, dtypes.bfloat16) # t1242: "cuda:0 bf16[1, 32, 128, 512]" - # t1243 = ltorch.matmul(t1238, t1242) # t1243: "cuda:0 bf16[1, 32, 512, 512]" - # t1243 = prims.matmul(t1238, t1242) # t1243: "cuda:0 bf16[1, 32, 512, 512]" - # t1253 = ltorch.tril(t1243, 0, fill_value=-float('inf')) # t1253: "cuda:0 bf16[1, 32, 512, 512]" - # t1244 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1244: "cuda:0 i64[512]" - # t1244 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1244: "cuda:0 i64[512]" - # t1245 = ltorch.unsqueeze(t1244, -1) # t1245: "cuda:0 i64[512, 1]" - # t1245 = prims.broadcast_in_dim(t1244, [512, 1], [0]) # t1245: "cuda:0 i64[512, 1]" - # t1246 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1246: "cuda:0 i64[512]" - # t1246 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1246: "cuda:0 i64[512]" - # t1247 = ltorch.unsqueeze(t1246, -2) # t1247: "cuda:0 i64[1, 512]" - # t1247 = prims.broadcast_in_dim(t1246, [1, 512], [1]) # t1247: "cuda:0 i64[1, 512]" - # t1248 = ltorch.add(t1245, 0, alpha=None) # t1248: "cuda:0 i64[512, 1]" - # t1248 = prims.add(t1245, 0) # t1248: "cuda:0 i64[512, 1]" - # t1251 = ltorch.ge(t1248, t1247) # t1251: "cuda:0 b8[512, 512]" - # t1249 = prims.broadcast_in_dim(t1248, (512, 512), (0, 1)) # t1249: "cuda:0 i64[512, 512]" - # t1250 = prims.broadcast_in_dim(t1247, (512, 512), (0, 1)) # t1250: "cuda:0 i64[512, 512]" - # t1251 = prims.ge(t1249, t1250) # t1251: "cuda:0 b8[512, 512]" - # t1253 = ltorch.where(t1251, t1243, -float('inf')) # t1253: "cuda:0 bf16[1, 32, 512, 512]" - # t1252 = prims.broadcast_in_dim(t1251, (1, 32, 512, 512), (2, 3)) # t1252: "cuda:0 b8[1, 32, 512, 512]" - # t1253 = prims.where(t1252, t1243, -float('inf')) # t1253: "cuda:0 bf16[1, 32, 512, 512]" - # t1264 = ltorch._softmax(t1253, -1, dtype=None) # t1264: "cuda:0 bf16[1, 32, 512, 512]" - # t1254 = ltorch.to(t1253, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t1254: "cuda:0 f32[1, 32, 512, 512]" - # t1254 = prims.convert_element_type(t1253, dtypes.float32) # t1254: "cuda:0 f32[1, 32, 512, 512]" - # t1256 = ltorch.amax(t1254, -1, True) # t1256: "cuda:0 f32[1, 32, 512, 1]" - # t1255 = prims.amax(t1254, (3,)) # t1255: "cuda:0 f32[1, 32, 512]" - # t1256 = prims.broadcast_in_dim(t1255, [1, 32, 512, 1], [0, 1, 2]) # t1256: "cuda:0 f32[1, 32, 512, 1]" - # t1258 = ltorch.sub(t1254, t1256, alpha=None) # t1258: "cuda:0 f32[1, 32, 512, 512]" - # t1257 = prims.broadcast_in_dim(t1256, (1, 32, 512, 512), (0, 1, 2, 3)) # t1257: "cuda:0 f32[1, 32, 512, 512]" - # t1258 = prims.sub(t1254, t1257) # t1258: "cuda:0 f32[1, 32, 512, 512]" - # t1259 = ltorch.exp(t1258) # t1259: "cuda:0 f32[1, 32, 512, 512]" - # t1259 = prims.exp(t1258) # t1259: "cuda:0 f32[1, 32, 512, 512]" - # t1261 = ltorch.sum(t1259, -1, True, dtype=None) # t1261: "cuda:0 f32[1, 32, 512, 1]" - # t1260 = prims.sum(t1259, (3,)) # t1260: "cuda:0 f32[1, 32, 512]" - # t1261 = prims.broadcast_in_dim(t1260, [1, 32, 512, 1], [0, 1, 2]) # t1261: "cuda:0 f32[1, 32, 512, 1]" - # t1263 = ltorch.true_divide(t1259, t1261) # t1263: "cuda:0 f32[1, 32, 512, 512]" - # t1262 = prims.broadcast_in_dim(t1261, (1, 32, 512, 512), (0, 1, 2, 3)) # t1262: "cuda:0 f32[1, 32, 512, 512]" - # t1263 = prims.div(t1259, t1262) # t1263: "cuda:0 f32[1, 32, 512, 512]" - # t1264 = ltorch.to(t1263, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t1264: "cuda:0 bf16[1, 32, 512, 512]" - # t1264 = prims.convert_element_type(t1263, dtypes.bfloat16) # t1264: "cuda:0 bf16[1, 32, 512, 512]" - # t1265 = ltorch.matmul(t1264, t1201) # t1265: "cuda:0 bf16[1, 32, 512, 128]" - # t1265 = prims.matmul(t1264, t1201) # t1265: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t1266 = ltorch.transpose(t1265, 1, 2) # t1266: "cuda:0 bf16[1, 512, 32, 128]" - # t1266 = prims.transpose(t1265, (0, 2, 1, 3)) # t1266: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t1267 = ltorch.reshape(t1266, 1, 512, 4096) # t1267: "cuda:0 bf16[1, 512, 4096]" - # t1267 = prims.reshape(t1266, (1, 512, 4096)) # t1267: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1271 = ltorch.linear(t1267, t_transformer_h_7_attn_proj_weight, None) # t1271: "cuda:0 bf16[1, 512, 4096]" - # t1271 = prims.linear(t1267, t_transformer_h_7_attn_proj_weight, None) # t1271: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t1275 = ltorch.add(t1271, t1165, alpha=None) # t1275: "cuda:0 bf16[1, 512, 4096]" - # t1272 = prims.convert_element_type(t1271, dtypes.float32) # t1272: "cuda:0 f32[1, 512, 4096]" - # t1273 = prims.convert_element_type(t1165, dtypes.float32) # t1273: "cuda:0 f32[1, 512, 4096]" - # t1274 = prims.add(t1272, t1273) # t1274: "cuda:0 f32[1, 512, 4096]" - # t1275 = prims.convert_element_type(t1274, dtypes.bfloat16) # t1275: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1276 = prims.convert_element_type(t1275, dtypes.float32) # t1276: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1277 = ltorch.mul(t1276, t1276) # t1277: "cuda:0 f32[1, 512, 4096]" - # t1277 = prims.mul(t1276, t1276) # t1277: "cuda:0 f32[1, 512, 4096]" - t1281 = ltorch.mean(t1277, -1, True, dtype=None) # t1281: "cuda:0 f32[1, 512, 1]" - # t1279 = prims.sum(t1277, (2,)) # t1279: "cuda:0 f32[1, 512]" - # t1280 = prims.broadcast_in_dim(t1279, [1, 512, 1], [0, 1]) # t1280: "cuda:0 f32[1, 512, 1]" - # t1281 = ltorch.true_divide(t1280, 4096) # t1281: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1281 = prims.div(t1280, 4096.0) # t1281: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1283 = ltorch.add(t1281, 1e-05, alpha=None) # t1283: "cuda:0 f32[1, 512, 1]" - # t1283 = prims.add(t1281, 1e-05) # t1283: "cuda:0 f32[1, 512, 1]" - t1284 = ltorch.rsqrt(t1283) # t1284: "cuda:0 f32[1, 512, 1]" - # t1284 = prims.rsqrt(t1283) # t1284: "cuda:0 f32[1, 512, 1]" - t1286 = ltorch.mul(t1276, t1284) # t1286: "cuda:0 f32[1, 512, 4096]" - # t1285 = prims.broadcast_in_dim(t1284, (1, 512, 4096), (0, 1, 2)) # t1285: "cuda:0 f32[1, 512, 4096]" - # t1286 = prims.mul(t1276, t1285) # t1286: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1287 = ltorch.to(t1286, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1287: "cuda:0 bf16[1, 512, 4096]" - # t1287 = prims.convert_element_type(t1286, dtypes.bfloat16) # t1287: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1297 = ltorch.mul(t1287, t_transformer_h_7_norm_2_weight) # t1297: "cuda:0 bf16[1, 512, 4096]" - # t1293 = prims.broadcast_in_dim(t_transformer_h_7_norm_2_weight, (1, 512, 4096), (2,)) # t1293: "cuda:0 bf16[1, 512, 4096]" - # t1294 = prims.convert_element_type(t1287, dtypes.float32) # t1294: "cuda:0 f32[1, 512, 4096]" - # t1295 = prims.convert_element_type(t1293, dtypes.float32) # t1295: "cuda:0 f32[1, 512, 4096]" - # t1296 = prims.mul(t1294, t1295) # t1296: "cuda:0 f32[1, 512, 4096]" - # t1297 = prims.convert_element_type(t1296, dtypes.bfloat16) # t1297: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1302 = ltorch.linear(t1297, t_transformer_h_7_mlp_fc_1_weight, None) # t1302: "cuda:0 bf16[1, 512, 11008]" - # t1302 = prims.linear(t1297, t_transformer_h_7_mlp_fc_1_weight, None) # t1302: "cuda:0 bf16[1, 512, 11008]" - t1306 = ltorch.linear(t1297, t_transformer_h_7_mlp_fc_2_weight, None) # t1306: "cuda:0 bf16[1, 512, 11008]" - # t1306 = prims.linear(t1297, t_transformer_h_7_mlp_fc_2_weight, None) # t1306: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t1316 = ltorch.silu(t1302, False) # t1316: "cuda:0 bf16[1, 512, 11008]" - # t1307 = prims.convert_element_type(t1302, dtypes.float32) # t1307: "cuda:0 f32[1, 512, 11008]" - # t1308 = prims.neg(t1307) # t1308: "cuda:0 f32[1, 512, 11008]" - # t1309 = prims.exp(t1308) # t1309: "cuda:0 f32[1, 512, 11008]" - # t1310 = prims.add(1.0, t1309) # t1310: "cuda:0 f32[1, 512, 11008]" - # t1311 = prims.reciprocal(t1310) # t1311: "cuda:0 f32[1, 512, 11008]" - # t1312 = prims.convert_element_type(t1311, dtypes.bfloat16) # t1312: "cuda:0 bf16[1, 512, 11008]" - # t1313 = prims.convert_element_type(t1302, dtypes.float32) # t1313: "cuda:0 f32[1, 512, 11008]" - # t1314 = prims.convert_element_type(t1312, dtypes.float32) # t1314: "cuda:0 f32[1, 512, 11008]" - # t1315 = prims.mul(t1313, t1314) # t1315: "cuda:0 f32[1, 512, 11008]" - # t1316 = prims.convert_element_type(t1315, dtypes.bfloat16) # t1316: "cuda:0 bf16[1, 512, 11008]" - t1320 = ltorch.mul(t1316, t1306) # t1320: "cuda:0 bf16[1, 512, 11008]" - # t1317 = prims.convert_element_type(t1316, dtypes.float32) # t1317: "cuda:0 f32[1, 512, 11008]" - # t1318 = prims.convert_element_type(t1306, dtypes.float32) # t1318: "cuda:0 f32[1, 512, 11008]" - # t1319 = prims.mul(t1317, t1318) # t1319: "cuda:0 f32[1, 512, 11008]" - # t1320 = prims.convert_element_type(t1319, dtypes.bfloat16) # t1320: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1324 = ltorch.linear(t1320, t_transformer_h_7_mlp_proj_weight, None) # t1324: "cuda:0 bf16[1, 512, 4096]" - # t1324 = prims.linear(t1320, t_transformer_h_7_mlp_proj_weight, None) # t1324: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t1328 = ltorch.add(t1324, t1275, alpha=None) # t1328: "cuda:0 bf16[1, 512, 4096]" - # t1325 = prims.convert_element_type(t1324, dtypes.float32) # t1325: "cuda:0 f32[1, 512, 4096]" - # t1326 = prims.convert_element_type(t1275, dtypes.float32) # t1326: "cuda:0 f32[1, 512, 4096]" - # t1327 = prims.add(t1325, t1326) # t1327: "cuda:0 f32[1, 512, 4096]" - # t1328 = prims.convert_element_type(t1327, dtypes.bfloat16) # t1328: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1330 = prims.convert_element_type(t1328, dtypes.float32) # t1330: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1331 = ltorch.mul(t1330, t1330) # t1331: "cuda:0 f32[1, 512, 4096]" - # t1331 = prims.mul(t1330, t1330) # t1331: "cuda:0 f32[1, 512, 4096]" - t1335 = ltorch.mean(t1331, -1, True, dtype=None) # t1335: "cuda:0 f32[1, 512, 1]" - # t1333 = prims.sum(t1331, (2,)) # t1333: "cuda:0 f32[1, 512]" - # t1334 = prims.broadcast_in_dim(t1333, [1, 512, 1], [0, 1]) # t1334: "cuda:0 f32[1, 512, 1]" - # t1335 = ltorch.true_divide(t1334, 4096) # t1335: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1335 = prims.div(t1334, 4096.0) # t1335: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1337 = ltorch.add(t1335, 1e-05, alpha=None) # t1337: "cuda:0 f32[1, 512, 1]" - # t1337 = prims.add(t1335, 1e-05) # t1337: "cuda:0 f32[1, 512, 1]" - t1338 = ltorch.rsqrt(t1337) # t1338: "cuda:0 f32[1, 512, 1]" - # t1338 = prims.rsqrt(t1337) # t1338: "cuda:0 f32[1, 512, 1]" - t1340 = ltorch.mul(t1330, t1338) # t1340: "cuda:0 f32[1, 512, 4096]" - # t1339 = prims.broadcast_in_dim(t1338, (1, 512, 4096), (0, 1, 2)) # t1339: "cuda:0 f32[1, 512, 4096]" - # t1340 = prims.mul(t1330, t1339) # t1340: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1341 = ltorch.to(t1340, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1341: "cuda:0 bf16[1, 512, 4096]" - # t1341 = prims.convert_element_type(t1340, dtypes.bfloat16) # t1341: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1351 = ltorch.mul(t1341, t_transformer_h_8_norm_1_weight) # t1351: "cuda:0 bf16[1, 512, 4096]" - # t1347 = prims.broadcast_in_dim(t_transformer_h_8_norm_1_weight, (1, 512, 4096), (2,)) # t1347: "cuda:0 bf16[1, 512, 4096]" - # t1348 = prims.convert_element_type(t1341, dtypes.float32) # t1348: "cuda:0 f32[1, 512, 4096]" - # t1349 = prims.convert_element_type(t1347, dtypes.float32) # t1349: "cuda:0 f32[1, 512, 4096]" - # t1350 = prims.mul(t1348, t1349) # t1350: "cuda:0 f32[1, 512, 4096]" - # t1351 = prims.convert_element_type(t1350, dtypes.bfloat16) # t1351: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1356 = ltorch.linear(t1351, t_transformer_h_8_attn_attn_weight, None) # t1356: "cuda:0 bf16[1, 512, 12288]" - # t1356 = prims.linear(t1351, t_transformer_h_8_attn_attn_weight, None) # t1356: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t1357 = ltorch.view(t1356, 1, 512, 32, 3, 128) # t1357: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1357 = ltorch.reshape(t1356, (1, 512, 32, 3, 128)) # t1357: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1357 = prims.reshape(t1356, (1, 512, 32, 3, 128)) # t1357: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t1358 = ltorch.permute(t1357, 0, 2, 3, 1, 4) # t1358: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t1358 = prims.transpose(t1357, (0, 2, 3, 1, 4)) # t1358: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t1359, t1360, t1361) = ltorch.split(t1358, (1, 1, 1), 2) - # t1359 = prims.slice_prim(t1358, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1359: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1360 = prims.slice_prim(t1358, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1360: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1361 = prims.slice_prim(t1358, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1361: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t1362 = ltorch.reshape(t1359, 1, -1, 512, 128) # t1362: "cuda:0 bf16[1, 32, 512, 128]" - # t1362 = prims.reshape(t1359, (1, 32, 512, 128)) # t1362: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t1363 = ltorch.reshape(t1360, 1, -1, 512, 128) # t1363: "cuda:0 bf16[1, 32, 512, 128]" - # t1363 = prims.reshape(t1360, (1, 32, 512, 128)) # t1363: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t1364 = ltorch.reshape(t1361, 1, -1, 512, 128) # t1364: "cuda:0 bf16[1, 32, 512, 128]" - # t1364 = prims.reshape(t1361, (1, 32, 512, 128)) # t1364: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t1365 = ltorch.getitem(t1362, (..., slice(None, 128, None))) # t1365: "cuda:0 bf16[1, 32, 512, 128]" - # t1365 = prims.slice_prim(t1362, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1365: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1366 = ltorch.getitem(t1365, (..., slice(None, 64, None))) # t1366: "cuda:0 bf16[1, 32, 512, 64]" - # t1366 = prims.slice_prim(t1365, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1366: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1367 = ltorch.getitem(t1365, (..., slice(64, None, None))) # t1367: "cuda:0 bf16[1, 32, 512, 64]" - # t1367 = prims.slice_prim(t1365, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1367: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1370 = ltorch.neg(t1367) # t1370: "cuda:0 bf16[1, 32, 512, 64]" - # t1368 = prims.convert_element_type(t1367, dtypes.float32) # t1368: "cuda:0 f32[1, 32, 512, 64]" - # t1369 = prims.neg(t1368) # t1369: "cuda:0 f32[1, 32, 512, 64]" - # t1370 = prims.convert_element_type(t1369, dtypes.bfloat16) # t1370: "cuda:0 bf16[1, 32, 512, 64]" - t1371 = ltorch.cat((t1370, t1366), -1) # t1371: "cuda:0 bf16[1, 32, 512, 128]" - # t1371 = prims.cat((t1370, t1366), -1) # t1371: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1374 = ltorch.mul(t1365, cos) # t1374: "cuda:0 f32[1, 32, 512, 128]" - # t1372 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1372: "cuda:0 f32[1, 32, 512, 128]" - # t1373 = prims.convert_element_type(t1365, dtypes.float32) # t1373: "cuda:0 f32[1, 32, 512, 128]" - # t1374 = prims.mul(t1373, t1372) # t1374: "cuda:0 f32[1, 32, 512, 128]" - t1377 = ltorch.mul(t1371, sin) # t1377: "cuda:0 f32[1, 32, 512, 128]" - # t1375 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1375: "cuda:0 f32[1, 32, 512, 128]" - # t1376 = prims.convert_element_type(t1371, dtypes.float32) # t1376: "cuda:0 f32[1, 32, 512, 128]" - # t1377 = prims.mul(t1376, t1375) # t1377: "cuda:0 f32[1, 32, 512, 128]" - t1378 = ltorch.add(t1374, t1377, alpha=None) # t1378: "cuda:0 f32[1, 32, 512, 128]" - # t1378 = prims.add(t1374, t1377) # t1378: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1379 = ltorch.to(t1378, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1379: "cuda:0 bf16[1, 32, 512, 128]" - # t1379 = prims.convert_element_type(t1378, dtypes.bfloat16) # t1379: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t1380 = ltorch.getitem(t1363, (..., slice(None, 128, None))) # t1380: "cuda:0 bf16[1, 32, 512, 128]" - # t1380 = prims.slice_prim(t1363, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1380: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1381 = ltorch.getitem(t1380, (..., slice(None, 64, None))) # t1381: "cuda:0 bf16[1, 32, 512, 64]" - # t1381 = prims.slice_prim(t1380, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1381: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1382 = ltorch.getitem(t1380, (..., slice(64, None, None))) # t1382: "cuda:0 bf16[1, 32, 512, 64]" - # t1382 = prims.slice_prim(t1380, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1382: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1385 = ltorch.neg(t1382) # t1385: "cuda:0 bf16[1, 32, 512, 64]" - # t1383 = prims.convert_element_type(t1382, dtypes.float32) # t1383: "cuda:0 f32[1, 32, 512, 64]" - # t1384 = prims.neg(t1383) # t1384: "cuda:0 f32[1, 32, 512, 64]" - # t1385 = prims.convert_element_type(t1384, dtypes.bfloat16) # t1385: "cuda:0 bf16[1, 32, 512, 64]" - t1386 = ltorch.cat((t1385, t1381), -1) # t1386: "cuda:0 bf16[1, 32, 512, 128]" - # t1386 = prims.cat((t1385, t1381), -1) # t1386: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1389 = ltorch.mul(t1380, cos) # t1389: "cuda:0 f32[1, 32, 512, 128]" - # t1387 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1387: "cuda:0 f32[1, 32, 512, 128]" - # t1388 = prims.convert_element_type(t1380, dtypes.float32) # t1388: "cuda:0 f32[1, 32, 512, 128]" - # t1389 = prims.mul(t1388, t1387) # t1389: "cuda:0 f32[1, 32, 512, 128]" - t1392 = ltorch.mul(t1386, sin) # t1392: "cuda:0 f32[1, 32, 512, 128]" - # t1390 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1390: "cuda:0 f32[1, 32, 512, 128]" - # t1391 = prims.convert_element_type(t1386, dtypes.float32) # t1391: "cuda:0 f32[1, 32, 512, 128]" - # t1392 = prims.mul(t1391, t1390) # t1392: "cuda:0 f32[1, 32, 512, 128]" - t1393 = ltorch.add(t1389, t1392, alpha=None) # t1393: "cuda:0 f32[1, 32, 512, 128]" - # t1393 = prims.add(t1389, t1392) # t1393: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1394 = ltorch.to(t1393, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1394: "cuda:0 bf16[1, 32, 512, 128]" - # t1394 = prims.convert_element_type(t1393, dtypes.bfloat16) # t1394: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t1395 = ltorch.getitem(t1362, (..., slice(128, None, None))) # t1395: "cuda:0 bf16[1, 32, 512, 0]" - # t1395 = prims.slice_prim(t1362, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1395: "cuda:0 bf16[1, 32, 512, 0]" - t1396 = ltorch.cat((t1379, t1395), -1) # t1396: "cuda:0 bf16[1, 32, 512, 128]" - # t1396 = prims.cat((t1379, t1395), -1) # t1396: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t1397 = ltorch.getitem(t1363, (..., slice(128, None, None))) # t1397: "cuda:0 bf16[1, 32, 512, 0]" - # t1397 = prims.slice_prim(t1363, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1397: "cuda:0 bf16[1, 32, 512, 0]" - t1398 = ltorch.cat((t1394, t1397), -1) # t1398: "cuda:0 bf16[1, 32, 512, 128]" - # t1398 = prims.cat((t1394, t1397), -1) # t1398: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t1428 = ltorch.scaled_dot_product_attention(t1396, t1398, t1364, None, 0.0, True, scale=0.08838834764831843) # t1428: "cuda:0 bf16[1, 32, 512, 128]" - # t1401 = ltorch.mul(t1396, 0.29730177875068026) # t1401: "cuda:0 bf16[1, 32, 512, 128]" - # t1399 = prims.convert_element_type(t1396, dtypes.float32) # t1399: "cuda:0 f32[1, 32, 512, 128]" - # t1400 = prims.mul(t1399, 0.29730177875068026) # t1400: "cuda:0 f32[1, 32, 512, 128]" - # t1401 = prims.convert_element_type(t1400, dtypes.bfloat16) # t1401: "cuda:0 bf16[1, 32, 512, 128]" - # t1402 = ltorch.transpose(t1398, -2, -1) # t1402: "cuda:0 bf16[1, 32, 128, 512]" - # t1402 = prims.transpose(t1398, (0, 1, 3, 2)) # t1402: "cuda:0 bf16[1, 32, 128, 512]" - # t1405 = ltorch.mul(t1402, 0.29730177875068026) # t1405: "cuda:0 bf16[1, 32, 128, 512]" - # t1403 = prims.convert_element_type(t1402, dtypes.float32) # t1403: "cuda:0 f32[1, 32, 128, 512]" - # t1404 = prims.mul(t1403, 0.29730177875068026) # t1404: "cuda:0 f32[1, 32, 128, 512]" - # t1405 = prims.convert_element_type(t1404, dtypes.bfloat16) # t1405: "cuda:0 bf16[1, 32, 128, 512]" - # t1406 = ltorch.matmul(t1401, t1405) # t1406: "cuda:0 bf16[1, 32, 512, 512]" - # t1406 = prims.matmul(t1401, t1405) # t1406: "cuda:0 bf16[1, 32, 512, 512]" - # t1416 = ltorch.tril(t1406, 0, fill_value=-float('inf')) # t1416: "cuda:0 bf16[1, 32, 512, 512]" - # t1407 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1407: "cuda:0 i64[512]" - # t1407 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1407: "cuda:0 i64[512]" - # t1408 = ltorch.unsqueeze(t1407, -1) # t1408: "cuda:0 i64[512, 1]" - # t1408 = prims.broadcast_in_dim(t1407, [512, 1], [0]) # t1408: "cuda:0 i64[512, 1]" - # t1409 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1409: "cuda:0 i64[512]" - # t1409 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1409: "cuda:0 i64[512]" - # t1410 = ltorch.unsqueeze(t1409, -2) # t1410: "cuda:0 i64[1, 512]" - # t1410 = prims.broadcast_in_dim(t1409, [1, 512], [1]) # t1410: "cuda:0 i64[1, 512]" - # t1411 = ltorch.add(t1408, 0, alpha=None) # t1411: "cuda:0 i64[512, 1]" - # t1411 = prims.add(t1408, 0) # t1411: "cuda:0 i64[512, 1]" - # t1414 = ltorch.ge(t1411, t1410) # t1414: "cuda:0 b8[512, 512]" - # t1412 = prims.broadcast_in_dim(t1411, (512, 512), (0, 1)) # t1412: "cuda:0 i64[512, 512]" - # t1413 = prims.broadcast_in_dim(t1410, (512, 512), (0, 1)) # t1413: "cuda:0 i64[512, 512]" - # t1414 = prims.ge(t1412, t1413) # t1414: "cuda:0 b8[512, 512]" - # t1416 = ltorch.where(t1414, t1406, -float('inf')) # t1416: "cuda:0 bf16[1, 32, 512, 512]" - # t1415 = prims.broadcast_in_dim(t1414, (1, 32, 512, 512), (2, 3)) # t1415: "cuda:0 b8[1, 32, 512, 512]" - # t1416 = prims.where(t1415, t1406, -float('inf')) # t1416: "cuda:0 bf16[1, 32, 512, 512]" - # t1427 = ltorch._softmax(t1416, -1, dtype=None) # t1427: "cuda:0 bf16[1, 32, 512, 512]" - # t1417 = ltorch.to(t1416, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t1417: "cuda:0 f32[1, 32, 512, 512]" - # t1417 = prims.convert_element_type(t1416, dtypes.float32) # t1417: "cuda:0 f32[1, 32, 512, 512]" - # t1419 = ltorch.amax(t1417, -1, True) # t1419: "cuda:0 f32[1, 32, 512, 1]" - # t1418 = prims.amax(t1417, (3,)) # t1418: "cuda:0 f32[1, 32, 512]" - # t1419 = prims.broadcast_in_dim(t1418, [1, 32, 512, 1], [0, 1, 2]) # t1419: "cuda:0 f32[1, 32, 512, 1]" - # t1421 = ltorch.sub(t1417, t1419, alpha=None) # t1421: "cuda:0 f32[1, 32, 512, 512]" - # t1420 = prims.broadcast_in_dim(t1419, (1, 32, 512, 512), (0, 1, 2, 3)) # t1420: "cuda:0 f32[1, 32, 512, 512]" - # t1421 = prims.sub(t1417, t1420) # t1421: "cuda:0 f32[1, 32, 512, 512]" - # t1422 = ltorch.exp(t1421) # t1422: "cuda:0 f32[1, 32, 512, 512]" - # t1422 = prims.exp(t1421) # t1422: "cuda:0 f32[1, 32, 512, 512]" - # t1424 = ltorch.sum(t1422, -1, True, dtype=None) # t1424: "cuda:0 f32[1, 32, 512, 1]" - # t1423 = prims.sum(t1422, (3,)) # t1423: "cuda:0 f32[1, 32, 512]" - # t1424 = prims.broadcast_in_dim(t1423, [1, 32, 512, 1], [0, 1, 2]) # t1424: "cuda:0 f32[1, 32, 512, 1]" - # t1426 = ltorch.true_divide(t1422, t1424) # t1426: "cuda:0 f32[1, 32, 512, 512]" - # t1425 = prims.broadcast_in_dim(t1424, (1, 32, 512, 512), (0, 1, 2, 3)) # t1425: "cuda:0 f32[1, 32, 512, 512]" - # t1426 = prims.div(t1422, t1425) # t1426: "cuda:0 f32[1, 32, 512, 512]" - # t1427 = ltorch.to(t1426, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t1427: "cuda:0 bf16[1, 32, 512, 512]" - # t1427 = prims.convert_element_type(t1426, dtypes.bfloat16) # t1427: "cuda:0 bf16[1, 32, 512, 512]" - # t1428 = ltorch.matmul(t1427, t1364) # t1428: "cuda:0 bf16[1, 32, 512, 128]" - # t1428 = prims.matmul(t1427, t1364) # t1428: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t1429 = ltorch.transpose(t1428, 1, 2) # t1429: "cuda:0 bf16[1, 512, 32, 128]" - # t1429 = prims.transpose(t1428, (0, 2, 1, 3)) # t1429: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t1430 = ltorch.reshape(t1429, 1, 512, 4096) # t1430: "cuda:0 bf16[1, 512, 4096]" - # t1430 = prims.reshape(t1429, (1, 512, 4096)) # t1430: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1434 = ltorch.linear(t1430, t_transformer_h_8_attn_proj_weight, None) # t1434: "cuda:0 bf16[1, 512, 4096]" - # t1434 = prims.linear(t1430, t_transformer_h_8_attn_proj_weight, None) # t1434: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t1438 = ltorch.add(t1434, t1328, alpha=None) # t1438: "cuda:0 bf16[1, 512, 4096]" - # t1435 = prims.convert_element_type(t1434, dtypes.float32) # t1435: "cuda:0 f32[1, 512, 4096]" - # t1436 = prims.convert_element_type(t1328, dtypes.float32) # t1436: "cuda:0 f32[1, 512, 4096]" - # t1437 = prims.add(t1435, t1436) # t1437: "cuda:0 f32[1, 512, 4096]" - # t1438 = prims.convert_element_type(t1437, dtypes.bfloat16) # t1438: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1439 = prims.convert_element_type(t1438, dtypes.float32) # t1439: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1440 = ltorch.mul(t1439, t1439) # t1440: "cuda:0 f32[1, 512, 4096]" - # t1440 = prims.mul(t1439, t1439) # t1440: "cuda:0 f32[1, 512, 4096]" - t1444 = ltorch.mean(t1440, -1, True, dtype=None) # t1444: "cuda:0 f32[1, 512, 1]" - # t1442 = prims.sum(t1440, (2,)) # t1442: "cuda:0 f32[1, 512]" - # t1443 = prims.broadcast_in_dim(t1442, [1, 512, 1], [0, 1]) # t1443: "cuda:0 f32[1, 512, 1]" - # t1444 = ltorch.true_divide(t1443, 4096) # t1444: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1444 = prims.div(t1443, 4096.0) # t1444: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1446 = ltorch.add(t1444, 1e-05, alpha=None) # t1446: "cuda:0 f32[1, 512, 1]" - # t1446 = prims.add(t1444, 1e-05) # t1446: "cuda:0 f32[1, 512, 1]" - t1447 = ltorch.rsqrt(t1446) # t1447: "cuda:0 f32[1, 512, 1]" - # t1447 = prims.rsqrt(t1446) # t1447: "cuda:0 f32[1, 512, 1]" - t1449 = ltorch.mul(t1439, t1447) # t1449: "cuda:0 f32[1, 512, 4096]" - # t1448 = prims.broadcast_in_dim(t1447, (1, 512, 4096), (0, 1, 2)) # t1448: "cuda:0 f32[1, 512, 4096]" - # t1449 = prims.mul(t1439, t1448) # t1449: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1450 = ltorch.to(t1449, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1450: "cuda:0 bf16[1, 512, 4096]" - # t1450 = prims.convert_element_type(t1449, dtypes.bfloat16) # t1450: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1460 = ltorch.mul(t1450, t_transformer_h_8_norm_2_weight) # t1460: "cuda:0 bf16[1, 512, 4096]" - # t1456 = prims.broadcast_in_dim(t_transformer_h_8_norm_2_weight, (1, 512, 4096), (2,)) # t1456: "cuda:0 bf16[1, 512, 4096]" - # t1457 = prims.convert_element_type(t1450, dtypes.float32) # t1457: "cuda:0 f32[1, 512, 4096]" - # t1458 = prims.convert_element_type(t1456, dtypes.float32) # t1458: "cuda:0 f32[1, 512, 4096]" - # t1459 = prims.mul(t1457, t1458) # t1459: "cuda:0 f32[1, 512, 4096]" - # t1460 = prims.convert_element_type(t1459, dtypes.bfloat16) # t1460: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1465 = ltorch.linear(t1460, t_transformer_h_8_mlp_fc_1_weight, None) # t1465: "cuda:0 bf16[1, 512, 11008]" - # t1465 = prims.linear(t1460, t_transformer_h_8_mlp_fc_1_weight, None) # t1465: "cuda:0 bf16[1, 512, 11008]" - t1469 = ltorch.linear(t1460, t_transformer_h_8_mlp_fc_2_weight, None) # t1469: "cuda:0 bf16[1, 512, 11008]" - # t1469 = prims.linear(t1460, t_transformer_h_8_mlp_fc_2_weight, None) # t1469: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t1479 = ltorch.silu(t1465, False) # t1479: "cuda:0 bf16[1, 512, 11008]" - # t1470 = prims.convert_element_type(t1465, dtypes.float32) # t1470: "cuda:0 f32[1, 512, 11008]" - # t1471 = prims.neg(t1470) # t1471: "cuda:0 f32[1, 512, 11008]" - # t1472 = prims.exp(t1471) # t1472: "cuda:0 f32[1, 512, 11008]" - # t1473 = prims.add(1.0, t1472) # t1473: "cuda:0 f32[1, 512, 11008]" - # t1474 = prims.reciprocal(t1473) # t1474: "cuda:0 f32[1, 512, 11008]" - # t1475 = prims.convert_element_type(t1474, dtypes.bfloat16) # t1475: "cuda:0 bf16[1, 512, 11008]" - # t1476 = prims.convert_element_type(t1465, dtypes.float32) # t1476: "cuda:0 f32[1, 512, 11008]" - # t1477 = prims.convert_element_type(t1475, dtypes.float32) # t1477: "cuda:0 f32[1, 512, 11008]" - # t1478 = prims.mul(t1476, t1477) # t1478: "cuda:0 f32[1, 512, 11008]" - # t1479 = prims.convert_element_type(t1478, dtypes.bfloat16) # t1479: "cuda:0 bf16[1, 512, 11008]" - t1483 = ltorch.mul(t1479, t1469) # t1483: "cuda:0 bf16[1, 512, 11008]" - # t1480 = prims.convert_element_type(t1479, dtypes.float32) # t1480: "cuda:0 f32[1, 512, 11008]" - # t1481 = prims.convert_element_type(t1469, dtypes.float32) # t1481: "cuda:0 f32[1, 512, 11008]" - # t1482 = prims.mul(t1480, t1481) # t1482: "cuda:0 f32[1, 512, 11008]" - # t1483 = prims.convert_element_type(t1482, dtypes.bfloat16) # t1483: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1487 = ltorch.linear(t1483, t_transformer_h_8_mlp_proj_weight, None) # t1487: "cuda:0 bf16[1, 512, 4096]" - # t1487 = prims.linear(t1483, t_transformer_h_8_mlp_proj_weight, None) # t1487: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t1491 = ltorch.add(t1487, t1438, alpha=None) # t1491: "cuda:0 bf16[1, 512, 4096]" - # t1488 = prims.convert_element_type(t1487, dtypes.float32) # t1488: "cuda:0 f32[1, 512, 4096]" - # t1489 = prims.convert_element_type(t1438, dtypes.float32) # t1489: "cuda:0 f32[1, 512, 4096]" - # t1490 = prims.add(t1488, t1489) # t1490: "cuda:0 f32[1, 512, 4096]" - # t1491 = prims.convert_element_type(t1490, dtypes.bfloat16) # t1491: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1493 = prims.convert_element_type(t1491, dtypes.float32) # t1493: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1494 = ltorch.mul(t1493, t1493) # t1494: "cuda:0 f32[1, 512, 4096]" - # t1494 = prims.mul(t1493, t1493) # t1494: "cuda:0 f32[1, 512, 4096]" - t1498 = ltorch.mean(t1494, -1, True, dtype=None) # t1498: "cuda:0 f32[1, 512, 1]" - # t1496 = prims.sum(t1494, (2,)) # t1496: "cuda:0 f32[1, 512]" - # t1497 = prims.broadcast_in_dim(t1496, [1, 512, 1], [0, 1]) # t1497: "cuda:0 f32[1, 512, 1]" - # t1498 = ltorch.true_divide(t1497, 4096) # t1498: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1498 = prims.div(t1497, 4096.0) # t1498: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1500 = ltorch.add(t1498, 1e-05, alpha=None) # t1500: "cuda:0 f32[1, 512, 1]" - # t1500 = prims.add(t1498, 1e-05) # t1500: "cuda:0 f32[1, 512, 1]" - t1501 = ltorch.rsqrt(t1500) # t1501: "cuda:0 f32[1, 512, 1]" - # t1501 = prims.rsqrt(t1500) # t1501: "cuda:0 f32[1, 512, 1]" - t1503 = ltorch.mul(t1493, t1501) # t1503: "cuda:0 f32[1, 512, 4096]" - # t1502 = prims.broadcast_in_dim(t1501, (1, 512, 4096), (0, 1, 2)) # t1502: "cuda:0 f32[1, 512, 4096]" - # t1503 = prims.mul(t1493, t1502) # t1503: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1504 = ltorch.to(t1503, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1504: "cuda:0 bf16[1, 512, 4096]" - # t1504 = prims.convert_element_type(t1503, dtypes.bfloat16) # t1504: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1514 = ltorch.mul(t1504, t_transformer_h_9_norm_1_weight) # t1514: "cuda:0 bf16[1, 512, 4096]" - # t1510 = prims.broadcast_in_dim(t_transformer_h_9_norm_1_weight, (1, 512, 4096), (2,)) # t1510: "cuda:0 bf16[1, 512, 4096]" - # t1511 = prims.convert_element_type(t1504, dtypes.float32) # t1511: "cuda:0 f32[1, 512, 4096]" - # t1512 = prims.convert_element_type(t1510, dtypes.float32) # t1512: "cuda:0 f32[1, 512, 4096]" - # t1513 = prims.mul(t1511, t1512) # t1513: "cuda:0 f32[1, 512, 4096]" - # t1514 = prims.convert_element_type(t1513, dtypes.bfloat16) # t1514: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1519 = ltorch.linear(t1514, t_transformer_h_9_attn_attn_weight, None) # t1519: "cuda:0 bf16[1, 512, 12288]" - # t1519 = prims.linear(t1514, t_transformer_h_9_attn_attn_weight, None) # t1519: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t1520 = ltorch.view(t1519, 1, 512, 32, 3, 128) # t1520: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1520 = ltorch.reshape(t1519, (1, 512, 32, 3, 128)) # t1520: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1520 = prims.reshape(t1519, (1, 512, 32, 3, 128)) # t1520: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t1521 = ltorch.permute(t1520, 0, 2, 3, 1, 4) # t1521: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t1521 = prims.transpose(t1520, (0, 2, 3, 1, 4)) # t1521: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t1522, t1523, t1524) = ltorch.split(t1521, (1, 1, 1), 2) - # t1522 = prims.slice_prim(t1521, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1522: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1523 = prims.slice_prim(t1521, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1523: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1524 = prims.slice_prim(t1521, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1524: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t1525 = ltorch.reshape(t1522, 1, -1, 512, 128) # t1525: "cuda:0 bf16[1, 32, 512, 128]" - # t1525 = prims.reshape(t1522, (1, 32, 512, 128)) # t1525: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t1526 = ltorch.reshape(t1523, 1, -1, 512, 128) # t1526: "cuda:0 bf16[1, 32, 512, 128]" - # t1526 = prims.reshape(t1523, (1, 32, 512, 128)) # t1526: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t1527 = ltorch.reshape(t1524, 1, -1, 512, 128) # t1527: "cuda:0 bf16[1, 32, 512, 128]" - # t1527 = prims.reshape(t1524, (1, 32, 512, 128)) # t1527: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t1528 = ltorch.getitem(t1525, (..., slice(None, 128, None))) # t1528: "cuda:0 bf16[1, 32, 512, 128]" - # t1528 = prims.slice_prim(t1525, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1528: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1529 = ltorch.getitem(t1528, (..., slice(None, 64, None))) # t1529: "cuda:0 bf16[1, 32, 512, 64]" - # t1529 = prims.slice_prim(t1528, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1529: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1530 = ltorch.getitem(t1528, (..., slice(64, None, None))) # t1530: "cuda:0 bf16[1, 32, 512, 64]" - # t1530 = prims.slice_prim(t1528, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1530: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1533 = ltorch.neg(t1530) # t1533: "cuda:0 bf16[1, 32, 512, 64]" - # t1531 = prims.convert_element_type(t1530, dtypes.float32) # t1531: "cuda:0 f32[1, 32, 512, 64]" - # t1532 = prims.neg(t1531) # t1532: "cuda:0 f32[1, 32, 512, 64]" - # t1533 = prims.convert_element_type(t1532, dtypes.bfloat16) # t1533: "cuda:0 bf16[1, 32, 512, 64]" - t1534 = ltorch.cat((t1533, t1529), -1) # t1534: "cuda:0 bf16[1, 32, 512, 128]" - # t1534 = prims.cat((t1533, t1529), -1) # t1534: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1537 = ltorch.mul(t1528, cos) # t1537: "cuda:0 f32[1, 32, 512, 128]" - # t1535 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1535: "cuda:0 f32[1, 32, 512, 128]" - # t1536 = prims.convert_element_type(t1528, dtypes.float32) # t1536: "cuda:0 f32[1, 32, 512, 128]" - # t1537 = prims.mul(t1536, t1535) # t1537: "cuda:0 f32[1, 32, 512, 128]" - t1540 = ltorch.mul(t1534, sin) # t1540: "cuda:0 f32[1, 32, 512, 128]" - # t1538 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1538: "cuda:0 f32[1, 32, 512, 128]" - # t1539 = prims.convert_element_type(t1534, dtypes.float32) # t1539: "cuda:0 f32[1, 32, 512, 128]" - # t1540 = prims.mul(t1539, t1538) # t1540: "cuda:0 f32[1, 32, 512, 128]" - t1541 = ltorch.add(t1537, t1540, alpha=None) # t1541: "cuda:0 f32[1, 32, 512, 128]" - # t1541 = prims.add(t1537, t1540) # t1541: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1542 = ltorch.to(t1541, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1542: "cuda:0 bf16[1, 32, 512, 128]" - # t1542 = prims.convert_element_type(t1541, dtypes.bfloat16) # t1542: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t1543 = ltorch.getitem(t1526, (..., slice(None, 128, None))) # t1543: "cuda:0 bf16[1, 32, 512, 128]" - # t1543 = prims.slice_prim(t1526, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1543: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1544 = ltorch.getitem(t1543, (..., slice(None, 64, None))) # t1544: "cuda:0 bf16[1, 32, 512, 64]" - # t1544 = prims.slice_prim(t1543, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1544: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1545 = ltorch.getitem(t1543, (..., slice(64, None, None))) # t1545: "cuda:0 bf16[1, 32, 512, 64]" - # t1545 = prims.slice_prim(t1543, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1545: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1548 = ltorch.neg(t1545) # t1548: "cuda:0 bf16[1, 32, 512, 64]" - # t1546 = prims.convert_element_type(t1545, dtypes.float32) # t1546: "cuda:0 f32[1, 32, 512, 64]" - # t1547 = prims.neg(t1546) # t1547: "cuda:0 f32[1, 32, 512, 64]" - # t1548 = prims.convert_element_type(t1547, dtypes.bfloat16) # t1548: "cuda:0 bf16[1, 32, 512, 64]" - t1549 = ltorch.cat((t1548, t1544), -1) # t1549: "cuda:0 bf16[1, 32, 512, 128]" - # t1549 = prims.cat((t1548, t1544), -1) # t1549: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1552 = ltorch.mul(t1543, cos) # t1552: "cuda:0 f32[1, 32, 512, 128]" - # t1550 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1550: "cuda:0 f32[1, 32, 512, 128]" - # t1551 = prims.convert_element_type(t1543, dtypes.float32) # t1551: "cuda:0 f32[1, 32, 512, 128]" - # t1552 = prims.mul(t1551, t1550) # t1552: "cuda:0 f32[1, 32, 512, 128]" - t1555 = ltorch.mul(t1549, sin) # t1555: "cuda:0 f32[1, 32, 512, 128]" - # t1553 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1553: "cuda:0 f32[1, 32, 512, 128]" - # t1554 = prims.convert_element_type(t1549, dtypes.float32) # t1554: "cuda:0 f32[1, 32, 512, 128]" - # t1555 = prims.mul(t1554, t1553) # t1555: "cuda:0 f32[1, 32, 512, 128]" - t1556 = ltorch.add(t1552, t1555, alpha=None) # t1556: "cuda:0 f32[1, 32, 512, 128]" - # t1556 = prims.add(t1552, t1555) # t1556: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1557 = ltorch.to(t1556, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1557: "cuda:0 bf16[1, 32, 512, 128]" - # t1557 = prims.convert_element_type(t1556, dtypes.bfloat16) # t1557: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t1558 = ltorch.getitem(t1525, (..., slice(128, None, None))) # t1558: "cuda:0 bf16[1, 32, 512, 0]" - # t1558 = prims.slice_prim(t1525, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1558: "cuda:0 bf16[1, 32, 512, 0]" - t1559 = ltorch.cat((t1542, t1558), -1) # t1559: "cuda:0 bf16[1, 32, 512, 128]" - # t1559 = prims.cat((t1542, t1558), -1) # t1559: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t1560 = ltorch.getitem(t1526, (..., slice(128, None, None))) # t1560: "cuda:0 bf16[1, 32, 512, 0]" - # t1560 = prims.slice_prim(t1526, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1560: "cuda:0 bf16[1, 32, 512, 0]" - t1561 = ltorch.cat((t1557, t1560), -1) # t1561: "cuda:0 bf16[1, 32, 512, 128]" - # t1561 = prims.cat((t1557, t1560), -1) # t1561: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t1591 = ltorch.scaled_dot_product_attention(t1559, t1561, t1527, None, 0.0, True, scale=0.08838834764831843) # t1591: "cuda:0 bf16[1, 32, 512, 128]" - # t1564 = ltorch.mul(t1559, 0.29730177875068026) # t1564: "cuda:0 bf16[1, 32, 512, 128]" - # t1562 = prims.convert_element_type(t1559, dtypes.float32) # t1562: "cuda:0 f32[1, 32, 512, 128]" - # t1563 = prims.mul(t1562, 0.29730177875068026) # t1563: "cuda:0 f32[1, 32, 512, 128]" - # t1564 = prims.convert_element_type(t1563, dtypes.bfloat16) # t1564: "cuda:0 bf16[1, 32, 512, 128]" - # t1565 = ltorch.transpose(t1561, -2, -1) # t1565: "cuda:0 bf16[1, 32, 128, 512]" - # t1565 = prims.transpose(t1561, (0, 1, 3, 2)) # t1565: "cuda:0 bf16[1, 32, 128, 512]" - # t1568 = ltorch.mul(t1565, 0.29730177875068026) # t1568: "cuda:0 bf16[1, 32, 128, 512]" - # t1566 = prims.convert_element_type(t1565, dtypes.float32) # t1566: "cuda:0 f32[1, 32, 128, 512]" - # t1567 = prims.mul(t1566, 0.29730177875068026) # t1567: "cuda:0 f32[1, 32, 128, 512]" - # t1568 = prims.convert_element_type(t1567, dtypes.bfloat16) # t1568: "cuda:0 bf16[1, 32, 128, 512]" - # t1569 = ltorch.matmul(t1564, t1568) # t1569: "cuda:0 bf16[1, 32, 512, 512]" - # t1569 = prims.matmul(t1564, t1568) # t1569: "cuda:0 bf16[1, 32, 512, 512]" - # t1579 = ltorch.tril(t1569, 0, fill_value=-float('inf')) # t1579: "cuda:0 bf16[1, 32, 512, 512]" - # t1570 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1570: "cuda:0 i64[512]" - # t1570 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1570: "cuda:0 i64[512]" - # t1571 = ltorch.unsqueeze(t1570, -1) # t1571: "cuda:0 i64[512, 1]" - # t1571 = prims.broadcast_in_dim(t1570, [512, 1], [0]) # t1571: "cuda:0 i64[512, 1]" - # t1572 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1572: "cuda:0 i64[512]" - # t1572 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1572: "cuda:0 i64[512]" - # t1573 = ltorch.unsqueeze(t1572, -2) # t1573: "cuda:0 i64[1, 512]" - # t1573 = prims.broadcast_in_dim(t1572, [1, 512], [1]) # t1573: "cuda:0 i64[1, 512]" - # t1574 = ltorch.add(t1571, 0, alpha=None) # t1574: "cuda:0 i64[512, 1]" - # t1574 = prims.add(t1571, 0) # t1574: "cuda:0 i64[512, 1]" - # t1577 = ltorch.ge(t1574, t1573) # t1577: "cuda:0 b8[512, 512]" - # t1575 = prims.broadcast_in_dim(t1574, (512, 512), (0, 1)) # t1575: "cuda:0 i64[512, 512]" - # t1576 = prims.broadcast_in_dim(t1573, (512, 512), (0, 1)) # t1576: "cuda:0 i64[512, 512]" - # t1577 = prims.ge(t1575, t1576) # t1577: "cuda:0 b8[512, 512]" - # t1579 = ltorch.where(t1577, t1569, -float('inf')) # t1579: "cuda:0 bf16[1, 32, 512, 512]" - # t1578 = prims.broadcast_in_dim(t1577, (1, 32, 512, 512), (2, 3)) # t1578: "cuda:0 b8[1, 32, 512, 512]" - # t1579 = prims.where(t1578, t1569, -float('inf')) # t1579: "cuda:0 bf16[1, 32, 512, 512]" - # t1590 = ltorch._softmax(t1579, -1, dtype=None) # t1590: "cuda:0 bf16[1, 32, 512, 512]" - # t1580 = ltorch.to(t1579, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t1580: "cuda:0 f32[1, 32, 512, 512]" - # t1580 = prims.convert_element_type(t1579, dtypes.float32) # t1580: "cuda:0 f32[1, 32, 512, 512]" - # t1582 = ltorch.amax(t1580, -1, True) # t1582: "cuda:0 f32[1, 32, 512, 1]" - # t1581 = prims.amax(t1580, (3,)) # t1581: "cuda:0 f32[1, 32, 512]" - # t1582 = prims.broadcast_in_dim(t1581, [1, 32, 512, 1], [0, 1, 2]) # t1582: "cuda:0 f32[1, 32, 512, 1]" - # t1584 = ltorch.sub(t1580, t1582, alpha=None) # t1584: "cuda:0 f32[1, 32, 512, 512]" - # t1583 = prims.broadcast_in_dim(t1582, (1, 32, 512, 512), (0, 1, 2, 3)) # t1583: "cuda:0 f32[1, 32, 512, 512]" - # t1584 = prims.sub(t1580, t1583) # t1584: "cuda:0 f32[1, 32, 512, 512]" - # t1585 = ltorch.exp(t1584) # t1585: "cuda:0 f32[1, 32, 512, 512]" - # t1585 = prims.exp(t1584) # t1585: "cuda:0 f32[1, 32, 512, 512]" - # t1587 = ltorch.sum(t1585, -1, True, dtype=None) # t1587: "cuda:0 f32[1, 32, 512, 1]" - # t1586 = prims.sum(t1585, (3,)) # t1586: "cuda:0 f32[1, 32, 512]" - # t1587 = prims.broadcast_in_dim(t1586, [1, 32, 512, 1], [0, 1, 2]) # t1587: "cuda:0 f32[1, 32, 512, 1]" - # t1589 = ltorch.true_divide(t1585, t1587) # t1589: "cuda:0 f32[1, 32, 512, 512]" - # t1588 = prims.broadcast_in_dim(t1587, (1, 32, 512, 512), (0, 1, 2, 3)) # t1588: "cuda:0 f32[1, 32, 512, 512]" - # t1589 = prims.div(t1585, t1588) # t1589: "cuda:0 f32[1, 32, 512, 512]" - # t1590 = ltorch.to(t1589, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t1590: "cuda:0 bf16[1, 32, 512, 512]" - # t1590 = prims.convert_element_type(t1589, dtypes.bfloat16) # t1590: "cuda:0 bf16[1, 32, 512, 512]" - # t1591 = ltorch.matmul(t1590, t1527) # t1591: "cuda:0 bf16[1, 32, 512, 128]" - # t1591 = prims.matmul(t1590, t1527) # t1591: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t1592 = ltorch.transpose(t1591, 1, 2) # t1592: "cuda:0 bf16[1, 512, 32, 128]" - # t1592 = prims.transpose(t1591, (0, 2, 1, 3)) # t1592: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t1593 = ltorch.reshape(t1592, 1, 512, 4096) # t1593: "cuda:0 bf16[1, 512, 4096]" - # t1593 = prims.reshape(t1592, (1, 512, 4096)) # t1593: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1597 = ltorch.linear(t1593, t_transformer_h_9_attn_proj_weight, None) # t1597: "cuda:0 bf16[1, 512, 4096]" - # t1597 = prims.linear(t1593, t_transformer_h_9_attn_proj_weight, None) # t1597: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t1601 = ltorch.add(t1597, t1491, alpha=None) # t1601: "cuda:0 bf16[1, 512, 4096]" - # t1598 = prims.convert_element_type(t1597, dtypes.float32) # t1598: "cuda:0 f32[1, 512, 4096]" - # t1599 = prims.convert_element_type(t1491, dtypes.float32) # t1599: "cuda:0 f32[1, 512, 4096]" - # t1600 = prims.add(t1598, t1599) # t1600: "cuda:0 f32[1, 512, 4096]" - # t1601 = prims.convert_element_type(t1600, dtypes.bfloat16) # t1601: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1602 = prims.convert_element_type(t1601, dtypes.float32) # t1602: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1603 = ltorch.mul(t1602, t1602) # t1603: "cuda:0 f32[1, 512, 4096]" - # t1603 = prims.mul(t1602, t1602) # t1603: "cuda:0 f32[1, 512, 4096]" - t1607 = ltorch.mean(t1603, -1, True, dtype=None) # t1607: "cuda:0 f32[1, 512, 1]" - # t1605 = prims.sum(t1603, (2,)) # t1605: "cuda:0 f32[1, 512]" - # t1606 = prims.broadcast_in_dim(t1605, [1, 512, 1], [0, 1]) # t1606: "cuda:0 f32[1, 512, 1]" - # t1607 = ltorch.true_divide(t1606, 4096) # t1607: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1607 = prims.div(t1606, 4096.0) # t1607: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1609 = ltorch.add(t1607, 1e-05, alpha=None) # t1609: "cuda:0 f32[1, 512, 1]" - # t1609 = prims.add(t1607, 1e-05) # t1609: "cuda:0 f32[1, 512, 1]" - t1610 = ltorch.rsqrt(t1609) # t1610: "cuda:0 f32[1, 512, 1]" - # t1610 = prims.rsqrt(t1609) # t1610: "cuda:0 f32[1, 512, 1]" - t1612 = ltorch.mul(t1602, t1610) # t1612: "cuda:0 f32[1, 512, 4096]" - # t1611 = prims.broadcast_in_dim(t1610, (1, 512, 4096), (0, 1, 2)) # t1611: "cuda:0 f32[1, 512, 4096]" - # t1612 = prims.mul(t1602, t1611) # t1612: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1613 = ltorch.to(t1612, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1613: "cuda:0 bf16[1, 512, 4096]" - # t1613 = prims.convert_element_type(t1612, dtypes.bfloat16) # t1613: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1623 = ltorch.mul(t1613, t_transformer_h_9_norm_2_weight) # t1623: "cuda:0 bf16[1, 512, 4096]" - # t1619 = prims.broadcast_in_dim(t_transformer_h_9_norm_2_weight, (1, 512, 4096), (2,)) # t1619: "cuda:0 bf16[1, 512, 4096]" - # t1620 = prims.convert_element_type(t1613, dtypes.float32) # t1620: "cuda:0 f32[1, 512, 4096]" - # t1621 = prims.convert_element_type(t1619, dtypes.float32) # t1621: "cuda:0 f32[1, 512, 4096]" - # t1622 = prims.mul(t1620, t1621) # t1622: "cuda:0 f32[1, 512, 4096]" - # t1623 = prims.convert_element_type(t1622, dtypes.bfloat16) # t1623: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1628 = ltorch.linear(t1623, t_transformer_h_9_mlp_fc_1_weight, None) # t1628: "cuda:0 bf16[1, 512, 11008]" - # t1628 = prims.linear(t1623, t_transformer_h_9_mlp_fc_1_weight, None) # t1628: "cuda:0 bf16[1, 512, 11008]" - t1632 = ltorch.linear(t1623, t_transformer_h_9_mlp_fc_2_weight, None) # t1632: "cuda:0 bf16[1, 512, 11008]" - # t1632 = prims.linear(t1623, t_transformer_h_9_mlp_fc_2_weight, None) # t1632: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t1642 = ltorch.silu(t1628, False) # t1642: "cuda:0 bf16[1, 512, 11008]" - # t1633 = prims.convert_element_type(t1628, dtypes.float32) # t1633: "cuda:0 f32[1, 512, 11008]" - # t1634 = prims.neg(t1633) # t1634: "cuda:0 f32[1, 512, 11008]" - # t1635 = prims.exp(t1634) # t1635: "cuda:0 f32[1, 512, 11008]" - # t1636 = prims.add(1.0, t1635) # t1636: "cuda:0 f32[1, 512, 11008]" - # t1637 = prims.reciprocal(t1636) # t1637: "cuda:0 f32[1, 512, 11008]" - # t1638 = prims.convert_element_type(t1637, dtypes.bfloat16) # t1638: "cuda:0 bf16[1, 512, 11008]" - # t1639 = prims.convert_element_type(t1628, dtypes.float32) # t1639: "cuda:0 f32[1, 512, 11008]" - # t1640 = prims.convert_element_type(t1638, dtypes.float32) # t1640: "cuda:0 f32[1, 512, 11008]" - # t1641 = prims.mul(t1639, t1640) # t1641: "cuda:0 f32[1, 512, 11008]" - # t1642 = prims.convert_element_type(t1641, dtypes.bfloat16) # t1642: "cuda:0 bf16[1, 512, 11008]" - t1646 = ltorch.mul(t1642, t1632) # t1646: "cuda:0 bf16[1, 512, 11008]" - # t1643 = prims.convert_element_type(t1642, dtypes.float32) # t1643: "cuda:0 f32[1, 512, 11008]" - # t1644 = prims.convert_element_type(t1632, dtypes.float32) # t1644: "cuda:0 f32[1, 512, 11008]" - # t1645 = prims.mul(t1643, t1644) # t1645: "cuda:0 f32[1, 512, 11008]" - # t1646 = prims.convert_element_type(t1645, dtypes.bfloat16) # t1646: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1650 = ltorch.linear(t1646, t_transformer_h_9_mlp_proj_weight, None) # t1650: "cuda:0 bf16[1, 512, 4096]" - # t1650 = prims.linear(t1646, t_transformer_h_9_mlp_proj_weight, None) # t1650: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t1654 = ltorch.add(t1650, t1601, alpha=None) # t1654: "cuda:0 bf16[1, 512, 4096]" - # t1651 = prims.convert_element_type(t1650, dtypes.float32) # t1651: "cuda:0 f32[1, 512, 4096]" - # t1652 = prims.convert_element_type(t1601, dtypes.float32) # t1652: "cuda:0 f32[1, 512, 4096]" - # t1653 = prims.add(t1651, t1652) # t1653: "cuda:0 f32[1, 512, 4096]" - # t1654 = prims.convert_element_type(t1653, dtypes.bfloat16) # t1654: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1656 = prims.convert_element_type(t1654, dtypes.float32) # t1656: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1657 = ltorch.mul(t1656, t1656) # t1657: "cuda:0 f32[1, 512, 4096]" - # t1657 = prims.mul(t1656, t1656) # t1657: "cuda:0 f32[1, 512, 4096]" - t1661 = ltorch.mean(t1657, -1, True, dtype=None) # t1661: "cuda:0 f32[1, 512, 1]" - # t1659 = prims.sum(t1657, (2,)) # t1659: "cuda:0 f32[1, 512]" - # t1660 = prims.broadcast_in_dim(t1659, [1, 512, 1], [0, 1]) # t1660: "cuda:0 f32[1, 512, 1]" - # t1661 = ltorch.true_divide(t1660, 4096) # t1661: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1661 = prims.div(t1660, 4096.0) # t1661: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1663 = ltorch.add(t1661, 1e-05, alpha=None) # t1663: "cuda:0 f32[1, 512, 1]" - # t1663 = prims.add(t1661, 1e-05) # t1663: "cuda:0 f32[1, 512, 1]" - t1664 = ltorch.rsqrt(t1663) # t1664: "cuda:0 f32[1, 512, 1]" - # t1664 = prims.rsqrt(t1663) # t1664: "cuda:0 f32[1, 512, 1]" - t1666 = ltorch.mul(t1656, t1664) # t1666: "cuda:0 f32[1, 512, 4096]" - # t1665 = prims.broadcast_in_dim(t1664, (1, 512, 4096), (0, 1, 2)) # t1665: "cuda:0 f32[1, 512, 4096]" - # t1666 = prims.mul(t1656, t1665) # t1666: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1667 = ltorch.to(t1666, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1667: "cuda:0 bf16[1, 512, 4096]" - # t1667 = prims.convert_element_type(t1666, dtypes.bfloat16) # t1667: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1677 = ltorch.mul(t1667, t_transformer_h_10_norm_1_weight) # t1677: "cuda:0 bf16[1, 512, 4096]" - # t1673 = prims.broadcast_in_dim(t_transformer_h_10_norm_1_weight, (1, 512, 4096), (2,)) # t1673: "cuda:0 bf16[1, 512, 4096]" - # t1674 = prims.convert_element_type(t1667, dtypes.float32) # t1674: "cuda:0 f32[1, 512, 4096]" - # t1675 = prims.convert_element_type(t1673, dtypes.float32) # t1675: "cuda:0 f32[1, 512, 4096]" - # t1676 = prims.mul(t1674, t1675) # t1676: "cuda:0 f32[1, 512, 4096]" - # t1677 = prims.convert_element_type(t1676, dtypes.bfloat16) # t1677: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1682 = ltorch.linear(t1677, t_transformer_h_10_attn_attn_weight, None) # t1682: "cuda:0 bf16[1, 512, 12288]" - # t1682 = prims.linear(t1677, t_transformer_h_10_attn_attn_weight, None) # t1682: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t1683 = ltorch.view(t1682, 1, 512, 32, 3, 128) # t1683: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1683 = ltorch.reshape(t1682, (1, 512, 32, 3, 128)) # t1683: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1683 = prims.reshape(t1682, (1, 512, 32, 3, 128)) # t1683: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t1684 = ltorch.permute(t1683, 0, 2, 3, 1, 4) # t1684: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t1684 = prims.transpose(t1683, (0, 2, 3, 1, 4)) # t1684: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t1685, t1686, t1687) = ltorch.split(t1684, (1, 1, 1), 2) - # t1685 = prims.slice_prim(t1684, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1685: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1686 = prims.slice_prim(t1684, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1686: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1687 = prims.slice_prim(t1684, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1687: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t1688 = ltorch.reshape(t1685, 1, -1, 512, 128) # t1688: "cuda:0 bf16[1, 32, 512, 128]" - # t1688 = prims.reshape(t1685, (1, 32, 512, 128)) # t1688: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t1689 = ltorch.reshape(t1686, 1, -1, 512, 128) # t1689: "cuda:0 bf16[1, 32, 512, 128]" - # t1689 = prims.reshape(t1686, (1, 32, 512, 128)) # t1689: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t1690 = ltorch.reshape(t1687, 1, -1, 512, 128) # t1690: "cuda:0 bf16[1, 32, 512, 128]" - # t1690 = prims.reshape(t1687, (1, 32, 512, 128)) # t1690: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t1691 = ltorch.getitem(t1688, (..., slice(None, 128, None))) # t1691: "cuda:0 bf16[1, 32, 512, 128]" - # t1691 = prims.slice_prim(t1688, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1691: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1692 = ltorch.getitem(t1691, (..., slice(None, 64, None))) # t1692: "cuda:0 bf16[1, 32, 512, 64]" - # t1692 = prims.slice_prim(t1691, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1692: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1693 = ltorch.getitem(t1691, (..., slice(64, None, None))) # t1693: "cuda:0 bf16[1, 32, 512, 64]" - # t1693 = prims.slice_prim(t1691, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1693: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1696 = ltorch.neg(t1693) # t1696: "cuda:0 bf16[1, 32, 512, 64]" - # t1694 = prims.convert_element_type(t1693, dtypes.float32) # t1694: "cuda:0 f32[1, 32, 512, 64]" - # t1695 = prims.neg(t1694) # t1695: "cuda:0 f32[1, 32, 512, 64]" - # t1696 = prims.convert_element_type(t1695, dtypes.bfloat16) # t1696: "cuda:0 bf16[1, 32, 512, 64]" - t1697 = ltorch.cat((t1696, t1692), -1) # t1697: "cuda:0 bf16[1, 32, 512, 128]" - # t1697 = prims.cat((t1696, t1692), -1) # t1697: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1700 = ltorch.mul(t1691, cos) # t1700: "cuda:0 f32[1, 32, 512, 128]" - # t1698 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1698: "cuda:0 f32[1, 32, 512, 128]" - # t1699 = prims.convert_element_type(t1691, dtypes.float32) # t1699: "cuda:0 f32[1, 32, 512, 128]" - # t1700 = prims.mul(t1699, t1698) # t1700: "cuda:0 f32[1, 32, 512, 128]" - t1703 = ltorch.mul(t1697, sin) # t1703: "cuda:0 f32[1, 32, 512, 128]" - # t1701 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1701: "cuda:0 f32[1, 32, 512, 128]" - # t1702 = prims.convert_element_type(t1697, dtypes.float32) # t1702: "cuda:0 f32[1, 32, 512, 128]" - # t1703 = prims.mul(t1702, t1701) # t1703: "cuda:0 f32[1, 32, 512, 128]" - t1704 = ltorch.add(t1700, t1703, alpha=None) # t1704: "cuda:0 f32[1, 32, 512, 128]" - # t1704 = prims.add(t1700, t1703) # t1704: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1705 = ltorch.to(t1704, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1705: "cuda:0 bf16[1, 32, 512, 128]" - # t1705 = prims.convert_element_type(t1704, dtypes.bfloat16) # t1705: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t1706 = ltorch.getitem(t1689, (..., slice(None, 128, None))) # t1706: "cuda:0 bf16[1, 32, 512, 128]" - # t1706 = prims.slice_prim(t1689, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1706: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1707 = ltorch.getitem(t1706, (..., slice(None, 64, None))) # t1707: "cuda:0 bf16[1, 32, 512, 64]" - # t1707 = prims.slice_prim(t1706, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1707: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1708 = ltorch.getitem(t1706, (..., slice(64, None, None))) # t1708: "cuda:0 bf16[1, 32, 512, 64]" - # t1708 = prims.slice_prim(t1706, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1708: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1711 = ltorch.neg(t1708) # t1711: "cuda:0 bf16[1, 32, 512, 64]" - # t1709 = prims.convert_element_type(t1708, dtypes.float32) # t1709: "cuda:0 f32[1, 32, 512, 64]" - # t1710 = prims.neg(t1709) # t1710: "cuda:0 f32[1, 32, 512, 64]" - # t1711 = prims.convert_element_type(t1710, dtypes.bfloat16) # t1711: "cuda:0 bf16[1, 32, 512, 64]" - t1712 = ltorch.cat((t1711, t1707), -1) # t1712: "cuda:0 bf16[1, 32, 512, 128]" - # t1712 = prims.cat((t1711, t1707), -1) # t1712: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1715 = ltorch.mul(t1706, cos) # t1715: "cuda:0 f32[1, 32, 512, 128]" - # t1713 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1713: "cuda:0 f32[1, 32, 512, 128]" - # t1714 = prims.convert_element_type(t1706, dtypes.float32) # t1714: "cuda:0 f32[1, 32, 512, 128]" - # t1715 = prims.mul(t1714, t1713) # t1715: "cuda:0 f32[1, 32, 512, 128]" - t1718 = ltorch.mul(t1712, sin) # t1718: "cuda:0 f32[1, 32, 512, 128]" - # t1716 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1716: "cuda:0 f32[1, 32, 512, 128]" - # t1717 = prims.convert_element_type(t1712, dtypes.float32) # t1717: "cuda:0 f32[1, 32, 512, 128]" - # t1718 = prims.mul(t1717, t1716) # t1718: "cuda:0 f32[1, 32, 512, 128]" - t1719 = ltorch.add(t1715, t1718, alpha=None) # t1719: "cuda:0 f32[1, 32, 512, 128]" - # t1719 = prims.add(t1715, t1718) # t1719: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1720 = ltorch.to(t1719, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1720: "cuda:0 bf16[1, 32, 512, 128]" - # t1720 = prims.convert_element_type(t1719, dtypes.bfloat16) # t1720: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t1721 = ltorch.getitem(t1688, (..., slice(128, None, None))) # t1721: "cuda:0 bf16[1, 32, 512, 0]" - # t1721 = prims.slice_prim(t1688, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1721: "cuda:0 bf16[1, 32, 512, 0]" - t1722 = ltorch.cat((t1705, t1721), -1) # t1722: "cuda:0 bf16[1, 32, 512, 128]" - # t1722 = prims.cat((t1705, t1721), -1) # t1722: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t1723 = ltorch.getitem(t1689, (..., slice(128, None, None))) # t1723: "cuda:0 bf16[1, 32, 512, 0]" - # t1723 = prims.slice_prim(t1689, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1723: "cuda:0 bf16[1, 32, 512, 0]" - t1724 = ltorch.cat((t1720, t1723), -1) # t1724: "cuda:0 bf16[1, 32, 512, 128]" - # t1724 = prims.cat((t1720, t1723), -1) # t1724: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t1754 = ltorch.scaled_dot_product_attention(t1722, t1724, t1690, None, 0.0, True, scale=0.08838834764831843) # t1754: "cuda:0 bf16[1, 32, 512, 128]" - # t1727 = ltorch.mul(t1722, 0.29730177875068026) # t1727: "cuda:0 bf16[1, 32, 512, 128]" - # t1725 = prims.convert_element_type(t1722, dtypes.float32) # t1725: "cuda:0 f32[1, 32, 512, 128]" - # t1726 = prims.mul(t1725, 0.29730177875068026) # t1726: "cuda:0 f32[1, 32, 512, 128]" - # t1727 = prims.convert_element_type(t1726, dtypes.bfloat16) # t1727: "cuda:0 bf16[1, 32, 512, 128]" - # t1728 = ltorch.transpose(t1724, -2, -1) # t1728: "cuda:0 bf16[1, 32, 128, 512]" - # t1728 = prims.transpose(t1724, (0, 1, 3, 2)) # t1728: "cuda:0 bf16[1, 32, 128, 512]" - # t1731 = ltorch.mul(t1728, 0.29730177875068026) # t1731: "cuda:0 bf16[1, 32, 128, 512]" - # t1729 = prims.convert_element_type(t1728, dtypes.float32) # t1729: "cuda:0 f32[1, 32, 128, 512]" - # t1730 = prims.mul(t1729, 0.29730177875068026) # t1730: "cuda:0 f32[1, 32, 128, 512]" - # t1731 = prims.convert_element_type(t1730, dtypes.bfloat16) # t1731: "cuda:0 bf16[1, 32, 128, 512]" - # t1732 = ltorch.matmul(t1727, t1731) # t1732: "cuda:0 bf16[1, 32, 512, 512]" - # t1732 = prims.matmul(t1727, t1731) # t1732: "cuda:0 bf16[1, 32, 512, 512]" - # t1742 = ltorch.tril(t1732, 0, fill_value=-float('inf')) # t1742: "cuda:0 bf16[1, 32, 512, 512]" - # t1733 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1733: "cuda:0 i64[512]" - # t1733 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1733: "cuda:0 i64[512]" - # t1734 = ltorch.unsqueeze(t1733, -1) # t1734: "cuda:0 i64[512, 1]" - # t1734 = prims.broadcast_in_dim(t1733, [512, 1], [0]) # t1734: "cuda:0 i64[512, 1]" - # t1735 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1735: "cuda:0 i64[512]" - # t1735 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1735: "cuda:0 i64[512]" - # t1736 = ltorch.unsqueeze(t1735, -2) # t1736: "cuda:0 i64[1, 512]" - # t1736 = prims.broadcast_in_dim(t1735, [1, 512], [1]) # t1736: "cuda:0 i64[1, 512]" - # t1737 = ltorch.add(t1734, 0, alpha=None) # t1737: "cuda:0 i64[512, 1]" - # t1737 = prims.add(t1734, 0) # t1737: "cuda:0 i64[512, 1]" - # t1740 = ltorch.ge(t1737, t1736) # t1740: "cuda:0 b8[512, 512]" - # t1738 = prims.broadcast_in_dim(t1737, (512, 512), (0, 1)) # t1738: "cuda:0 i64[512, 512]" - # t1739 = prims.broadcast_in_dim(t1736, (512, 512), (0, 1)) # t1739: "cuda:0 i64[512, 512]" - # t1740 = prims.ge(t1738, t1739) # t1740: "cuda:0 b8[512, 512]" - # t1742 = ltorch.where(t1740, t1732, -float('inf')) # t1742: "cuda:0 bf16[1, 32, 512, 512]" - # t1741 = prims.broadcast_in_dim(t1740, (1, 32, 512, 512), (2, 3)) # t1741: "cuda:0 b8[1, 32, 512, 512]" - # t1742 = prims.where(t1741, t1732, -float('inf')) # t1742: "cuda:0 bf16[1, 32, 512, 512]" - # t1753 = ltorch._softmax(t1742, -1, dtype=None) # t1753: "cuda:0 bf16[1, 32, 512, 512]" - # t1743 = ltorch.to(t1742, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t1743: "cuda:0 f32[1, 32, 512, 512]" - # t1743 = prims.convert_element_type(t1742, dtypes.float32) # t1743: "cuda:0 f32[1, 32, 512, 512]" - # t1745 = ltorch.amax(t1743, -1, True) # t1745: "cuda:0 f32[1, 32, 512, 1]" - # t1744 = prims.amax(t1743, (3,)) # t1744: "cuda:0 f32[1, 32, 512]" - # t1745 = prims.broadcast_in_dim(t1744, [1, 32, 512, 1], [0, 1, 2]) # t1745: "cuda:0 f32[1, 32, 512, 1]" - # t1747 = ltorch.sub(t1743, t1745, alpha=None) # t1747: "cuda:0 f32[1, 32, 512, 512]" - # t1746 = prims.broadcast_in_dim(t1745, (1, 32, 512, 512), (0, 1, 2, 3)) # t1746: "cuda:0 f32[1, 32, 512, 512]" - # t1747 = prims.sub(t1743, t1746) # t1747: "cuda:0 f32[1, 32, 512, 512]" - # t1748 = ltorch.exp(t1747) # t1748: "cuda:0 f32[1, 32, 512, 512]" - # t1748 = prims.exp(t1747) # t1748: "cuda:0 f32[1, 32, 512, 512]" - # t1750 = ltorch.sum(t1748, -1, True, dtype=None) # t1750: "cuda:0 f32[1, 32, 512, 1]" - # t1749 = prims.sum(t1748, (3,)) # t1749: "cuda:0 f32[1, 32, 512]" - # t1750 = prims.broadcast_in_dim(t1749, [1, 32, 512, 1], [0, 1, 2]) # t1750: "cuda:0 f32[1, 32, 512, 1]" - # t1752 = ltorch.true_divide(t1748, t1750) # t1752: "cuda:0 f32[1, 32, 512, 512]" - # t1751 = prims.broadcast_in_dim(t1750, (1, 32, 512, 512), (0, 1, 2, 3)) # t1751: "cuda:0 f32[1, 32, 512, 512]" - # t1752 = prims.div(t1748, t1751) # t1752: "cuda:0 f32[1, 32, 512, 512]" - # t1753 = ltorch.to(t1752, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t1753: "cuda:0 bf16[1, 32, 512, 512]" - # t1753 = prims.convert_element_type(t1752, dtypes.bfloat16) # t1753: "cuda:0 bf16[1, 32, 512, 512]" - # t1754 = ltorch.matmul(t1753, t1690) # t1754: "cuda:0 bf16[1, 32, 512, 128]" - # t1754 = prims.matmul(t1753, t1690) # t1754: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t1755 = ltorch.transpose(t1754, 1, 2) # t1755: "cuda:0 bf16[1, 512, 32, 128]" - # t1755 = prims.transpose(t1754, (0, 2, 1, 3)) # t1755: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t1756 = ltorch.reshape(t1755, 1, 512, 4096) # t1756: "cuda:0 bf16[1, 512, 4096]" - # t1756 = prims.reshape(t1755, (1, 512, 4096)) # t1756: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1760 = ltorch.linear(t1756, t_transformer_h_10_attn_proj_weight, None) # t1760: "cuda:0 bf16[1, 512, 4096]" - # t1760 = prims.linear(t1756, t_transformer_h_10_attn_proj_weight, None) # t1760: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t1764 = ltorch.add(t1760, t1654, alpha=None) # t1764: "cuda:0 bf16[1, 512, 4096]" - # t1761 = prims.convert_element_type(t1760, dtypes.float32) # t1761: "cuda:0 f32[1, 512, 4096]" - # t1762 = prims.convert_element_type(t1654, dtypes.float32) # t1762: "cuda:0 f32[1, 512, 4096]" - # t1763 = prims.add(t1761, t1762) # t1763: "cuda:0 f32[1, 512, 4096]" - # t1764 = prims.convert_element_type(t1763, dtypes.bfloat16) # t1764: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1765 = prims.convert_element_type(t1764, dtypes.float32) # t1765: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1766 = ltorch.mul(t1765, t1765) # t1766: "cuda:0 f32[1, 512, 4096]" - # t1766 = prims.mul(t1765, t1765) # t1766: "cuda:0 f32[1, 512, 4096]" - t1770 = ltorch.mean(t1766, -1, True, dtype=None) # t1770: "cuda:0 f32[1, 512, 1]" - # t1768 = prims.sum(t1766, (2,)) # t1768: "cuda:0 f32[1, 512]" - # t1769 = prims.broadcast_in_dim(t1768, [1, 512, 1], [0, 1]) # t1769: "cuda:0 f32[1, 512, 1]" - # t1770 = ltorch.true_divide(t1769, 4096) # t1770: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1770 = prims.div(t1769, 4096.0) # t1770: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1772 = ltorch.add(t1770, 1e-05, alpha=None) # t1772: "cuda:0 f32[1, 512, 1]" - # t1772 = prims.add(t1770, 1e-05) # t1772: "cuda:0 f32[1, 512, 1]" - t1773 = ltorch.rsqrt(t1772) # t1773: "cuda:0 f32[1, 512, 1]" - # t1773 = prims.rsqrt(t1772) # t1773: "cuda:0 f32[1, 512, 1]" - t1775 = ltorch.mul(t1765, t1773) # t1775: "cuda:0 f32[1, 512, 4096]" - # t1774 = prims.broadcast_in_dim(t1773, (1, 512, 4096), (0, 1, 2)) # t1774: "cuda:0 f32[1, 512, 4096]" - # t1775 = prims.mul(t1765, t1774) # t1775: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1776 = ltorch.to(t1775, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1776: "cuda:0 bf16[1, 512, 4096]" - # t1776 = prims.convert_element_type(t1775, dtypes.bfloat16) # t1776: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1786 = ltorch.mul(t1776, t_transformer_h_10_norm_2_weight) # t1786: "cuda:0 bf16[1, 512, 4096]" - # t1782 = prims.broadcast_in_dim(t_transformer_h_10_norm_2_weight, (1, 512, 4096), (2,)) # t1782: "cuda:0 bf16[1, 512, 4096]" - # t1783 = prims.convert_element_type(t1776, dtypes.float32) # t1783: "cuda:0 f32[1, 512, 4096]" - # t1784 = prims.convert_element_type(t1782, dtypes.float32) # t1784: "cuda:0 f32[1, 512, 4096]" - # t1785 = prims.mul(t1783, t1784) # t1785: "cuda:0 f32[1, 512, 4096]" - # t1786 = prims.convert_element_type(t1785, dtypes.bfloat16) # t1786: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1791 = ltorch.linear(t1786, t_transformer_h_10_mlp_fc_1_weight, None) # t1791: "cuda:0 bf16[1, 512, 11008]" - # t1791 = prims.linear(t1786, t_transformer_h_10_mlp_fc_1_weight, None) # t1791: "cuda:0 bf16[1, 512, 11008]" - t1795 = ltorch.linear(t1786, t_transformer_h_10_mlp_fc_2_weight, None) # t1795: "cuda:0 bf16[1, 512, 11008]" - # t1795 = prims.linear(t1786, t_transformer_h_10_mlp_fc_2_weight, None) # t1795: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t1805 = ltorch.silu(t1791, False) # t1805: "cuda:0 bf16[1, 512, 11008]" - # t1796 = prims.convert_element_type(t1791, dtypes.float32) # t1796: "cuda:0 f32[1, 512, 11008]" - # t1797 = prims.neg(t1796) # t1797: "cuda:0 f32[1, 512, 11008]" - # t1798 = prims.exp(t1797) # t1798: "cuda:0 f32[1, 512, 11008]" - # t1799 = prims.add(1.0, t1798) # t1799: "cuda:0 f32[1, 512, 11008]" - # t1800 = prims.reciprocal(t1799) # t1800: "cuda:0 f32[1, 512, 11008]" - # t1801 = prims.convert_element_type(t1800, dtypes.bfloat16) # t1801: "cuda:0 bf16[1, 512, 11008]" - # t1802 = prims.convert_element_type(t1791, dtypes.float32) # t1802: "cuda:0 f32[1, 512, 11008]" - # t1803 = prims.convert_element_type(t1801, dtypes.float32) # t1803: "cuda:0 f32[1, 512, 11008]" - # t1804 = prims.mul(t1802, t1803) # t1804: "cuda:0 f32[1, 512, 11008]" - # t1805 = prims.convert_element_type(t1804, dtypes.bfloat16) # t1805: "cuda:0 bf16[1, 512, 11008]" - t1809 = ltorch.mul(t1805, t1795) # t1809: "cuda:0 bf16[1, 512, 11008]" - # t1806 = prims.convert_element_type(t1805, dtypes.float32) # t1806: "cuda:0 f32[1, 512, 11008]" - # t1807 = prims.convert_element_type(t1795, dtypes.float32) # t1807: "cuda:0 f32[1, 512, 11008]" - # t1808 = prims.mul(t1806, t1807) # t1808: "cuda:0 f32[1, 512, 11008]" - # t1809 = prims.convert_element_type(t1808, dtypes.bfloat16) # t1809: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1813 = ltorch.linear(t1809, t_transformer_h_10_mlp_proj_weight, None) # t1813: "cuda:0 bf16[1, 512, 4096]" - # t1813 = prims.linear(t1809, t_transformer_h_10_mlp_proj_weight, None) # t1813: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t1817 = ltorch.add(t1813, t1764, alpha=None) # t1817: "cuda:0 bf16[1, 512, 4096]" - # t1814 = prims.convert_element_type(t1813, dtypes.float32) # t1814: "cuda:0 f32[1, 512, 4096]" - # t1815 = prims.convert_element_type(t1764, dtypes.float32) # t1815: "cuda:0 f32[1, 512, 4096]" - # t1816 = prims.add(t1814, t1815) # t1816: "cuda:0 f32[1, 512, 4096]" - # t1817 = prims.convert_element_type(t1816, dtypes.bfloat16) # t1817: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1819 = prims.convert_element_type(t1817, dtypes.float32) # t1819: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1820 = ltorch.mul(t1819, t1819) # t1820: "cuda:0 f32[1, 512, 4096]" - # t1820 = prims.mul(t1819, t1819) # t1820: "cuda:0 f32[1, 512, 4096]" - t1824 = ltorch.mean(t1820, -1, True, dtype=None) # t1824: "cuda:0 f32[1, 512, 1]" - # t1822 = prims.sum(t1820, (2,)) # t1822: "cuda:0 f32[1, 512]" - # t1823 = prims.broadcast_in_dim(t1822, [1, 512, 1], [0, 1]) # t1823: "cuda:0 f32[1, 512, 1]" - # t1824 = ltorch.true_divide(t1823, 4096) # t1824: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1824 = prims.div(t1823, 4096.0) # t1824: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1826 = ltorch.add(t1824, 1e-05, alpha=None) # t1826: "cuda:0 f32[1, 512, 1]" - # t1826 = prims.add(t1824, 1e-05) # t1826: "cuda:0 f32[1, 512, 1]" - t1827 = ltorch.rsqrt(t1826) # t1827: "cuda:0 f32[1, 512, 1]" - # t1827 = prims.rsqrt(t1826) # t1827: "cuda:0 f32[1, 512, 1]" - t1829 = ltorch.mul(t1819, t1827) # t1829: "cuda:0 f32[1, 512, 4096]" - # t1828 = prims.broadcast_in_dim(t1827, (1, 512, 4096), (0, 1, 2)) # t1828: "cuda:0 f32[1, 512, 4096]" - # t1829 = prims.mul(t1819, t1828) # t1829: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1830 = ltorch.to(t1829, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1830: "cuda:0 bf16[1, 512, 4096]" - # t1830 = prims.convert_element_type(t1829, dtypes.bfloat16) # t1830: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1840 = ltorch.mul(t1830, t_transformer_h_11_norm_1_weight) # t1840: "cuda:0 bf16[1, 512, 4096]" - # t1836 = prims.broadcast_in_dim(t_transformer_h_11_norm_1_weight, (1, 512, 4096), (2,)) # t1836: "cuda:0 bf16[1, 512, 4096]" - # t1837 = prims.convert_element_type(t1830, dtypes.float32) # t1837: "cuda:0 f32[1, 512, 4096]" - # t1838 = prims.convert_element_type(t1836, dtypes.float32) # t1838: "cuda:0 f32[1, 512, 4096]" - # t1839 = prims.mul(t1837, t1838) # t1839: "cuda:0 f32[1, 512, 4096]" - # t1840 = prims.convert_element_type(t1839, dtypes.bfloat16) # t1840: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1845 = ltorch.linear(t1840, t_transformer_h_11_attn_attn_weight, None) # t1845: "cuda:0 bf16[1, 512, 12288]" - # t1845 = prims.linear(t1840, t_transformer_h_11_attn_attn_weight, None) # t1845: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t1846 = ltorch.view(t1845, 1, 512, 32, 3, 128) # t1846: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1846 = ltorch.reshape(t1845, (1, 512, 32, 3, 128)) # t1846: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1846 = prims.reshape(t1845, (1, 512, 32, 3, 128)) # t1846: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t1847 = ltorch.permute(t1846, 0, 2, 3, 1, 4) # t1847: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t1847 = prims.transpose(t1846, (0, 2, 3, 1, 4)) # t1847: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t1848, t1849, t1850) = ltorch.split(t1847, (1, 1, 1), 2) - # t1848 = prims.slice_prim(t1847, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1848: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1849 = prims.slice_prim(t1847, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1849: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1850 = prims.slice_prim(t1847, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1850: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t1851 = ltorch.reshape(t1848, 1, -1, 512, 128) # t1851: "cuda:0 bf16[1, 32, 512, 128]" - # t1851 = prims.reshape(t1848, (1, 32, 512, 128)) # t1851: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t1852 = ltorch.reshape(t1849, 1, -1, 512, 128) # t1852: "cuda:0 bf16[1, 32, 512, 128]" - # t1852 = prims.reshape(t1849, (1, 32, 512, 128)) # t1852: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t1853 = ltorch.reshape(t1850, 1, -1, 512, 128) # t1853: "cuda:0 bf16[1, 32, 512, 128]" - # t1853 = prims.reshape(t1850, (1, 32, 512, 128)) # t1853: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t1854 = ltorch.getitem(t1851, (..., slice(None, 128, None))) # t1854: "cuda:0 bf16[1, 32, 512, 128]" - # t1854 = prims.slice_prim(t1851, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1854: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1855 = ltorch.getitem(t1854, (..., slice(None, 64, None))) # t1855: "cuda:0 bf16[1, 32, 512, 64]" - # t1855 = prims.slice_prim(t1854, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1855: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1856 = ltorch.getitem(t1854, (..., slice(64, None, None))) # t1856: "cuda:0 bf16[1, 32, 512, 64]" - # t1856 = prims.slice_prim(t1854, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1856: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1859 = ltorch.neg(t1856) # t1859: "cuda:0 bf16[1, 32, 512, 64]" - # t1857 = prims.convert_element_type(t1856, dtypes.float32) # t1857: "cuda:0 f32[1, 32, 512, 64]" - # t1858 = prims.neg(t1857) # t1858: "cuda:0 f32[1, 32, 512, 64]" - # t1859 = prims.convert_element_type(t1858, dtypes.bfloat16) # t1859: "cuda:0 bf16[1, 32, 512, 64]" - t1860 = ltorch.cat((t1859, t1855), -1) # t1860: "cuda:0 bf16[1, 32, 512, 128]" - # t1860 = prims.cat((t1859, t1855), -1) # t1860: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1863 = ltorch.mul(t1854, cos) # t1863: "cuda:0 f32[1, 32, 512, 128]" - # t1861 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1861: "cuda:0 f32[1, 32, 512, 128]" - # t1862 = prims.convert_element_type(t1854, dtypes.float32) # t1862: "cuda:0 f32[1, 32, 512, 128]" - # t1863 = prims.mul(t1862, t1861) # t1863: "cuda:0 f32[1, 32, 512, 128]" - t1866 = ltorch.mul(t1860, sin) # t1866: "cuda:0 f32[1, 32, 512, 128]" - # t1864 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1864: "cuda:0 f32[1, 32, 512, 128]" - # t1865 = prims.convert_element_type(t1860, dtypes.float32) # t1865: "cuda:0 f32[1, 32, 512, 128]" - # t1866 = prims.mul(t1865, t1864) # t1866: "cuda:0 f32[1, 32, 512, 128]" - t1867 = ltorch.add(t1863, t1866, alpha=None) # t1867: "cuda:0 f32[1, 32, 512, 128]" - # t1867 = prims.add(t1863, t1866) # t1867: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1868 = ltorch.to(t1867, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1868: "cuda:0 bf16[1, 32, 512, 128]" - # t1868 = prims.convert_element_type(t1867, dtypes.bfloat16) # t1868: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t1869 = ltorch.getitem(t1852, (..., slice(None, 128, None))) # t1869: "cuda:0 bf16[1, 32, 512, 128]" - # t1869 = prims.slice_prim(t1852, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1869: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1870 = ltorch.getitem(t1869, (..., slice(None, 64, None))) # t1870: "cuda:0 bf16[1, 32, 512, 64]" - # t1870 = prims.slice_prim(t1869, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1870: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1871 = ltorch.getitem(t1869, (..., slice(64, None, None))) # t1871: "cuda:0 bf16[1, 32, 512, 64]" - # t1871 = prims.slice_prim(t1869, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1871: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1874 = ltorch.neg(t1871) # t1874: "cuda:0 bf16[1, 32, 512, 64]" - # t1872 = prims.convert_element_type(t1871, dtypes.float32) # t1872: "cuda:0 f32[1, 32, 512, 64]" - # t1873 = prims.neg(t1872) # t1873: "cuda:0 f32[1, 32, 512, 64]" - # t1874 = prims.convert_element_type(t1873, dtypes.bfloat16) # t1874: "cuda:0 bf16[1, 32, 512, 64]" - t1875 = ltorch.cat((t1874, t1870), -1) # t1875: "cuda:0 bf16[1, 32, 512, 128]" - # t1875 = prims.cat((t1874, t1870), -1) # t1875: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1878 = ltorch.mul(t1869, cos) # t1878: "cuda:0 f32[1, 32, 512, 128]" - # t1876 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1876: "cuda:0 f32[1, 32, 512, 128]" - # t1877 = prims.convert_element_type(t1869, dtypes.float32) # t1877: "cuda:0 f32[1, 32, 512, 128]" - # t1878 = prims.mul(t1877, t1876) # t1878: "cuda:0 f32[1, 32, 512, 128]" - t1881 = ltorch.mul(t1875, sin) # t1881: "cuda:0 f32[1, 32, 512, 128]" - # t1879 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1879: "cuda:0 f32[1, 32, 512, 128]" - # t1880 = prims.convert_element_type(t1875, dtypes.float32) # t1880: "cuda:0 f32[1, 32, 512, 128]" - # t1881 = prims.mul(t1880, t1879) # t1881: "cuda:0 f32[1, 32, 512, 128]" - t1882 = ltorch.add(t1878, t1881, alpha=None) # t1882: "cuda:0 f32[1, 32, 512, 128]" - # t1882 = prims.add(t1878, t1881) # t1882: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1883 = ltorch.to(t1882, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1883: "cuda:0 bf16[1, 32, 512, 128]" - # t1883 = prims.convert_element_type(t1882, dtypes.bfloat16) # t1883: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t1884 = ltorch.getitem(t1851, (..., slice(128, None, None))) # t1884: "cuda:0 bf16[1, 32, 512, 0]" - # t1884 = prims.slice_prim(t1851, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1884: "cuda:0 bf16[1, 32, 512, 0]" - t1885 = ltorch.cat((t1868, t1884), -1) # t1885: "cuda:0 bf16[1, 32, 512, 128]" - # t1885 = prims.cat((t1868, t1884), -1) # t1885: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t1886 = ltorch.getitem(t1852, (..., slice(128, None, None))) # t1886: "cuda:0 bf16[1, 32, 512, 0]" - # t1886 = prims.slice_prim(t1852, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1886: "cuda:0 bf16[1, 32, 512, 0]" - t1887 = ltorch.cat((t1883, t1886), -1) # t1887: "cuda:0 bf16[1, 32, 512, 128]" - # t1887 = prims.cat((t1883, t1886), -1) # t1887: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t1917 = ltorch.scaled_dot_product_attention(t1885, t1887, t1853, None, 0.0, True, scale=0.08838834764831843) # t1917: "cuda:0 bf16[1, 32, 512, 128]" - # t1890 = ltorch.mul(t1885, 0.29730177875068026) # t1890: "cuda:0 bf16[1, 32, 512, 128]" - # t1888 = prims.convert_element_type(t1885, dtypes.float32) # t1888: "cuda:0 f32[1, 32, 512, 128]" - # t1889 = prims.mul(t1888, 0.29730177875068026) # t1889: "cuda:0 f32[1, 32, 512, 128]" - # t1890 = prims.convert_element_type(t1889, dtypes.bfloat16) # t1890: "cuda:0 bf16[1, 32, 512, 128]" - # t1891 = ltorch.transpose(t1887, -2, -1) # t1891: "cuda:0 bf16[1, 32, 128, 512]" - # t1891 = prims.transpose(t1887, (0, 1, 3, 2)) # t1891: "cuda:0 bf16[1, 32, 128, 512]" - # t1894 = ltorch.mul(t1891, 0.29730177875068026) # t1894: "cuda:0 bf16[1, 32, 128, 512]" - # t1892 = prims.convert_element_type(t1891, dtypes.float32) # t1892: "cuda:0 f32[1, 32, 128, 512]" - # t1893 = prims.mul(t1892, 0.29730177875068026) # t1893: "cuda:0 f32[1, 32, 128, 512]" - # t1894 = prims.convert_element_type(t1893, dtypes.bfloat16) # t1894: "cuda:0 bf16[1, 32, 128, 512]" - # t1895 = ltorch.matmul(t1890, t1894) # t1895: "cuda:0 bf16[1, 32, 512, 512]" - # t1895 = prims.matmul(t1890, t1894) # t1895: "cuda:0 bf16[1, 32, 512, 512]" - # t1905 = ltorch.tril(t1895, 0, fill_value=-float('inf')) # t1905: "cuda:0 bf16[1, 32, 512, 512]" - # t1896 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1896: "cuda:0 i64[512]" - # t1896 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1896: "cuda:0 i64[512]" - # t1897 = ltorch.unsqueeze(t1896, -1) # t1897: "cuda:0 i64[512, 1]" - # t1897 = prims.broadcast_in_dim(t1896, [512, 1], [0]) # t1897: "cuda:0 i64[512, 1]" - # t1898 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1898: "cuda:0 i64[512]" - # t1898 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1898: "cuda:0 i64[512]" - # t1899 = ltorch.unsqueeze(t1898, -2) # t1899: "cuda:0 i64[1, 512]" - # t1899 = prims.broadcast_in_dim(t1898, [1, 512], [1]) # t1899: "cuda:0 i64[1, 512]" - # t1900 = ltorch.add(t1897, 0, alpha=None) # t1900: "cuda:0 i64[512, 1]" - # t1900 = prims.add(t1897, 0) # t1900: "cuda:0 i64[512, 1]" - # t1903 = ltorch.ge(t1900, t1899) # t1903: "cuda:0 b8[512, 512]" - # t1901 = prims.broadcast_in_dim(t1900, (512, 512), (0, 1)) # t1901: "cuda:0 i64[512, 512]" - # t1902 = prims.broadcast_in_dim(t1899, (512, 512), (0, 1)) # t1902: "cuda:0 i64[512, 512]" - # t1903 = prims.ge(t1901, t1902) # t1903: "cuda:0 b8[512, 512]" - # t1905 = ltorch.where(t1903, t1895, -float('inf')) # t1905: "cuda:0 bf16[1, 32, 512, 512]" - # t1904 = prims.broadcast_in_dim(t1903, (1, 32, 512, 512), (2, 3)) # t1904: "cuda:0 b8[1, 32, 512, 512]" - # t1905 = prims.where(t1904, t1895, -float('inf')) # t1905: "cuda:0 bf16[1, 32, 512, 512]" - # t1916 = ltorch._softmax(t1905, -1, dtype=None) # t1916: "cuda:0 bf16[1, 32, 512, 512]" - # t1906 = ltorch.to(t1905, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t1906: "cuda:0 f32[1, 32, 512, 512]" - # t1906 = prims.convert_element_type(t1905, dtypes.float32) # t1906: "cuda:0 f32[1, 32, 512, 512]" - # t1908 = ltorch.amax(t1906, -1, True) # t1908: "cuda:0 f32[1, 32, 512, 1]" - # t1907 = prims.amax(t1906, (3,)) # t1907: "cuda:0 f32[1, 32, 512]" - # t1908 = prims.broadcast_in_dim(t1907, [1, 32, 512, 1], [0, 1, 2]) # t1908: "cuda:0 f32[1, 32, 512, 1]" - # t1910 = ltorch.sub(t1906, t1908, alpha=None) # t1910: "cuda:0 f32[1, 32, 512, 512]" - # t1909 = prims.broadcast_in_dim(t1908, (1, 32, 512, 512), (0, 1, 2, 3)) # t1909: "cuda:0 f32[1, 32, 512, 512]" - # t1910 = prims.sub(t1906, t1909) # t1910: "cuda:0 f32[1, 32, 512, 512]" - # t1911 = ltorch.exp(t1910) # t1911: "cuda:0 f32[1, 32, 512, 512]" - # t1911 = prims.exp(t1910) # t1911: "cuda:0 f32[1, 32, 512, 512]" - # t1913 = ltorch.sum(t1911, -1, True, dtype=None) # t1913: "cuda:0 f32[1, 32, 512, 1]" - # t1912 = prims.sum(t1911, (3,)) # t1912: "cuda:0 f32[1, 32, 512]" - # t1913 = prims.broadcast_in_dim(t1912, [1, 32, 512, 1], [0, 1, 2]) # t1913: "cuda:0 f32[1, 32, 512, 1]" - # t1915 = ltorch.true_divide(t1911, t1913) # t1915: "cuda:0 f32[1, 32, 512, 512]" - # t1914 = prims.broadcast_in_dim(t1913, (1, 32, 512, 512), (0, 1, 2, 3)) # t1914: "cuda:0 f32[1, 32, 512, 512]" - # t1915 = prims.div(t1911, t1914) # t1915: "cuda:0 f32[1, 32, 512, 512]" - # t1916 = ltorch.to(t1915, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t1916: "cuda:0 bf16[1, 32, 512, 512]" - # t1916 = prims.convert_element_type(t1915, dtypes.bfloat16) # t1916: "cuda:0 bf16[1, 32, 512, 512]" - # t1917 = ltorch.matmul(t1916, t1853) # t1917: "cuda:0 bf16[1, 32, 512, 128]" - # t1917 = prims.matmul(t1916, t1853) # t1917: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t1918 = ltorch.transpose(t1917, 1, 2) # t1918: "cuda:0 bf16[1, 512, 32, 128]" - # t1918 = prims.transpose(t1917, (0, 2, 1, 3)) # t1918: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t1919 = ltorch.reshape(t1918, 1, 512, 4096) # t1919: "cuda:0 bf16[1, 512, 4096]" - # t1919 = prims.reshape(t1918, (1, 512, 4096)) # t1919: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1923 = ltorch.linear(t1919, t_transformer_h_11_attn_proj_weight, None) # t1923: "cuda:0 bf16[1, 512, 4096]" - # t1923 = prims.linear(t1919, t_transformer_h_11_attn_proj_weight, None) # t1923: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t1927 = ltorch.add(t1923, t1817, alpha=None) # t1927: "cuda:0 bf16[1, 512, 4096]" - # t1924 = prims.convert_element_type(t1923, dtypes.float32) # t1924: "cuda:0 f32[1, 512, 4096]" - # t1925 = prims.convert_element_type(t1817, dtypes.float32) # t1925: "cuda:0 f32[1, 512, 4096]" - # t1926 = prims.add(t1924, t1925) # t1926: "cuda:0 f32[1, 512, 4096]" - # t1927 = prims.convert_element_type(t1926, dtypes.bfloat16) # t1927: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1928 = prims.convert_element_type(t1927, dtypes.float32) # t1928: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1929 = ltorch.mul(t1928, t1928) # t1929: "cuda:0 f32[1, 512, 4096]" - # t1929 = prims.mul(t1928, t1928) # t1929: "cuda:0 f32[1, 512, 4096]" - t1933 = ltorch.mean(t1929, -1, True, dtype=None) # t1933: "cuda:0 f32[1, 512, 1]" - # t1931 = prims.sum(t1929, (2,)) # t1931: "cuda:0 f32[1, 512]" - # t1932 = prims.broadcast_in_dim(t1931, [1, 512, 1], [0, 1]) # t1932: "cuda:0 f32[1, 512, 1]" - # t1933 = ltorch.true_divide(t1932, 4096) # t1933: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1933 = prims.div(t1932, 4096.0) # t1933: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1935 = ltorch.add(t1933, 1e-05, alpha=None) # t1935: "cuda:0 f32[1, 512, 1]" - # t1935 = prims.add(t1933, 1e-05) # t1935: "cuda:0 f32[1, 512, 1]" - t1936 = ltorch.rsqrt(t1935) # t1936: "cuda:0 f32[1, 512, 1]" - # t1936 = prims.rsqrt(t1935) # t1936: "cuda:0 f32[1, 512, 1]" - t1938 = ltorch.mul(t1928, t1936) # t1938: "cuda:0 f32[1, 512, 4096]" - # t1937 = prims.broadcast_in_dim(t1936, (1, 512, 4096), (0, 1, 2)) # t1937: "cuda:0 f32[1, 512, 4096]" - # t1938 = prims.mul(t1928, t1937) # t1938: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1939 = ltorch.to(t1938, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1939: "cuda:0 bf16[1, 512, 4096]" - # t1939 = prims.convert_element_type(t1938, dtypes.bfloat16) # t1939: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1949 = ltorch.mul(t1939, t_transformer_h_11_norm_2_weight) # t1949: "cuda:0 bf16[1, 512, 4096]" - # t1945 = prims.broadcast_in_dim(t_transformer_h_11_norm_2_weight, (1, 512, 4096), (2,)) # t1945: "cuda:0 bf16[1, 512, 4096]" - # t1946 = prims.convert_element_type(t1939, dtypes.float32) # t1946: "cuda:0 f32[1, 512, 4096]" - # t1947 = prims.convert_element_type(t1945, dtypes.float32) # t1947: "cuda:0 f32[1, 512, 4096]" - # t1948 = prims.mul(t1946, t1947) # t1948: "cuda:0 f32[1, 512, 4096]" - # t1949 = prims.convert_element_type(t1948, dtypes.bfloat16) # t1949: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1954 = ltorch.linear(t1949, t_transformer_h_11_mlp_fc_1_weight, None) # t1954: "cuda:0 bf16[1, 512, 11008]" - # t1954 = prims.linear(t1949, t_transformer_h_11_mlp_fc_1_weight, None) # t1954: "cuda:0 bf16[1, 512, 11008]" - t1958 = ltorch.linear(t1949, t_transformer_h_11_mlp_fc_2_weight, None) # t1958: "cuda:0 bf16[1, 512, 11008]" - # t1958 = prims.linear(t1949, t_transformer_h_11_mlp_fc_2_weight, None) # t1958: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t1968 = ltorch.silu(t1954, False) # t1968: "cuda:0 bf16[1, 512, 11008]" - # t1959 = prims.convert_element_type(t1954, dtypes.float32) # t1959: "cuda:0 f32[1, 512, 11008]" - # t1960 = prims.neg(t1959) # t1960: "cuda:0 f32[1, 512, 11008]" - # t1961 = prims.exp(t1960) # t1961: "cuda:0 f32[1, 512, 11008]" - # t1962 = prims.add(1.0, t1961) # t1962: "cuda:0 f32[1, 512, 11008]" - # t1963 = prims.reciprocal(t1962) # t1963: "cuda:0 f32[1, 512, 11008]" - # t1964 = prims.convert_element_type(t1963, dtypes.bfloat16) # t1964: "cuda:0 bf16[1, 512, 11008]" - # t1965 = prims.convert_element_type(t1954, dtypes.float32) # t1965: "cuda:0 f32[1, 512, 11008]" - # t1966 = prims.convert_element_type(t1964, dtypes.float32) # t1966: "cuda:0 f32[1, 512, 11008]" - # t1967 = prims.mul(t1965, t1966) # t1967: "cuda:0 f32[1, 512, 11008]" - # t1968 = prims.convert_element_type(t1967, dtypes.bfloat16) # t1968: "cuda:0 bf16[1, 512, 11008]" - t1972 = ltorch.mul(t1968, t1958) # t1972: "cuda:0 bf16[1, 512, 11008]" - # t1969 = prims.convert_element_type(t1968, dtypes.float32) # t1969: "cuda:0 f32[1, 512, 11008]" - # t1970 = prims.convert_element_type(t1958, dtypes.float32) # t1970: "cuda:0 f32[1, 512, 11008]" - # t1971 = prims.mul(t1969, t1970) # t1971: "cuda:0 f32[1, 512, 11008]" - # t1972 = prims.convert_element_type(t1971, dtypes.bfloat16) # t1972: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1976 = ltorch.linear(t1972, t_transformer_h_11_mlp_proj_weight, None) # t1976: "cuda:0 bf16[1, 512, 4096]" - # t1976 = prims.linear(t1972, t_transformer_h_11_mlp_proj_weight, None) # t1976: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t1980 = ltorch.add(t1976, t1927, alpha=None) # t1980: "cuda:0 bf16[1, 512, 4096]" - # t1977 = prims.convert_element_type(t1976, dtypes.float32) # t1977: "cuda:0 f32[1, 512, 4096]" - # t1978 = prims.convert_element_type(t1927, dtypes.float32) # t1978: "cuda:0 f32[1, 512, 4096]" - # t1979 = prims.add(t1977, t1978) # t1979: "cuda:0 f32[1, 512, 4096]" - # t1980 = prims.convert_element_type(t1979, dtypes.bfloat16) # t1980: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1982 = prims.convert_element_type(t1980, dtypes.float32) # t1982: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1983 = ltorch.mul(t1982, t1982) # t1983: "cuda:0 f32[1, 512, 4096]" - # t1983 = prims.mul(t1982, t1982) # t1983: "cuda:0 f32[1, 512, 4096]" - t1987 = ltorch.mean(t1983, -1, True, dtype=None) # t1987: "cuda:0 f32[1, 512, 1]" - # t1985 = prims.sum(t1983, (2,)) # t1985: "cuda:0 f32[1, 512]" - # t1986 = prims.broadcast_in_dim(t1985, [1, 512, 1], [0, 1]) # t1986: "cuda:0 f32[1, 512, 1]" - # t1987 = ltorch.true_divide(t1986, 4096) # t1987: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1987 = prims.div(t1986, 4096.0) # t1987: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1989 = ltorch.add(t1987, 1e-05, alpha=None) # t1989: "cuda:0 f32[1, 512, 1]" - # t1989 = prims.add(t1987, 1e-05) # t1989: "cuda:0 f32[1, 512, 1]" - t1990 = ltorch.rsqrt(t1989) # t1990: "cuda:0 f32[1, 512, 1]" - # t1990 = prims.rsqrt(t1989) # t1990: "cuda:0 f32[1, 512, 1]" - t1992 = ltorch.mul(t1982, t1990) # t1992: "cuda:0 f32[1, 512, 4096]" - # t1991 = prims.broadcast_in_dim(t1990, (1, 512, 4096), (0, 1, 2)) # t1991: "cuda:0 f32[1, 512, 4096]" - # t1992 = prims.mul(t1982, t1991) # t1992: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1993 = ltorch.to(t1992, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1993: "cuda:0 bf16[1, 512, 4096]" - # t1993 = prims.convert_element_type(t1992, dtypes.bfloat16) # t1993: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t2003 = ltorch.mul(t1993, t_transformer_h_12_norm_1_weight) # t2003: "cuda:0 bf16[1, 512, 4096]" - # t1999 = prims.broadcast_in_dim(t_transformer_h_12_norm_1_weight, (1, 512, 4096), (2,)) # t1999: "cuda:0 bf16[1, 512, 4096]" - # t2000 = prims.convert_element_type(t1993, dtypes.float32) # t2000: "cuda:0 f32[1, 512, 4096]" - # t2001 = prims.convert_element_type(t1999, dtypes.float32) # t2001: "cuda:0 f32[1, 512, 4096]" - # t2002 = prims.mul(t2000, t2001) # t2002: "cuda:0 f32[1, 512, 4096]" - # t2003 = prims.convert_element_type(t2002, dtypes.bfloat16) # t2003: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2008 = ltorch.linear(t2003, t_transformer_h_12_attn_attn_weight, None) # t2008: "cuda:0 bf16[1, 512, 12288]" - # t2008 = prims.linear(t2003, t_transformer_h_12_attn_attn_weight, None) # t2008: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t2009 = ltorch.view(t2008, 1, 512, 32, 3, 128) # t2009: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t2009 = ltorch.reshape(t2008, (1, 512, 32, 3, 128)) # t2009: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t2009 = prims.reshape(t2008, (1, 512, 32, 3, 128)) # t2009: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t2010 = ltorch.permute(t2009, 0, 2, 3, 1, 4) # t2010: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t2010 = prims.transpose(t2009, (0, 2, 3, 1, 4)) # t2010: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t2011, t2012, t2013) = ltorch.split(t2010, (1, 1, 1), 2) - # t2011 = prims.slice_prim(t2010, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t2011: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2012 = prims.slice_prim(t2010, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t2012: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2013 = prims.slice_prim(t2010, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t2013: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t2014 = ltorch.reshape(t2011, 1, -1, 512, 128) # t2014: "cuda:0 bf16[1, 32, 512, 128]" - # t2014 = prims.reshape(t2011, (1, 32, 512, 128)) # t2014: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t2015 = ltorch.reshape(t2012, 1, -1, 512, 128) # t2015: "cuda:0 bf16[1, 32, 512, 128]" - # t2015 = prims.reshape(t2012, (1, 32, 512, 128)) # t2015: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t2016 = ltorch.reshape(t2013, 1, -1, 512, 128) # t2016: "cuda:0 bf16[1, 32, 512, 128]" - # t2016 = prims.reshape(t2013, (1, 32, 512, 128)) # t2016: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t2017 = ltorch.getitem(t2014, (..., slice(None, 128, None))) # t2017: "cuda:0 bf16[1, 32, 512, 128]" - # t2017 = prims.slice_prim(t2014, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2017: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t2018 = ltorch.getitem(t2017, (..., slice(None, 64, None))) # t2018: "cuda:0 bf16[1, 32, 512, 64]" - # t2018 = prims.slice_prim(t2017, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2018: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t2019 = ltorch.getitem(t2017, (..., slice(64, None, None))) # t2019: "cuda:0 bf16[1, 32, 512, 64]" - # t2019 = prims.slice_prim(t2017, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2019: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t2022 = ltorch.neg(t2019) # t2022: "cuda:0 bf16[1, 32, 512, 64]" - # t2020 = prims.convert_element_type(t2019, dtypes.float32) # t2020: "cuda:0 f32[1, 32, 512, 64]" - # t2021 = prims.neg(t2020) # t2021: "cuda:0 f32[1, 32, 512, 64]" - # t2022 = prims.convert_element_type(t2021, dtypes.bfloat16) # t2022: "cuda:0 bf16[1, 32, 512, 64]" - t2023 = ltorch.cat((t2022, t2018), -1) # t2023: "cuda:0 bf16[1, 32, 512, 128]" - # t2023 = prims.cat((t2022, t2018), -1) # t2023: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t2026 = ltorch.mul(t2017, cos) # t2026: "cuda:0 f32[1, 32, 512, 128]" - # t2024 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t2024: "cuda:0 f32[1, 32, 512, 128]" - # t2025 = prims.convert_element_type(t2017, dtypes.float32) # t2025: "cuda:0 f32[1, 32, 512, 128]" - # t2026 = prims.mul(t2025, t2024) # t2026: "cuda:0 f32[1, 32, 512, 128]" - t2029 = ltorch.mul(t2023, sin) # t2029: "cuda:0 f32[1, 32, 512, 128]" - # t2027 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t2027: "cuda:0 f32[1, 32, 512, 128]" - # t2028 = prims.convert_element_type(t2023, dtypes.float32) # t2028: "cuda:0 f32[1, 32, 512, 128]" - # t2029 = prims.mul(t2028, t2027) # t2029: "cuda:0 f32[1, 32, 512, 128]" - t2030 = ltorch.add(t2026, t2029, alpha=None) # t2030: "cuda:0 f32[1, 32, 512, 128]" - # t2030 = prims.add(t2026, t2029) # t2030: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t2031 = ltorch.to(t2030, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2031: "cuda:0 bf16[1, 32, 512, 128]" - # t2031 = prims.convert_element_type(t2030, dtypes.bfloat16) # t2031: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t2032 = ltorch.getitem(t2015, (..., slice(None, 128, None))) # t2032: "cuda:0 bf16[1, 32, 512, 128]" - # t2032 = prims.slice_prim(t2015, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2032: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t2033 = ltorch.getitem(t2032, (..., slice(None, 64, None))) # t2033: "cuda:0 bf16[1, 32, 512, 64]" - # t2033 = prims.slice_prim(t2032, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2033: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t2034 = ltorch.getitem(t2032, (..., slice(64, None, None))) # t2034: "cuda:0 bf16[1, 32, 512, 64]" - # t2034 = prims.slice_prim(t2032, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2034: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t2037 = ltorch.neg(t2034) # t2037: "cuda:0 bf16[1, 32, 512, 64]" - # t2035 = prims.convert_element_type(t2034, dtypes.float32) # t2035: "cuda:0 f32[1, 32, 512, 64]" - # t2036 = prims.neg(t2035) # t2036: "cuda:0 f32[1, 32, 512, 64]" - # t2037 = prims.convert_element_type(t2036, dtypes.bfloat16) # t2037: "cuda:0 bf16[1, 32, 512, 64]" - t2038 = ltorch.cat((t2037, t2033), -1) # t2038: "cuda:0 bf16[1, 32, 512, 128]" - # t2038 = prims.cat((t2037, t2033), -1) # t2038: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t2041 = ltorch.mul(t2032, cos) # t2041: "cuda:0 f32[1, 32, 512, 128]" - # t2039 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t2039: "cuda:0 f32[1, 32, 512, 128]" - # t2040 = prims.convert_element_type(t2032, dtypes.float32) # t2040: "cuda:0 f32[1, 32, 512, 128]" - # t2041 = prims.mul(t2040, t2039) # t2041: "cuda:0 f32[1, 32, 512, 128]" - t2044 = ltorch.mul(t2038, sin) # t2044: "cuda:0 f32[1, 32, 512, 128]" - # t2042 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t2042: "cuda:0 f32[1, 32, 512, 128]" - # t2043 = prims.convert_element_type(t2038, dtypes.float32) # t2043: "cuda:0 f32[1, 32, 512, 128]" - # t2044 = prims.mul(t2043, t2042) # t2044: "cuda:0 f32[1, 32, 512, 128]" - t2045 = ltorch.add(t2041, t2044, alpha=None) # t2045: "cuda:0 f32[1, 32, 512, 128]" - # t2045 = prims.add(t2041, t2044) # t2045: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t2046 = ltorch.to(t2045, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2046: "cuda:0 bf16[1, 32, 512, 128]" - # t2046 = prims.convert_element_type(t2045, dtypes.bfloat16) # t2046: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t2047 = ltorch.getitem(t2014, (..., slice(128, None, None))) # t2047: "cuda:0 bf16[1, 32, 512, 0]" - # t2047 = prims.slice_prim(t2014, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2047: "cuda:0 bf16[1, 32, 512, 0]" - t2048 = ltorch.cat((t2031, t2047), -1) # t2048: "cuda:0 bf16[1, 32, 512, 128]" - # t2048 = prims.cat((t2031, t2047), -1) # t2048: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t2049 = ltorch.getitem(t2015, (..., slice(128, None, None))) # t2049: "cuda:0 bf16[1, 32, 512, 0]" - # t2049 = prims.slice_prim(t2015, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2049: "cuda:0 bf16[1, 32, 512, 0]" - t2050 = ltorch.cat((t2046, t2049), -1) # t2050: "cuda:0 bf16[1, 32, 512, 128]" - # t2050 = prims.cat((t2046, t2049), -1) # t2050: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t2080 = ltorch.scaled_dot_product_attention(t2048, t2050, t2016, None, 0.0, True, scale=0.08838834764831843) # t2080: "cuda:0 bf16[1, 32, 512, 128]" - # t2053 = ltorch.mul(t2048, 0.29730177875068026) # t2053: "cuda:0 bf16[1, 32, 512, 128]" - # t2051 = prims.convert_element_type(t2048, dtypes.float32) # t2051: "cuda:0 f32[1, 32, 512, 128]" - # t2052 = prims.mul(t2051, 0.29730177875068026) # t2052: "cuda:0 f32[1, 32, 512, 128]" - # t2053 = prims.convert_element_type(t2052, dtypes.bfloat16) # t2053: "cuda:0 bf16[1, 32, 512, 128]" - # t2054 = ltorch.transpose(t2050, -2, -1) # t2054: "cuda:0 bf16[1, 32, 128, 512]" - # t2054 = prims.transpose(t2050, (0, 1, 3, 2)) # t2054: "cuda:0 bf16[1, 32, 128, 512]" - # t2057 = ltorch.mul(t2054, 0.29730177875068026) # t2057: "cuda:0 bf16[1, 32, 128, 512]" - # t2055 = prims.convert_element_type(t2054, dtypes.float32) # t2055: "cuda:0 f32[1, 32, 128, 512]" - # t2056 = prims.mul(t2055, 0.29730177875068026) # t2056: "cuda:0 f32[1, 32, 128, 512]" - # t2057 = prims.convert_element_type(t2056, dtypes.bfloat16) # t2057: "cuda:0 bf16[1, 32, 128, 512]" - # t2058 = ltorch.matmul(t2053, t2057) # t2058: "cuda:0 bf16[1, 32, 512, 512]" - # t2058 = prims.matmul(t2053, t2057) # t2058: "cuda:0 bf16[1, 32, 512, 512]" - # t2068 = ltorch.tril(t2058, 0, fill_value=-float('inf')) # t2068: "cuda:0 bf16[1, 32, 512, 512]" - # t2059 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t2059: "cuda:0 i64[512]" - # t2059 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t2059: "cuda:0 i64[512]" - # t2060 = ltorch.unsqueeze(t2059, -1) # t2060: "cuda:0 i64[512, 1]" - # t2060 = prims.broadcast_in_dim(t2059, [512, 1], [0]) # t2060: "cuda:0 i64[512, 1]" - # t2061 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t2061: "cuda:0 i64[512]" - # t2061 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t2061: "cuda:0 i64[512]" - # t2062 = ltorch.unsqueeze(t2061, -2) # t2062: "cuda:0 i64[1, 512]" - # t2062 = prims.broadcast_in_dim(t2061, [1, 512], [1]) # t2062: "cuda:0 i64[1, 512]" - # t2063 = ltorch.add(t2060, 0, alpha=None) # t2063: "cuda:0 i64[512, 1]" - # t2063 = prims.add(t2060, 0) # t2063: "cuda:0 i64[512, 1]" - # t2066 = ltorch.ge(t2063, t2062) # t2066: "cuda:0 b8[512, 512]" - # t2064 = prims.broadcast_in_dim(t2063, (512, 512), (0, 1)) # t2064: "cuda:0 i64[512, 512]" - # t2065 = prims.broadcast_in_dim(t2062, (512, 512), (0, 1)) # t2065: "cuda:0 i64[512, 512]" - # t2066 = prims.ge(t2064, t2065) # t2066: "cuda:0 b8[512, 512]" - # t2068 = ltorch.where(t2066, t2058, -float('inf')) # t2068: "cuda:0 bf16[1, 32, 512, 512]" - # t2067 = prims.broadcast_in_dim(t2066, (1, 32, 512, 512), (2, 3)) # t2067: "cuda:0 b8[1, 32, 512, 512]" - # t2068 = prims.where(t2067, t2058, -float('inf')) # t2068: "cuda:0 bf16[1, 32, 512, 512]" - # t2079 = ltorch._softmax(t2068, -1, dtype=None) # t2079: "cuda:0 bf16[1, 32, 512, 512]" - # t2069 = ltorch.to(t2068, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t2069: "cuda:0 f32[1, 32, 512, 512]" - # t2069 = prims.convert_element_type(t2068, dtypes.float32) # t2069: "cuda:0 f32[1, 32, 512, 512]" - # t2071 = ltorch.amax(t2069, -1, True) # t2071: "cuda:0 f32[1, 32, 512, 1]" - # t2070 = prims.amax(t2069, (3,)) # t2070: "cuda:0 f32[1, 32, 512]" - # t2071 = prims.broadcast_in_dim(t2070, [1, 32, 512, 1], [0, 1, 2]) # t2071: "cuda:0 f32[1, 32, 512, 1]" - # t2073 = ltorch.sub(t2069, t2071, alpha=None) # t2073: "cuda:0 f32[1, 32, 512, 512]" - # t2072 = prims.broadcast_in_dim(t2071, (1, 32, 512, 512), (0, 1, 2, 3)) # t2072: "cuda:0 f32[1, 32, 512, 512]" - # t2073 = prims.sub(t2069, t2072) # t2073: "cuda:0 f32[1, 32, 512, 512]" - # t2074 = ltorch.exp(t2073) # t2074: "cuda:0 f32[1, 32, 512, 512]" - # t2074 = prims.exp(t2073) # t2074: "cuda:0 f32[1, 32, 512, 512]" - # t2076 = ltorch.sum(t2074, -1, True, dtype=None) # t2076: "cuda:0 f32[1, 32, 512, 1]" - # t2075 = prims.sum(t2074, (3,)) # t2075: "cuda:0 f32[1, 32, 512]" - # t2076 = prims.broadcast_in_dim(t2075, [1, 32, 512, 1], [0, 1, 2]) # t2076: "cuda:0 f32[1, 32, 512, 1]" - # t2078 = ltorch.true_divide(t2074, t2076) # t2078: "cuda:0 f32[1, 32, 512, 512]" - # t2077 = prims.broadcast_in_dim(t2076, (1, 32, 512, 512), (0, 1, 2, 3)) # t2077: "cuda:0 f32[1, 32, 512, 512]" - # t2078 = prims.div(t2074, t2077) # t2078: "cuda:0 f32[1, 32, 512, 512]" - # t2079 = ltorch.to(t2078, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t2079: "cuda:0 bf16[1, 32, 512, 512]" - # t2079 = prims.convert_element_type(t2078, dtypes.bfloat16) # t2079: "cuda:0 bf16[1, 32, 512, 512]" - # t2080 = ltorch.matmul(t2079, t2016) # t2080: "cuda:0 bf16[1, 32, 512, 128]" - # t2080 = prims.matmul(t2079, t2016) # t2080: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t2081 = ltorch.transpose(t2080, 1, 2) # t2081: "cuda:0 bf16[1, 512, 32, 128]" - # t2081 = prims.transpose(t2080, (0, 2, 1, 3)) # t2081: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t2082 = ltorch.reshape(t2081, 1, 512, 4096) # t2082: "cuda:0 bf16[1, 512, 4096]" - # t2082 = prims.reshape(t2081, (1, 512, 4096)) # t2082: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2086 = ltorch.linear(t2082, t_transformer_h_12_attn_proj_weight, None) # t2086: "cuda:0 bf16[1, 512, 4096]" - # t2086 = prims.linear(t2082, t_transformer_h_12_attn_proj_weight, None) # t2086: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t2090 = ltorch.add(t2086, t1980, alpha=None) # t2090: "cuda:0 bf16[1, 512, 4096]" - # t2087 = prims.convert_element_type(t2086, dtypes.float32) # t2087: "cuda:0 f32[1, 512, 4096]" - # t2088 = prims.convert_element_type(t1980, dtypes.float32) # t2088: "cuda:0 f32[1, 512, 4096]" - # t2089 = prims.add(t2087, t2088) # t2089: "cuda:0 f32[1, 512, 4096]" - # t2090 = prims.convert_element_type(t2089, dtypes.bfloat16) # t2090: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t2091 = prims.convert_element_type(t2090, dtypes.float32) # t2091: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t2092 = ltorch.mul(t2091, t2091) # t2092: "cuda:0 f32[1, 512, 4096]" - # t2092 = prims.mul(t2091, t2091) # t2092: "cuda:0 f32[1, 512, 4096]" - t2096 = ltorch.mean(t2092, -1, True, dtype=None) # t2096: "cuda:0 f32[1, 512, 1]" - # t2094 = prims.sum(t2092, (2,)) # t2094: "cuda:0 f32[1, 512]" - # t2095 = prims.broadcast_in_dim(t2094, [1, 512, 1], [0, 1]) # t2095: "cuda:0 f32[1, 512, 1]" - # t2096 = ltorch.true_divide(t2095, 4096) # t2096: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t2096 = prims.div(t2095, 4096.0) # t2096: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t2098 = ltorch.add(t2096, 1e-05, alpha=None) # t2098: "cuda:0 f32[1, 512, 1]" - # t2098 = prims.add(t2096, 1e-05) # t2098: "cuda:0 f32[1, 512, 1]" - t2099 = ltorch.rsqrt(t2098) # t2099: "cuda:0 f32[1, 512, 1]" - # t2099 = prims.rsqrt(t2098) # t2099: "cuda:0 f32[1, 512, 1]" - t2101 = ltorch.mul(t2091, t2099) # t2101: "cuda:0 f32[1, 512, 4096]" - # t2100 = prims.broadcast_in_dim(t2099, (1, 512, 4096), (0, 1, 2)) # t2100: "cuda:0 f32[1, 512, 4096]" - # t2101 = prims.mul(t2091, t2100) # t2101: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t2102 = ltorch.to(t2101, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2102: "cuda:0 bf16[1, 512, 4096]" - # t2102 = prims.convert_element_type(t2101, dtypes.bfloat16) # t2102: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t2112 = ltorch.mul(t2102, t_transformer_h_12_norm_2_weight) # t2112: "cuda:0 bf16[1, 512, 4096]" - # t2108 = prims.broadcast_in_dim(t_transformer_h_12_norm_2_weight, (1, 512, 4096), (2,)) # t2108: "cuda:0 bf16[1, 512, 4096]" - # t2109 = prims.convert_element_type(t2102, dtypes.float32) # t2109: "cuda:0 f32[1, 512, 4096]" - # t2110 = prims.convert_element_type(t2108, dtypes.float32) # t2110: "cuda:0 f32[1, 512, 4096]" - # t2111 = prims.mul(t2109, t2110) # t2111: "cuda:0 f32[1, 512, 4096]" - # t2112 = prims.convert_element_type(t2111, dtypes.bfloat16) # t2112: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2117 = ltorch.linear(t2112, t_transformer_h_12_mlp_fc_1_weight, None) # t2117: "cuda:0 bf16[1, 512, 11008]" - # t2117 = prims.linear(t2112, t_transformer_h_12_mlp_fc_1_weight, None) # t2117: "cuda:0 bf16[1, 512, 11008]" - t2121 = ltorch.linear(t2112, t_transformer_h_12_mlp_fc_2_weight, None) # t2121: "cuda:0 bf16[1, 512, 11008]" - # t2121 = prims.linear(t2112, t_transformer_h_12_mlp_fc_2_weight, None) # t2121: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t2131 = ltorch.silu(t2117, False) # t2131: "cuda:0 bf16[1, 512, 11008]" - # t2122 = prims.convert_element_type(t2117, dtypes.float32) # t2122: "cuda:0 f32[1, 512, 11008]" - # t2123 = prims.neg(t2122) # t2123: "cuda:0 f32[1, 512, 11008]" - # t2124 = prims.exp(t2123) # t2124: "cuda:0 f32[1, 512, 11008]" - # t2125 = prims.add(1.0, t2124) # t2125: "cuda:0 f32[1, 512, 11008]" - # t2126 = prims.reciprocal(t2125) # t2126: "cuda:0 f32[1, 512, 11008]" - # t2127 = prims.convert_element_type(t2126, dtypes.bfloat16) # t2127: "cuda:0 bf16[1, 512, 11008]" - # t2128 = prims.convert_element_type(t2117, dtypes.float32) # t2128: "cuda:0 f32[1, 512, 11008]" - # t2129 = prims.convert_element_type(t2127, dtypes.float32) # t2129: "cuda:0 f32[1, 512, 11008]" - # t2130 = prims.mul(t2128, t2129) # t2130: "cuda:0 f32[1, 512, 11008]" - # t2131 = prims.convert_element_type(t2130, dtypes.bfloat16) # t2131: "cuda:0 bf16[1, 512, 11008]" - t2135 = ltorch.mul(t2131, t2121) # t2135: "cuda:0 bf16[1, 512, 11008]" - # t2132 = prims.convert_element_type(t2131, dtypes.float32) # t2132: "cuda:0 f32[1, 512, 11008]" - # t2133 = prims.convert_element_type(t2121, dtypes.float32) # t2133: "cuda:0 f32[1, 512, 11008]" - # t2134 = prims.mul(t2132, t2133) # t2134: "cuda:0 f32[1, 512, 11008]" - # t2135 = prims.convert_element_type(t2134, dtypes.bfloat16) # t2135: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2139 = ltorch.linear(t2135, t_transformer_h_12_mlp_proj_weight, None) # t2139: "cuda:0 bf16[1, 512, 4096]" - # t2139 = prims.linear(t2135, t_transformer_h_12_mlp_proj_weight, None) # t2139: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t2143 = ltorch.add(t2139, t2090, alpha=None) # t2143: "cuda:0 bf16[1, 512, 4096]" - # t2140 = prims.convert_element_type(t2139, dtypes.float32) # t2140: "cuda:0 f32[1, 512, 4096]" - # t2141 = prims.convert_element_type(t2090, dtypes.float32) # t2141: "cuda:0 f32[1, 512, 4096]" - # t2142 = prims.add(t2140, t2141) # t2142: "cuda:0 f32[1, 512, 4096]" - # t2143 = prims.convert_element_type(t2142, dtypes.bfloat16) # t2143: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t2145 = prims.convert_element_type(t2143, dtypes.float32) # t2145: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t2146 = ltorch.mul(t2145, t2145) # t2146: "cuda:0 f32[1, 512, 4096]" - # t2146 = prims.mul(t2145, t2145) # t2146: "cuda:0 f32[1, 512, 4096]" - t2150 = ltorch.mean(t2146, -1, True, dtype=None) # t2150: "cuda:0 f32[1, 512, 1]" - # t2148 = prims.sum(t2146, (2,)) # t2148: "cuda:0 f32[1, 512]" - # t2149 = prims.broadcast_in_dim(t2148, [1, 512, 1], [0, 1]) # t2149: "cuda:0 f32[1, 512, 1]" - # t2150 = ltorch.true_divide(t2149, 4096) # t2150: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t2150 = prims.div(t2149, 4096.0) # t2150: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t2152 = ltorch.add(t2150, 1e-05, alpha=None) # t2152: "cuda:0 f32[1, 512, 1]" - # t2152 = prims.add(t2150, 1e-05) # t2152: "cuda:0 f32[1, 512, 1]" - t2153 = ltorch.rsqrt(t2152) # t2153: "cuda:0 f32[1, 512, 1]" - # t2153 = prims.rsqrt(t2152) # t2153: "cuda:0 f32[1, 512, 1]" - t2155 = ltorch.mul(t2145, t2153) # t2155: "cuda:0 f32[1, 512, 4096]" - # t2154 = prims.broadcast_in_dim(t2153, (1, 512, 4096), (0, 1, 2)) # t2154: "cuda:0 f32[1, 512, 4096]" - # t2155 = prims.mul(t2145, t2154) # t2155: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t2156 = ltorch.to(t2155, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2156: "cuda:0 bf16[1, 512, 4096]" - # t2156 = prims.convert_element_type(t2155, dtypes.bfloat16) # t2156: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t2166 = ltorch.mul(t2156, t_transformer_h_13_norm_1_weight) # t2166: "cuda:0 bf16[1, 512, 4096]" - # t2162 = prims.broadcast_in_dim(t_transformer_h_13_norm_1_weight, (1, 512, 4096), (2,)) # t2162: "cuda:0 bf16[1, 512, 4096]" - # t2163 = prims.convert_element_type(t2156, dtypes.float32) # t2163: "cuda:0 f32[1, 512, 4096]" - # t2164 = prims.convert_element_type(t2162, dtypes.float32) # t2164: "cuda:0 f32[1, 512, 4096]" - # t2165 = prims.mul(t2163, t2164) # t2165: "cuda:0 f32[1, 512, 4096]" - # t2166 = prims.convert_element_type(t2165, dtypes.bfloat16) # t2166: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2171 = ltorch.linear(t2166, t_transformer_h_13_attn_attn_weight, None) # t2171: "cuda:0 bf16[1, 512, 12288]" - # t2171 = prims.linear(t2166, t_transformer_h_13_attn_attn_weight, None) # t2171: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t2172 = ltorch.view(t2171, 1, 512, 32, 3, 128) # t2172: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t2172 = ltorch.reshape(t2171, (1, 512, 32, 3, 128)) # t2172: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t2172 = prims.reshape(t2171, (1, 512, 32, 3, 128)) # t2172: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t2173 = ltorch.permute(t2172, 0, 2, 3, 1, 4) # t2173: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t2173 = prims.transpose(t2172, (0, 2, 3, 1, 4)) # t2173: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t2174, t2175, t2176) = ltorch.split(t2173, (1, 1, 1), 2) - # t2174 = prims.slice_prim(t2173, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t2174: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2175 = prims.slice_prim(t2173, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t2175: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2176 = prims.slice_prim(t2173, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t2176: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t2177 = ltorch.reshape(t2174, 1, -1, 512, 128) # t2177: "cuda:0 bf16[1, 32, 512, 128]" - # t2177 = prims.reshape(t2174, (1, 32, 512, 128)) # t2177: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t2178 = ltorch.reshape(t2175, 1, -1, 512, 128) # t2178: "cuda:0 bf16[1, 32, 512, 128]" - # t2178 = prims.reshape(t2175, (1, 32, 512, 128)) # t2178: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t2179 = ltorch.reshape(t2176, 1, -1, 512, 128) # t2179: "cuda:0 bf16[1, 32, 512, 128]" - # t2179 = prims.reshape(t2176, (1, 32, 512, 128)) # t2179: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t2180 = ltorch.getitem(t2177, (..., slice(None, 128, None))) # t2180: "cuda:0 bf16[1, 32, 512, 128]" - # t2180 = prims.slice_prim(t2177, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2180: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t2181 = ltorch.getitem(t2180, (..., slice(None, 64, None))) # t2181: "cuda:0 bf16[1, 32, 512, 64]" - # t2181 = prims.slice_prim(t2180, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2181: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t2182 = ltorch.getitem(t2180, (..., slice(64, None, None))) # t2182: "cuda:0 bf16[1, 32, 512, 64]" - # t2182 = prims.slice_prim(t2180, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2182: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t2185 = ltorch.neg(t2182) # t2185: "cuda:0 bf16[1, 32, 512, 64]" - # t2183 = prims.convert_element_type(t2182, dtypes.float32) # t2183: "cuda:0 f32[1, 32, 512, 64]" - # t2184 = prims.neg(t2183) # t2184: "cuda:0 f32[1, 32, 512, 64]" - # t2185 = prims.convert_element_type(t2184, dtypes.bfloat16) # t2185: "cuda:0 bf16[1, 32, 512, 64]" - t2186 = ltorch.cat((t2185, t2181), -1) # t2186: "cuda:0 bf16[1, 32, 512, 128]" - # t2186 = prims.cat((t2185, t2181), -1) # t2186: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t2189 = ltorch.mul(t2180, cos) # t2189: "cuda:0 f32[1, 32, 512, 128]" - # t2187 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t2187: "cuda:0 f32[1, 32, 512, 128]" - # t2188 = prims.convert_element_type(t2180, dtypes.float32) # t2188: "cuda:0 f32[1, 32, 512, 128]" - # t2189 = prims.mul(t2188, t2187) # t2189: "cuda:0 f32[1, 32, 512, 128]" - t2192 = ltorch.mul(t2186, sin) # t2192: "cuda:0 f32[1, 32, 512, 128]" - # t2190 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t2190: "cuda:0 f32[1, 32, 512, 128]" - # t2191 = prims.convert_element_type(t2186, dtypes.float32) # t2191: "cuda:0 f32[1, 32, 512, 128]" - # t2192 = prims.mul(t2191, t2190) # t2192: "cuda:0 f32[1, 32, 512, 128]" - t2193 = ltorch.add(t2189, t2192, alpha=None) # t2193: "cuda:0 f32[1, 32, 512, 128]" - # t2193 = prims.add(t2189, t2192) # t2193: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t2194 = ltorch.to(t2193, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2194: "cuda:0 bf16[1, 32, 512, 128]" - # t2194 = prims.convert_element_type(t2193, dtypes.bfloat16) # t2194: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t2195 = ltorch.getitem(t2178, (..., slice(None, 128, None))) # t2195: "cuda:0 bf16[1, 32, 512, 128]" - # t2195 = prims.slice_prim(t2178, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2195: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t2196 = ltorch.getitem(t2195, (..., slice(None, 64, None))) # t2196: "cuda:0 bf16[1, 32, 512, 64]" - # t2196 = prims.slice_prim(t2195, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2196: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t2197 = ltorch.getitem(t2195, (..., slice(64, None, None))) # t2197: "cuda:0 bf16[1, 32, 512, 64]" - # t2197 = prims.slice_prim(t2195, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2197: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t2200 = ltorch.neg(t2197) # t2200: "cuda:0 bf16[1, 32, 512, 64]" - # t2198 = prims.convert_element_type(t2197, dtypes.float32) # t2198: "cuda:0 f32[1, 32, 512, 64]" - # t2199 = prims.neg(t2198) # t2199: "cuda:0 f32[1, 32, 512, 64]" - # t2200 = prims.convert_element_type(t2199, dtypes.bfloat16) # t2200: "cuda:0 bf16[1, 32, 512, 64]" - t2201 = ltorch.cat((t2200, t2196), -1) # t2201: "cuda:0 bf16[1, 32, 512, 128]" - # t2201 = prims.cat((t2200, t2196), -1) # t2201: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t2204 = ltorch.mul(t2195, cos) # t2204: "cuda:0 f32[1, 32, 512, 128]" - # t2202 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t2202: "cuda:0 f32[1, 32, 512, 128]" - # t2203 = prims.convert_element_type(t2195, dtypes.float32) # t2203: "cuda:0 f32[1, 32, 512, 128]" - # t2204 = prims.mul(t2203, t2202) # t2204: "cuda:0 f32[1, 32, 512, 128]" - t2207 = ltorch.mul(t2201, sin) # t2207: "cuda:0 f32[1, 32, 512, 128]" - # t2205 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t2205: "cuda:0 f32[1, 32, 512, 128]" - # t2206 = prims.convert_element_type(t2201, dtypes.float32) # t2206: "cuda:0 f32[1, 32, 512, 128]" - # t2207 = prims.mul(t2206, t2205) # t2207: "cuda:0 f32[1, 32, 512, 128]" - t2208 = ltorch.add(t2204, t2207, alpha=None) # t2208: "cuda:0 f32[1, 32, 512, 128]" - # t2208 = prims.add(t2204, t2207) # t2208: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t2209 = ltorch.to(t2208, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2209: "cuda:0 bf16[1, 32, 512, 128]" - # t2209 = prims.convert_element_type(t2208, dtypes.bfloat16) # t2209: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t2210 = ltorch.getitem(t2177, (..., slice(128, None, None))) # t2210: "cuda:0 bf16[1, 32, 512, 0]" - # t2210 = prims.slice_prim(t2177, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2210: "cuda:0 bf16[1, 32, 512, 0]" - t2211 = ltorch.cat((t2194, t2210), -1) # t2211: "cuda:0 bf16[1, 32, 512, 128]" - # t2211 = prims.cat((t2194, t2210), -1) # t2211: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t2212 = ltorch.getitem(t2178, (..., slice(128, None, None))) # t2212: "cuda:0 bf16[1, 32, 512, 0]" - # t2212 = prims.slice_prim(t2178, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2212: "cuda:0 bf16[1, 32, 512, 0]" - t2213 = ltorch.cat((t2209, t2212), -1) # t2213: "cuda:0 bf16[1, 32, 512, 128]" - # t2213 = prims.cat((t2209, t2212), -1) # t2213: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t2243 = ltorch.scaled_dot_product_attention(t2211, t2213, t2179, None, 0.0, True, scale=0.08838834764831843) # t2243: "cuda:0 bf16[1, 32, 512, 128]" - # t2216 = ltorch.mul(t2211, 0.29730177875068026) # t2216: "cuda:0 bf16[1, 32, 512, 128]" - # t2214 = prims.convert_element_type(t2211, dtypes.float32) # t2214: "cuda:0 f32[1, 32, 512, 128]" - # t2215 = prims.mul(t2214, 0.29730177875068026) # t2215: "cuda:0 f32[1, 32, 512, 128]" - # t2216 = prims.convert_element_type(t2215, dtypes.bfloat16) # t2216: "cuda:0 bf16[1, 32, 512, 128]" - # t2217 = ltorch.transpose(t2213, -2, -1) # t2217: "cuda:0 bf16[1, 32, 128, 512]" - # t2217 = prims.transpose(t2213, (0, 1, 3, 2)) # t2217: "cuda:0 bf16[1, 32, 128, 512]" - # t2220 = ltorch.mul(t2217, 0.29730177875068026) # t2220: "cuda:0 bf16[1, 32, 128, 512]" - # t2218 = prims.convert_element_type(t2217, dtypes.float32) # t2218: "cuda:0 f32[1, 32, 128, 512]" - # t2219 = prims.mul(t2218, 0.29730177875068026) # t2219: "cuda:0 f32[1, 32, 128, 512]" - # t2220 = prims.convert_element_type(t2219, dtypes.bfloat16) # t2220: "cuda:0 bf16[1, 32, 128, 512]" - # t2221 = ltorch.matmul(t2216, t2220) # t2221: "cuda:0 bf16[1, 32, 512, 512]" - # t2221 = prims.matmul(t2216, t2220) # t2221: "cuda:0 bf16[1, 32, 512, 512]" - # t2231 = ltorch.tril(t2221, 0, fill_value=-float('inf')) # t2231: "cuda:0 bf16[1, 32, 512, 512]" - # t2222 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t2222: "cuda:0 i64[512]" - # t2222 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t2222: "cuda:0 i64[512]" - # t2223 = ltorch.unsqueeze(t2222, -1) # t2223: "cuda:0 i64[512, 1]" - # t2223 = prims.broadcast_in_dim(t2222, [512, 1], [0]) # t2223: "cuda:0 i64[512, 1]" - # t2224 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t2224: "cuda:0 i64[512]" - # t2224 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t2224: "cuda:0 i64[512]" - # t2225 = ltorch.unsqueeze(t2224, -2) # t2225: "cuda:0 i64[1, 512]" - # t2225 = prims.broadcast_in_dim(t2224, [1, 512], [1]) # t2225: "cuda:0 i64[1, 512]" - # t2226 = ltorch.add(t2223, 0, alpha=None) # t2226: "cuda:0 i64[512, 1]" - # t2226 = prims.add(t2223, 0) # t2226: "cuda:0 i64[512, 1]" - # t2229 = ltorch.ge(t2226, t2225) # t2229: "cuda:0 b8[512, 512]" - # t2227 = prims.broadcast_in_dim(t2226, (512, 512), (0, 1)) # t2227: "cuda:0 i64[512, 512]" - # t2228 = prims.broadcast_in_dim(t2225, (512, 512), (0, 1)) # t2228: "cuda:0 i64[512, 512]" - # t2229 = prims.ge(t2227, t2228) # t2229: "cuda:0 b8[512, 512]" - # t2231 = ltorch.where(t2229, t2221, -float('inf')) # t2231: "cuda:0 bf16[1, 32, 512, 512]" - # t2230 = prims.broadcast_in_dim(t2229, (1, 32, 512, 512), (2, 3)) # t2230: "cuda:0 b8[1, 32, 512, 512]" - # t2231 = prims.where(t2230, t2221, -float('inf')) # t2231: "cuda:0 bf16[1, 32, 512, 512]" - # t2242 = ltorch._softmax(t2231, -1, dtype=None) # t2242: "cuda:0 bf16[1, 32, 512, 512]" - # t2232 = ltorch.to(t2231, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t2232: "cuda:0 f32[1, 32, 512, 512]" - # t2232 = prims.convert_element_type(t2231, dtypes.float32) # t2232: "cuda:0 f32[1, 32, 512, 512]" - # t2234 = ltorch.amax(t2232, -1, True) # t2234: "cuda:0 f32[1, 32, 512, 1]" - # t2233 = prims.amax(t2232, (3,)) # t2233: "cuda:0 f32[1, 32, 512]" - # t2234 = prims.broadcast_in_dim(t2233, [1, 32, 512, 1], [0, 1, 2]) # t2234: "cuda:0 f32[1, 32, 512, 1]" - # t2236 = ltorch.sub(t2232, t2234, alpha=None) # t2236: "cuda:0 f32[1, 32, 512, 512]" - # t2235 = prims.broadcast_in_dim(t2234, (1, 32, 512, 512), (0, 1, 2, 3)) # t2235: "cuda:0 f32[1, 32, 512, 512]" - # t2236 = prims.sub(t2232, t2235) # t2236: "cuda:0 f32[1, 32, 512, 512]" - # t2237 = ltorch.exp(t2236) # t2237: "cuda:0 f32[1, 32, 512, 512]" - # t2237 = prims.exp(t2236) # t2237: "cuda:0 f32[1, 32, 512, 512]" - # t2239 = ltorch.sum(t2237, -1, True, dtype=None) # t2239: "cuda:0 f32[1, 32, 512, 1]" - # t2238 = prims.sum(t2237, (3,)) # t2238: "cuda:0 f32[1, 32, 512]" - # t2239 = prims.broadcast_in_dim(t2238, [1, 32, 512, 1], [0, 1, 2]) # t2239: "cuda:0 f32[1, 32, 512, 1]" - # t2241 = ltorch.true_divide(t2237, t2239) # t2241: "cuda:0 f32[1, 32, 512, 512]" - # t2240 = prims.broadcast_in_dim(t2239, (1, 32, 512, 512), (0, 1, 2, 3)) # t2240: "cuda:0 f32[1, 32, 512, 512]" - # t2241 = prims.div(t2237, t2240) # t2241: "cuda:0 f32[1, 32, 512, 512]" - # t2242 = ltorch.to(t2241, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t2242: "cuda:0 bf16[1, 32, 512, 512]" - # t2242 = prims.convert_element_type(t2241, dtypes.bfloat16) # t2242: "cuda:0 bf16[1, 32, 512, 512]" - # t2243 = ltorch.matmul(t2242, t2179) # t2243: "cuda:0 bf16[1, 32, 512, 128]" - # t2243 = prims.matmul(t2242, t2179) # t2243: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t2244 = ltorch.transpose(t2243, 1, 2) # t2244: "cuda:0 bf16[1, 512, 32, 128]" - # t2244 = prims.transpose(t2243, (0, 2, 1, 3)) # t2244: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t2245 = ltorch.reshape(t2244, 1, 512, 4096) # t2245: "cuda:0 bf16[1, 512, 4096]" - # t2245 = prims.reshape(t2244, (1, 512, 4096)) # t2245: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2249 = ltorch.linear(t2245, t_transformer_h_13_attn_proj_weight, None) # t2249: "cuda:0 bf16[1, 512, 4096]" - # t2249 = prims.linear(t2245, t_transformer_h_13_attn_proj_weight, None) # t2249: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t2253 = ltorch.add(t2249, t2143, alpha=None) # t2253: "cuda:0 bf16[1, 512, 4096]" - # t2250 = prims.convert_element_type(t2249, dtypes.float32) # t2250: "cuda:0 f32[1, 512, 4096]" - # t2251 = prims.convert_element_type(t2143, dtypes.float32) # t2251: "cuda:0 f32[1, 512, 4096]" - # t2252 = prims.add(t2250, t2251) # t2252: "cuda:0 f32[1, 512, 4096]" - # t2253 = prims.convert_element_type(t2252, dtypes.bfloat16) # t2253: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t2254 = prims.convert_element_type(t2253, dtypes.float32) # t2254: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t2255 = ltorch.mul(t2254, t2254) # t2255: "cuda:0 f32[1, 512, 4096]" - # t2255 = prims.mul(t2254, t2254) # t2255: "cuda:0 f32[1, 512, 4096]" - t2259 = ltorch.mean(t2255, -1, True, dtype=None) # t2259: "cuda:0 f32[1, 512, 1]" - # t2257 = prims.sum(t2255, (2,)) # t2257: "cuda:0 f32[1, 512]" - # t2258 = prims.broadcast_in_dim(t2257, [1, 512, 1], [0, 1]) # t2258: "cuda:0 f32[1, 512, 1]" - # t2259 = ltorch.true_divide(t2258, 4096) # t2259: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t2259 = prims.div(t2258, 4096.0) # t2259: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t2261 = ltorch.add(t2259, 1e-05, alpha=None) # t2261: "cuda:0 f32[1, 512, 1]" - # t2261 = prims.add(t2259, 1e-05) # t2261: "cuda:0 f32[1, 512, 1]" - t2262 = ltorch.rsqrt(t2261) # t2262: "cuda:0 f32[1, 512, 1]" - # t2262 = prims.rsqrt(t2261) # t2262: "cuda:0 f32[1, 512, 1]" - t2264 = ltorch.mul(t2254, t2262) # t2264: "cuda:0 f32[1, 512, 4096]" - # t2263 = prims.broadcast_in_dim(t2262, (1, 512, 4096), (0, 1, 2)) # t2263: "cuda:0 f32[1, 512, 4096]" - # t2264 = prims.mul(t2254, t2263) # t2264: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t2265 = ltorch.to(t2264, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2265: "cuda:0 bf16[1, 512, 4096]" - # t2265 = prims.convert_element_type(t2264, dtypes.bfloat16) # t2265: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t2275 = ltorch.mul(t2265, t_transformer_h_13_norm_2_weight) # t2275: "cuda:0 bf16[1, 512, 4096]" - # t2271 = prims.broadcast_in_dim(t_transformer_h_13_norm_2_weight, (1, 512, 4096), (2,)) # t2271: "cuda:0 bf16[1, 512, 4096]" - # t2272 = prims.convert_element_type(t2265, dtypes.float32) # t2272: "cuda:0 f32[1, 512, 4096]" - # t2273 = prims.convert_element_type(t2271, dtypes.float32) # t2273: "cuda:0 f32[1, 512, 4096]" - # t2274 = prims.mul(t2272, t2273) # t2274: "cuda:0 f32[1, 512, 4096]" - # t2275 = prims.convert_element_type(t2274, dtypes.bfloat16) # t2275: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2280 = ltorch.linear(t2275, t_transformer_h_13_mlp_fc_1_weight, None) # t2280: "cuda:0 bf16[1, 512, 11008]" - # t2280 = prims.linear(t2275, t_transformer_h_13_mlp_fc_1_weight, None) # t2280: "cuda:0 bf16[1, 512, 11008]" - t2284 = ltorch.linear(t2275, t_transformer_h_13_mlp_fc_2_weight, None) # t2284: "cuda:0 bf16[1, 512, 11008]" - # t2284 = prims.linear(t2275, t_transformer_h_13_mlp_fc_2_weight, None) # t2284: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t2294 = ltorch.silu(t2280, False) # t2294: "cuda:0 bf16[1, 512, 11008]" - # t2285 = prims.convert_element_type(t2280, dtypes.float32) # t2285: "cuda:0 f32[1, 512, 11008]" - # t2286 = prims.neg(t2285) # t2286: "cuda:0 f32[1, 512, 11008]" - # t2287 = prims.exp(t2286) # t2287: "cuda:0 f32[1, 512, 11008]" - # t2288 = prims.add(1.0, t2287) # t2288: "cuda:0 f32[1, 512, 11008]" - # t2289 = prims.reciprocal(t2288) # t2289: "cuda:0 f32[1, 512, 11008]" - # t2290 = prims.convert_element_type(t2289, dtypes.bfloat16) # t2290: "cuda:0 bf16[1, 512, 11008]" - # t2291 = prims.convert_element_type(t2280, dtypes.float32) # t2291: "cuda:0 f32[1, 512, 11008]" - # t2292 = prims.convert_element_type(t2290, dtypes.float32) # t2292: "cuda:0 f32[1, 512, 11008]" - # t2293 = prims.mul(t2291, t2292) # t2293: "cuda:0 f32[1, 512, 11008]" - # t2294 = prims.convert_element_type(t2293, dtypes.bfloat16) # t2294: "cuda:0 bf16[1, 512, 11008]" - t2298 = ltorch.mul(t2294, t2284) # t2298: "cuda:0 bf16[1, 512, 11008]" - # t2295 = prims.convert_element_type(t2294, dtypes.float32) # t2295: "cuda:0 f32[1, 512, 11008]" - # t2296 = prims.convert_element_type(t2284, dtypes.float32) # t2296: "cuda:0 f32[1, 512, 11008]" - # t2297 = prims.mul(t2295, t2296) # t2297: "cuda:0 f32[1, 512, 11008]" - # t2298 = prims.convert_element_type(t2297, dtypes.bfloat16) # t2298: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2302 = ltorch.linear(t2298, t_transformer_h_13_mlp_proj_weight, None) # t2302: "cuda:0 bf16[1, 512, 4096]" - # t2302 = prims.linear(t2298, t_transformer_h_13_mlp_proj_weight, None) # t2302: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t2306 = ltorch.add(t2302, t2253, alpha=None) # t2306: "cuda:0 bf16[1, 512, 4096]" - # t2303 = prims.convert_element_type(t2302, dtypes.float32) # t2303: "cuda:0 f32[1, 512, 4096]" - # t2304 = prims.convert_element_type(t2253, dtypes.float32) # t2304: "cuda:0 f32[1, 512, 4096]" - # t2305 = prims.add(t2303, t2304) # t2305: "cuda:0 f32[1, 512, 4096]" - # t2306 = prims.convert_element_type(t2305, dtypes.bfloat16) # t2306: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t2308 = prims.convert_element_type(t2306, dtypes.float32) # t2308: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t2309 = ltorch.mul(t2308, t2308) # t2309: "cuda:0 f32[1, 512, 4096]" - # t2309 = prims.mul(t2308, t2308) # t2309: "cuda:0 f32[1, 512, 4096]" - t2313 = ltorch.mean(t2309, -1, True, dtype=None) # t2313: "cuda:0 f32[1, 512, 1]" - # t2311 = prims.sum(t2309, (2,)) # t2311: "cuda:0 f32[1, 512]" - # t2312 = prims.broadcast_in_dim(t2311, [1, 512, 1], [0, 1]) # t2312: "cuda:0 f32[1, 512, 1]" - # t2313 = ltorch.true_divide(t2312, 4096) # t2313: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t2313 = prims.div(t2312, 4096.0) # t2313: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t2315 = ltorch.add(t2313, 1e-05, alpha=None) # t2315: "cuda:0 f32[1, 512, 1]" - # t2315 = prims.add(t2313, 1e-05) # t2315: "cuda:0 f32[1, 512, 1]" - t2316 = ltorch.rsqrt(t2315) # t2316: "cuda:0 f32[1, 512, 1]" - # t2316 = prims.rsqrt(t2315) # t2316: "cuda:0 f32[1, 512, 1]" - t2318 = ltorch.mul(t2308, t2316) # t2318: "cuda:0 f32[1, 512, 4096]" - # t2317 = prims.broadcast_in_dim(t2316, (1, 512, 4096), (0, 1, 2)) # t2317: "cuda:0 f32[1, 512, 4096]" - # t2318 = prims.mul(t2308, t2317) # t2318: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t2319 = ltorch.to(t2318, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2319: "cuda:0 bf16[1, 512, 4096]" - # t2319 = prims.convert_element_type(t2318, dtypes.bfloat16) # t2319: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t2329 = ltorch.mul(t2319, t_transformer_h_14_norm_1_weight) # t2329: "cuda:0 bf16[1, 512, 4096]" - # t2325 = prims.broadcast_in_dim(t_transformer_h_14_norm_1_weight, (1, 512, 4096), (2,)) # t2325: "cuda:0 bf16[1, 512, 4096]" - # t2326 = prims.convert_element_type(t2319, dtypes.float32) # t2326: "cuda:0 f32[1, 512, 4096]" - # t2327 = prims.convert_element_type(t2325, dtypes.float32) # t2327: "cuda:0 f32[1, 512, 4096]" - # t2328 = prims.mul(t2326, t2327) # t2328: "cuda:0 f32[1, 512, 4096]" - # t2329 = prims.convert_element_type(t2328, dtypes.bfloat16) # t2329: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2334 = ltorch.linear(t2329, t_transformer_h_14_attn_attn_weight, None) # t2334: "cuda:0 bf16[1, 512, 12288]" - # t2334 = prims.linear(t2329, t_transformer_h_14_attn_attn_weight, None) # t2334: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t2335 = ltorch.view(t2334, 1, 512, 32, 3, 128) # t2335: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t2335 = ltorch.reshape(t2334, (1, 512, 32, 3, 128)) # t2335: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t2335 = prims.reshape(t2334, (1, 512, 32, 3, 128)) # t2335: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t2336 = ltorch.permute(t2335, 0, 2, 3, 1, 4) # t2336: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t2336 = prims.transpose(t2335, (0, 2, 3, 1, 4)) # t2336: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t2337, t2338, t2339) = ltorch.split(t2336, (1, 1, 1), 2) - # t2337 = prims.slice_prim(t2336, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t2337: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2338 = prims.slice_prim(t2336, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t2338: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2339 = prims.slice_prim(t2336, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t2339: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t2340 = ltorch.reshape(t2337, 1, -1, 512, 128) # t2340: "cuda:0 bf16[1, 32, 512, 128]" - # t2340 = prims.reshape(t2337, (1, 32, 512, 128)) # t2340: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t2341 = ltorch.reshape(t2338, 1, -1, 512, 128) # t2341: "cuda:0 bf16[1, 32, 512, 128]" - # t2341 = prims.reshape(t2338, (1, 32, 512, 128)) # t2341: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t2342 = ltorch.reshape(t2339, 1, -1, 512, 128) # t2342: "cuda:0 bf16[1, 32, 512, 128]" - # t2342 = prims.reshape(t2339, (1, 32, 512, 128)) # t2342: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t2343 = ltorch.getitem(t2340, (..., slice(None, 128, None))) # t2343: "cuda:0 bf16[1, 32, 512, 128]" - # t2343 = prims.slice_prim(t2340, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2343: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t2344 = ltorch.getitem(t2343, (..., slice(None, 64, None))) # t2344: "cuda:0 bf16[1, 32, 512, 64]" - # t2344 = prims.slice_prim(t2343, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2344: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t2345 = ltorch.getitem(t2343, (..., slice(64, None, None))) # t2345: "cuda:0 bf16[1, 32, 512, 64]" - # t2345 = prims.slice_prim(t2343, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2345: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t2348 = ltorch.neg(t2345) # t2348: "cuda:0 bf16[1, 32, 512, 64]" - # t2346 = prims.convert_element_type(t2345, dtypes.float32) # t2346: "cuda:0 f32[1, 32, 512, 64]" - # t2347 = prims.neg(t2346) # t2347: "cuda:0 f32[1, 32, 512, 64]" - # t2348 = prims.convert_element_type(t2347, dtypes.bfloat16) # t2348: "cuda:0 bf16[1, 32, 512, 64]" - t2349 = ltorch.cat((t2348, t2344), -1) # t2349: "cuda:0 bf16[1, 32, 512, 128]" - # t2349 = prims.cat((t2348, t2344), -1) # t2349: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t2352 = ltorch.mul(t2343, cos) # t2352: "cuda:0 f32[1, 32, 512, 128]" - # t2350 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t2350: "cuda:0 f32[1, 32, 512, 128]" - # t2351 = prims.convert_element_type(t2343, dtypes.float32) # t2351: "cuda:0 f32[1, 32, 512, 128]" - # t2352 = prims.mul(t2351, t2350) # t2352: "cuda:0 f32[1, 32, 512, 128]" - t2355 = ltorch.mul(t2349, sin) # t2355: "cuda:0 f32[1, 32, 512, 128]" - # t2353 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t2353: "cuda:0 f32[1, 32, 512, 128]" - # t2354 = prims.convert_element_type(t2349, dtypes.float32) # t2354: "cuda:0 f32[1, 32, 512, 128]" - # t2355 = prims.mul(t2354, t2353) # t2355: "cuda:0 f32[1, 32, 512, 128]" - t2356 = ltorch.add(t2352, t2355, alpha=None) # t2356: "cuda:0 f32[1, 32, 512, 128]" - # t2356 = prims.add(t2352, t2355) # t2356: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t2357 = ltorch.to(t2356, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2357: "cuda:0 bf16[1, 32, 512, 128]" - # t2357 = prims.convert_element_type(t2356, dtypes.bfloat16) # t2357: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t2358 = ltorch.getitem(t2341, (..., slice(None, 128, None))) # t2358: "cuda:0 bf16[1, 32, 512, 128]" - # t2358 = prims.slice_prim(t2341, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2358: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t2359 = ltorch.getitem(t2358, (..., slice(None, 64, None))) # t2359: "cuda:0 bf16[1, 32, 512, 64]" - # t2359 = prims.slice_prim(t2358, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2359: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t2360 = ltorch.getitem(t2358, (..., slice(64, None, None))) # t2360: "cuda:0 bf16[1, 32, 512, 64]" - # t2360 = prims.slice_prim(t2358, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2360: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t2363 = ltorch.neg(t2360) # t2363: "cuda:0 bf16[1, 32, 512, 64]" - # t2361 = prims.convert_element_type(t2360, dtypes.float32) # t2361: "cuda:0 f32[1, 32, 512, 64]" - # t2362 = prims.neg(t2361) # t2362: "cuda:0 f32[1, 32, 512, 64]" - # t2363 = prims.convert_element_type(t2362, dtypes.bfloat16) # t2363: "cuda:0 bf16[1, 32, 512, 64]" - t2364 = ltorch.cat((t2363, t2359), -1) # t2364: "cuda:0 bf16[1, 32, 512, 128]" - # t2364 = prims.cat((t2363, t2359), -1) # t2364: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t2367 = ltorch.mul(t2358, cos) # t2367: "cuda:0 f32[1, 32, 512, 128]" - # t2365 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t2365: "cuda:0 f32[1, 32, 512, 128]" - # t2366 = prims.convert_element_type(t2358, dtypes.float32) # t2366: "cuda:0 f32[1, 32, 512, 128]" - # t2367 = prims.mul(t2366, t2365) # t2367: "cuda:0 f32[1, 32, 512, 128]" - t2370 = ltorch.mul(t2364, sin) # t2370: "cuda:0 f32[1, 32, 512, 128]" - # t2368 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t2368: "cuda:0 f32[1, 32, 512, 128]" - # t2369 = prims.convert_element_type(t2364, dtypes.float32) # t2369: "cuda:0 f32[1, 32, 512, 128]" - # t2370 = prims.mul(t2369, t2368) # t2370: "cuda:0 f32[1, 32, 512, 128]" - t2371 = ltorch.add(t2367, t2370, alpha=None) # t2371: "cuda:0 f32[1, 32, 512, 128]" - # t2371 = prims.add(t2367, t2370) # t2371: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t2372 = ltorch.to(t2371, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2372: "cuda:0 bf16[1, 32, 512, 128]" - # t2372 = prims.convert_element_type(t2371, dtypes.bfloat16) # t2372: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t2373 = ltorch.getitem(t2340, (..., slice(128, None, None))) # t2373: "cuda:0 bf16[1, 32, 512, 0]" - # t2373 = prims.slice_prim(t2340, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2373: "cuda:0 bf16[1, 32, 512, 0]" - t2374 = ltorch.cat((t2357, t2373), -1) # t2374: "cuda:0 bf16[1, 32, 512, 128]" - # t2374 = prims.cat((t2357, t2373), -1) # t2374: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t2375 = ltorch.getitem(t2341, (..., slice(128, None, None))) # t2375: "cuda:0 bf16[1, 32, 512, 0]" - # t2375 = prims.slice_prim(t2341, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2375: "cuda:0 bf16[1, 32, 512, 0]" - t2376 = ltorch.cat((t2372, t2375), -1) # t2376: "cuda:0 bf16[1, 32, 512, 128]" - # t2376 = prims.cat((t2372, t2375), -1) # t2376: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t2406 = ltorch.scaled_dot_product_attention(t2374, t2376, t2342, None, 0.0, True, scale=0.08838834764831843) # t2406: "cuda:0 bf16[1, 32, 512, 128]" - # t2379 = ltorch.mul(t2374, 0.29730177875068026) # t2379: "cuda:0 bf16[1, 32, 512, 128]" - # t2377 = prims.convert_element_type(t2374, dtypes.float32) # t2377: "cuda:0 f32[1, 32, 512, 128]" - # t2378 = prims.mul(t2377, 0.29730177875068026) # t2378: "cuda:0 f32[1, 32, 512, 128]" - # t2379 = prims.convert_element_type(t2378, dtypes.bfloat16) # t2379: "cuda:0 bf16[1, 32, 512, 128]" - # t2380 = ltorch.transpose(t2376, -2, -1) # t2380: "cuda:0 bf16[1, 32, 128, 512]" - # t2380 = prims.transpose(t2376, (0, 1, 3, 2)) # t2380: "cuda:0 bf16[1, 32, 128, 512]" - # t2383 = ltorch.mul(t2380, 0.29730177875068026) # t2383: "cuda:0 bf16[1, 32, 128, 512]" - # t2381 = prims.convert_element_type(t2380, dtypes.float32) # t2381: "cuda:0 f32[1, 32, 128, 512]" - # t2382 = prims.mul(t2381, 0.29730177875068026) # t2382: "cuda:0 f32[1, 32, 128, 512]" - # t2383 = prims.convert_element_type(t2382, dtypes.bfloat16) # t2383: "cuda:0 bf16[1, 32, 128, 512]" - # t2384 = ltorch.matmul(t2379, t2383) # t2384: "cuda:0 bf16[1, 32, 512, 512]" - # t2384 = prims.matmul(t2379, t2383) # t2384: "cuda:0 bf16[1, 32, 512, 512]" - # t2394 = ltorch.tril(t2384, 0, fill_value=-float('inf')) # t2394: "cuda:0 bf16[1, 32, 512, 512]" - # t2385 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t2385: "cuda:0 i64[512]" - # t2385 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t2385: "cuda:0 i64[512]" - # t2386 = ltorch.unsqueeze(t2385, -1) # t2386: "cuda:0 i64[512, 1]" - # t2386 = prims.broadcast_in_dim(t2385, [512, 1], [0]) # t2386: "cuda:0 i64[512, 1]" - # t2387 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t2387: "cuda:0 i64[512]" - # t2387 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t2387: "cuda:0 i64[512]" - # t2388 = ltorch.unsqueeze(t2387, -2) # t2388: "cuda:0 i64[1, 512]" - # t2388 = prims.broadcast_in_dim(t2387, [1, 512], [1]) # t2388: "cuda:0 i64[1, 512]" - # t2389 = ltorch.add(t2386, 0, alpha=None) # t2389: "cuda:0 i64[512, 1]" - # t2389 = prims.add(t2386, 0) # t2389: "cuda:0 i64[512, 1]" - # t2392 = ltorch.ge(t2389, t2388) # t2392: "cuda:0 b8[512, 512]" - # t2390 = prims.broadcast_in_dim(t2389, (512, 512), (0, 1)) # t2390: "cuda:0 i64[512, 512]" - # t2391 = prims.broadcast_in_dim(t2388, (512, 512), (0, 1)) # t2391: "cuda:0 i64[512, 512]" - # t2392 = prims.ge(t2390, t2391) # t2392: "cuda:0 b8[512, 512]" - # t2394 = ltorch.where(t2392, t2384, -float('inf')) # t2394: "cuda:0 bf16[1, 32, 512, 512]" - # t2393 = prims.broadcast_in_dim(t2392, (1, 32, 512, 512), (2, 3)) # t2393: "cuda:0 b8[1, 32, 512, 512]" - # t2394 = prims.where(t2393, t2384, -float('inf')) # t2394: "cuda:0 bf16[1, 32, 512, 512]" - # t2405 = ltorch._softmax(t2394, -1, dtype=None) # t2405: "cuda:0 bf16[1, 32, 512, 512]" - # t2395 = ltorch.to(t2394, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t2395: "cuda:0 f32[1, 32, 512, 512]" - # t2395 = prims.convert_element_type(t2394, dtypes.float32) # t2395: "cuda:0 f32[1, 32, 512, 512]" - # t2397 = ltorch.amax(t2395, -1, True) # t2397: "cuda:0 f32[1, 32, 512, 1]" - # t2396 = prims.amax(t2395, (3,)) # t2396: "cuda:0 f32[1, 32, 512]" - # t2397 = prims.broadcast_in_dim(t2396, [1, 32, 512, 1], [0, 1, 2]) # t2397: "cuda:0 f32[1, 32, 512, 1]" - # t2399 = ltorch.sub(t2395, t2397, alpha=None) # t2399: "cuda:0 f32[1, 32, 512, 512]" - # t2398 = prims.broadcast_in_dim(t2397, (1, 32, 512, 512), (0, 1, 2, 3)) # t2398: "cuda:0 f32[1, 32, 512, 512]" - # t2399 = prims.sub(t2395, t2398) # t2399: "cuda:0 f32[1, 32, 512, 512]" - # t2400 = ltorch.exp(t2399) # t2400: "cuda:0 f32[1, 32, 512, 512]" - # t2400 = prims.exp(t2399) # t2400: "cuda:0 f32[1, 32, 512, 512]" - # t2402 = ltorch.sum(t2400, -1, True, dtype=None) # t2402: "cuda:0 f32[1, 32, 512, 1]" - # t2401 = prims.sum(t2400, (3,)) # t2401: "cuda:0 f32[1, 32, 512]" - # t2402 = prims.broadcast_in_dim(t2401, [1, 32, 512, 1], [0, 1, 2]) # t2402: "cuda:0 f32[1, 32, 512, 1]" - # t2404 = ltorch.true_divide(t2400, t2402) # t2404: "cuda:0 f32[1, 32, 512, 512]" - # t2403 = prims.broadcast_in_dim(t2402, (1, 32, 512, 512), (0, 1, 2, 3)) # t2403: "cuda:0 f32[1, 32, 512, 512]" - # t2404 = prims.div(t2400, t2403) # t2404: "cuda:0 f32[1, 32, 512, 512]" - # t2405 = ltorch.to(t2404, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t2405: "cuda:0 bf16[1, 32, 512, 512]" - # t2405 = prims.convert_element_type(t2404, dtypes.bfloat16) # t2405: "cuda:0 bf16[1, 32, 512, 512]" - # t2406 = ltorch.matmul(t2405, t2342) # t2406: "cuda:0 bf16[1, 32, 512, 128]" - # t2406 = prims.matmul(t2405, t2342) # t2406: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t2407 = ltorch.transpose(t2406, 1, 2) # t2407: "cuda:0 bf16[1, 512, 32, 128]" - # t2407 = prims.transpose(t2406, (0, 2, 1, 3)) # t2407: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t2408 = ltorch.reshape(t2407, 1, 512, 4096) # t2408: "cuda:0 bf16[1, 512, 4096]" - # t2408 = prims.reshape(t2407, (1, 512, 4096)) # t2408: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2412 = ltorch.linear(t2408, t_transformer_h_14_attn_proj_weight, None) # t2412: "cuda:0 bf16[1, 512, 4096]" - # t2412 = prims.linear(t2408, t_transformer_h_14_attn_proj_weight, None) # t2412: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t2416 = ltorch.add(t2412, t2306, alpha=None) # t2416: "cuda:0 bf16[1, 512, 4096]" - # t2413 = prims.convert_element_type(t2412, dtypes.float32) # t2413: "cuda:0 f32[1, 512, 4096]" - # t2414 = prims.convert_element_type(t2306, dtypes.float32) # t2414: "cuda:0 f32[1, 512, 4096]" - # t2415 = prims.add(t2413, t2414) # t2415: "cuda:0 f32[1, 512, 4096]" - # t2416 = prims.convert_element_type(t2415, dtypes.bfloat16) # t2416: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t2417 = prims.convert_element_type(t2416, dtypes.float32) # t2417: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t2418 = ltorch.mul(t2417, t2417) # t2418: "cuda:0 f32[1, 512, 4096]" - # t2418 = prims.mul(t2417, t2417) # t2418: "cuda:0 f32[1, 512, 4096]" - t2422 = ltorch.mean(t2418, -1, True, dtype=None) # t2422: "cuda:0 f32[1, 512, 1]" - # t2420 = prims.sum(t2418, (2,)) # t2420: "cuda:0 f32[1, 512]" - # t2421 = prims.broadcast_in_dim(t2420, [1, 512, 1], [0, 1]) # t2421: "cuda:0 f32[1, 512, 1]" - # t2422 = ltorch.true_divide(t2421, 4096) # t2422: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t2422 = prims.div(t2421, 4096.0) # t2422: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t2424 = ltorch.add(t2422, 1e-05, alpha=None) # t2424: "cuda:0 f32[1, 512, 1]" - # t2424 = prims.add(t2422, 1e-05) # t2424: "cuda:0 f32[1, 512, 1]" - t2425 = ltorch.rsqrt(t2424) # t2425: "cuda:0 f32[1, 512, 1]" - # t2425 = prims.rsqrt(t2424) # t2425: "cuda:0 f32[1, 512, 1]" - t2427 = ltorch.mul(t2417, t2425) # t2427: "cuda:0 f32[1, 512, 4096]" - # t2426 = prims.broadcast_in_dim(t2425, (1, 512, 4096), (0, 1, 2)) # t2426: "cuda:0 f32[1, 512, 4096]" - # t2427 = prims.mul(t2417, t2426) # t2427: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t2428 = ltorch.to(t2427, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2428: "cuda:0 bf16[1, 512, 4096]" - # t2428 = prims.convert_element_type(t2427, dtypes.bfloat16) # t2428: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t2438 = ltorch.mul(t2428, t_transformer_h_14_norm_2_weight) # t2438: "cuda:0 bf16[1, 512, 4096]" - # t2434 = prims.broadcast_in_dim(t_transformer_h_14_norm_2_weight, (1, 512, 4096), (2,)) # t2434: "cuda:0 bf16[1, 512, 4096]" - # t2435 = prims.convert_element_type(t2428, dtypes.float32) # t2435: "cuda:0 f32[1, 512, 4096]" - # t2436 = prims.convert_element_type(t2434, dtypes.float32) # t2436: "cuda:0 f32[1, 512, 4096]" - # t2437 = prims.mul(t2435, t2436) # t2437: "cuda:0 f32[1, 512, 4096]" - # t2438 = prims.convert_element_type(t2437, dtypes.bfloat16) # t2438: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2443 = ltorch.linear(t2438, t_transformer_h_14_mlp_fc_1_weight, None) # t2443: "cuda:0 bf16[1, 512, 11008]" - # t2443 = prims.linear(t2438, t_transformer_h_14_mlp_fc_1_weight, None) # t2443: "cuda:0 bf16[1, 512, 11008]" - t2447 = ltorch.linear(t2438, t_transformer_h_14_mlp_fc_2_weight, None) # t2447: "cuda:0 bf16[1, 512, 11008]" - # t2447 = prims.linear(t2438, t_transformer_h_14_mlp_fc_2_weight, None) # t2447: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t2457 = ltorch.silu(t2443, False) # t2457: "cuda:0 bf16[1, 512, 11008]" - # t2448 = prims.convert_element_type(t2443, dtypes.float32) # t2448: "cuda:0 f32[1, 512, 11008]" - # t2449 = prims.neg(t2448) # t2449: "cuda:0 f32[1, 512, 11008]" - # t2450 = prims.exp(t2449) # t2450: "cuda:0 f32[1, 512, 11008]" - # t2451 = prims.add(1.0, t2450) # t2451: "cuda:0 f32[1, 512, 11008]" - # t2452 = prims.reciprocal(t2451) # t2452: "cuda:0 f32[1, 512, 11008]" - # t2453 = prims.convert_element_type(t2452, dtypes.bfloat16) # t2453: "cuda:0 bf16[1, 512, 11008]" - # t2454 = prims.convert_element_type(t2443, dtypes.float32) # t2454: "cuda:0 f32[1, 512, 11008]" - # t2455 = prims.convert_element_type(t2453, dtypes.float32) # t2455: "cuda:0 f32[1, 512, 11008]" - # t2456 = prims.mul(t2454, t2455) # t2456: "cuda:0 f32[1, 512, 11008]" - # t2457 = prims.convert_element_type(t2456, dtypes.bfloat16) # t2457: "cuda:0 bf16[1, 512, 11008]" - t2461 = ltorch.mul(t2457, t2447) # t2461: "cuda:0 bf16[1, 512, 11008]" - # t2458 = prims.convert_element_type(t2457, dtypes.float32) # t2458: "cuda:0 f32[1, 512, 11008]" - # t2459 = prims.convert_element_type(t2447, dtypes.float32) # t2459: "cuda:0 f32[1, 512, 11008]" - # t2460 = prims.mul(t2458, t2459) # t2460: "cuda:0 f32[1, 512, 11008]" - # t2461 = prims.convert_element_type(t2460, dtypes.bfloat16) # t2461: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2465 = ltorch.linear(t2461, t_transformer_h_14_mlp_proj_weight, None) # t2465: "cuda:0 bf16[1, 512, 4096]" - # t2465 = prims.linear(t2461, t_transformer_h_14_mlp_proj_weight, None) # t2465: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t2469 = ltorch.add(t2465, t2416, alpha=None) # t2469: "cuda:0 bf16[1, 512, 4096]" - # t2466 = prims.convert_element_type(t2465, dtypes.float32) # t2466: "cuda:0 f32[1, 512, 4096]" - # t2467 = prims.convert_element_type(t2416, dtypes.float32) # t2467: "cuda:0 f32[1, 512, 4096]" - # t2468 = prims.add(t2466, t2467) # t2468: "cuda:0 f32[1, 512, 4096]" - # t2469 = prims.convert_element_type(t2468, dtypes.bfloat16) # t2469: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t2471 = prims.convert_element_type(t2469, dtypes.float32) # t2471: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t2472 = ltorch.mul(t2471, t2471) # t2472: "cuda:0 f32[1, 512, 4096]" - # t2472 = prims.mul(t2471, t2471) # t2472: "cuda:0 f32[1, 512, 4096]" - t2476 = ltorch.mean(t2472, -1, True, dtype=None) # t2476: "cuda:0 f32[1, 512, 1]" - # t2474 = prims.sum(t2472, (2,)) # t2474: "cuda:0 f32[1, 512]" - # t2475 = prims.broadcast_in_dim(t2474, [1, 512, 1], [0, 1]) # t2475: "cuda:0 f32[1, 512, 1]" - # t2476 = ltorch.true_divide(t2475, 4096) # t2476: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t2476 = prims.div(t2475, 4096.0) # t2476: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t2478 = ltorch.add(t2476, 1e-05, alpha=None) # t2478: "cuda:0 f32[1, 512, 1]" - # t2478 = prims.add(t2476, 1e-05) # t2478: "cuda:0 f32[1, 512, 1]" - t2479 = ltorch.rsqrt(t2478) # t2479: "cuda:0 f32[1, 512, 1]" - # t2479 = prims.rsqrt(t2478) # t2479: "cuda:0 f32[1, 512, 1]" - t2481 = ltorch.mul(t2471, t2479) # t2481: "cuda:0 f32[1, 512, 4096]" - # t2480 = prims.broadcast_in_dim(t2479, (1, 512, 4096), (0, 1, 2)) # t2480: "cuda:0 f32[1, 512, 4096]" - # t2481 = prims.mul(t2471, t2480) # t2481: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t2482 = ltorch.to(t2481, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2482: "cuda:0 bf16[1, 512, 4096]" - # t2482 = prims.convert_element_type(t2481, dtypes.bfloat16) # t2482: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t2492 = ltorch.mul(t2482, t_transformer_h_15_norm_1_weight) # t2492: "cuda:0 bf16[1, 512, 4096]" - # t2488 = prims.broadcast_in_dim(t_transformer_h_15_norm_1_weight, (1, 512, 4096), (2,)) # t2488: "cuda:0 bf16[1, 512, 4096]" - # t2489 = prims.convert_element_type(t2482, dtypes.float32) # t2489: "cuda:0 f32[1, 512, 4096]" - # t2490 = prims.convert_element_type(t2488, dtypes.float32) # t2490: "cuda:0 f32[1, 512, 4096]" - # t2491 = prims.mul(t2489, t2490) # t2491: "cuda:0 f32[1, 512, 4096]" - # t2492 = prims.convert_element_type(t2491, dtypes.bfloat16) # t2492: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2497 = ltorch.linear(t2492, t_transformer_h_15_attn_attn_weight, None) # t2497: "cuda:0 bf16[1, 512, 12288]" - # t2497 = prims.linear(t2492, t_transformer_h_15_attn_attn_weight, None) # t2497: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t2498 = ltorch.view(t2497, 1, 512, 32, 3, 128) # t2498: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t2498 = ltorch.reshape(t2497, (1, 512, 32, 3, 128)) # t2498: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t2498 = prims.reshape(t2497, (1, 512, 32, 3, 128)) # t2498: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t2499 = ltorch.permute(t2498, 0, 2, 3, 1, 4) # t2499: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t2499 = prims.transpose(t2498, (0, 2, 3, 1, 4)) # t2499: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t2500, t2501, t2502) = ltorch.split(t2499, (1, 1, 1), 2) - # t2500 = prims.slice_prim(t2499, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t2500: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2501 = prims.slice_prim(t2499, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t2501: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2502 = prims.slice_prim(t2499, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t2502: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t2503 = ltorch.reshape(t2500, 1, -1, 512, 128) # t2503: "cuda:0 bf16[1, 32, 512, 128]" - # t2503 = prims.reshape(t2500, (1, 32, 512, 128)) # t2503: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t2504 = ltorch.reshape(t2501, 1, -1, 512, 128) # t2504: "cuda:0 bf16[1, 32, 512, 128]" - # t2504 = prims.reshape(t2501, (1, 32, 512, 128)) # t2504: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t2505 = ltorch.reshape(t2502, 1, -1, 512, 128) # t2505: "cuda:0 bf16[1, 32, 512, 128]" - # t2505 = prims.reshape(t2502, (1, 32, 512, 128)) # t2505: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t2506 = ltorch.getitem(t2503, (..., slice(None, 128, None))) # t2506: "cuda:0 bf16[1, 32, 512, 128]" - # t2506 = prims.slice_prim(t2503, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2506: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t2507 = ltorch.getitem(t2506, (..., slice(None, 64, None))) # t2507: "cuda:0 bf16[1, 32, 512, 64]" - # t2507 = prims.slice_prim(t2506, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2507: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t2508 = ltorch.getitem(t2506, (..., slice(64, None, None))) # t2508: "cuda:0 bf16[1, 32, 512, 64]" - # t2508 = prims.slice_prim(t2506, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2508: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t2511 = ltorch.neg(t2508) # t2511: "cuda:0 bf16[1, 32, 512, 64]" - # t2509 = prims.convert_element_type(t2508, dtypes.float32) # t2509: "cuda:0 f32[1, 32, 512, 64]" - # t2510 = prims.neg(t2509) # t2510: "cuda:0 f32[1, 32, 512, 64]" - # t2511 = prims.convert_element_type(t2510, dtypes.bfloat16) # t2511: "cuda:0 bf16[1, 32, 512, 64]" - t2512 = ltorch.cat((t2511, t2507), -1) # t2512: "cuda:0 bf16[1, 32, 512, 128]" - # t2512 = prims.cat((t2511, t2507), -1) # t2512: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t2515 = ltorch.mul(t2506, cos) # t2515: "cuda:0 f32[1, 32, 512, 128]" - # t2513 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t2513: "cuda:0 f32[1, 32, 512, 128]" - # t2514 = prims.convert_element_type(t2506, dtypes.float32) # t2514: "cuda:0 f32[1, 32, 512, 128]" - # t2515 = prims.mul(t2514, t2513) # t2515: "cuda:0 f32[1, 32, 512, 128]" - t2518 = ltorch.mul(t2512, sin) # t2518: "cuda:0 f32[1, 32, 512, 128]" - # t2516 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t2516: "cuda:0 f32[1, 32, 512, 128]" - # t2517 = prims.convert_element_type(t2512, dtypes.float32) # t2517: "cuda:0 f32[1, 32, 512, 128]" - # t2518 = prims.mul(t2517, t2516) # t2518: "cuda:0 f32[1, 32, 512, 128]" - t2519 = ltorch.add(t2515, t2518, alpha=None) # t2519: "cuda:0 f32[1, 32, 512, 128]" - # t2519 = prims.add(t2515, t2518) # t2519: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t2520 = ltorch.to(t2519, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2520: "cuda:0 bf16[1, 32, 512, 128]" - # t2520 = prims.convert_element_type(t2519, dtypes.bfloat16) # t2520: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t2521 = ltorch.getitem(t2504, (..., slice(None, 128, None))) # t2521: "cuda:0 bf16[1, 32, 512, 128]" - # t2521 = prims.slice_prim(t2504, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2521: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t2522 = ltorch.getitem(t2521, (..., slice(None, 64, None))) # t2522: "cuda:0 bf16[1, 32, 512, 64]" - # t2522 = prims.slice_prim(t2521, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2522: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t2523 = ltorch.getitem(t2521, (..., slice(64, None, None))) # t2523: "cuda:0 bf16[1, 32, 512, 64]" - # t2523 = prims.slice_prim(t2521, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2523: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t2526 = ltorch.neg(t2523) # t2526: "cuda:0 bf16[1, 32, 512, 64]" - # t2524 = prims.convert_element_type(t2523, dtypes.float32) # t2524: "cuda:0 f32[1, 32, 512, 64]" - # t2525 = prims.neg(t2524) # t2525: "cuda:0 f32[1, 32, 512, 64]" - # t2526 = prims.convert_element_type(t2525, dtypes.bfloat16) # t2526: "cuda:0 bf16[1, 32, 512, 64]" - t2527 = ltorch.cat((t2526, t2522), -1) # t2527: "cuda:0 bf16[1, 32, 512, 128]" - # t2527 = prims.cat((t2526, t2522), -1) # t2527: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t2530 = ltorch.mul(t2521, cos) # t2530: "cuda:0 f32[1, 32, 512, 128]" - # t2528 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t2528: "cuda:0 f32[1, 32, 512, 128]" - # t2529 = prims.convert_element_type(t2521, dtypes.float32) # t2529: "cuda:0 f32[1, 32, 512, 128]" - # t2530 = prims.mul(t2529, t2528) # t2530: "cuda:0 f32[1, 32, 512, 128]" - t2533 = ltorch.mul(t2527, sin) # t2533: "cuda:0 f32[1, 32, 512, 128]" - # t2531 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t2531: "cuda:0 f32[1, 32, 512, 128]" - # t2532 = prims.convert_element_type(t2527, dtypes.float32) # t2532: "cuda:0 f32[1, 32, 512, 128]" - # t2533 = prims.mul(t2532, t2531) # t2533: "cuda:0 f32[1, 32, 512, 128]" - t2534 = ltorch.add(t2530, t2533, alpha=None) # t2534: "cuda:0 f32[1, 32, 512, 128]" - # t2534 = prims.add(t2530, t2533) # t2534: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t2535 = ltorch.to(t2534, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2535: "cuda:0 bf16[1, 32, 512, 128]" - # t2535 = prims.convert_element_type(t2534, dtypes.bfloat16) # t2535: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t2536 = ltorch.getitem(t2503, (..., slice(128, None, None))) # t2536: "cuda:0 bf16[1, 32, 512, 0]" - # t2536 = prims.slice_prim(t2503, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2536: "cuda:0 bf16[1, 32, 512, 0]" - t2537 = ltorch.cat((t2520, t2536), -1) # t2537: "cuda:0 bf16[1, 32, 512, 128]" - # t2537 = prims.cat((t2520, t2536), -1) # t2537: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t2538 = ltorch.getitem(t2504, (..., slice(128, None, None))) # t2538: "cuda:0 bf16[1, 32, 512, 0]" - # t2538 = prims.slice_prim(t2504, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2538: "cuda:0 bf16[1, 32, 512, 0]" - t2539 = ltorch.cat((t2535, t2538), -1) # t2539: "cuda:0 bf16[1, 32, 512, 128]" - # t2539 = prims.cat((t2535, t2538), -1) # t2539: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t2569 = ltorch.scaled_dot_product_attention(t2537, t2539, t2505, None, 0.0, True, scale=0.08838834764831843) # t2569: "cuda:0 bf16[1, 32, 512, 128]" - # t2542 = ltorch.mul(t2537, 0.29730177875068026) # t2542: "cuda:0 bf16[1, 32, 512, 128]" - # t2540 = prims.convert_element_type(t2537, dtypes.float32) # t2540: "cuda:0 f32[1, 32, 512, 128]" - # t2541 = prims.mul(t2540, 0.29730177875068026) # t2541: "cuda:0 f32[1, 32, 512, 128]" - # t2542 = prims.convert_element_type(t2541, dtypes.bfloat16) # t2542: "cuda:0 bf16[1, 32, 512, 128]" - # t2543 = ltorch.transpose(t2539, -2, -1) # t2543: "cuda:0 bf16[1, 32, 128, 512]" - # t2543 = prims.transpose(t2539, (0, 1, 3, 2)) # t2543: "cuda:0 bf16[1, 32, 128, 512]" - # t2546 = ltorch.mul(t2543, 0.29730177875068026) # t2546: "cuda:0 bf16[1, 32, 128, 512]" - # t2544 = prims.convert_element_type(t2543, dtypes.float32) # t2544: "cuda:0 f32[1, 32, 128, 512]" - # t2545 = prims.mul(t2544, 0.29730177875068026) # t2545: "cuda:0 f32[1, 32, 128, 512]" - # t2546 = prims.convert_element_type(t2545, dtypes.bfloat16) # t2546: "cuda:0 bf16[1, 32, 128, 512]" - # t2547 = ltorch.matmul(t2542, t2546) # t2547: "cuda:0 bf16[1, 32, 512, 512]" - # t2547 = prims.matmul(t2542, t2546) # t2547: "cuda:0 bf16[1, 32, 512, 512]" - # t2557 = ltorch.tril(t2547, 0, fill_value=-float('inf')) # t2557: "cuda:0 bf16[1, 32, 512, 512]" - # t2548 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t2548: "cuda:0 i64[512]" - # t2548 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t2548: "cuda:0 i64[512]" - # t2549 = ltorch.unsqueeze(t2548, -1) # t2549: "cuda:0 i64[512, 1]" - # t2549 = prims.broadcast_in_dim(t2548, [512, 1], [0]) # t2549: "cuda:0 i64[512, 1]" - # t2550 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t2550: "cuda:0 i64[512]" - # t2550 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t2550: "cuda:0 i64[512]" - # t2551 = ltorch.unsqueeze(t2550, -2) # t2551: "cuda:0 i64[1, 512]" - # t2551 = prims.broadcast_in_dim(t2550, [1, 512], [1]) # t2551: "cuda:0 i64[1, 512]" - # t2552 = ltorch.add(t2549, 0, alpha=None) # t2552: "cuda:0 i64[512, 1]" - # t2552 = prims.add(t2549, 0) # t2552: "cuda:0 i64[512, 1]" - # t2555 = ltorch.ge(t2552, t2551) # t2555: "cuda:0 b8[512, 512]" - # t2553 = prims.broadcast_in_dim(t2552, (512, 512), (0, 1)) # t2553: "cuda:0 i64[512, 512]" - # t2554 = prims.broadcast_in_dim(t2551, (512, 512), (0, 1)) # t2554: "cuda:0 i64[512, 512]" - # t2555 = prims.ge(t2553, t2554) # t2555: "cuda:0 b8[512, 512]" - # t2557 = ltorch.where(t2555, t2547, -float('inf')) # t2557: "cuda:0 bf16[1, 32, 512, 512]" - # t2556 = prims.broadcast_in_dim(t2555, (1, 32, 512, 512), (2, 3)) # t2556: "cuda:0 b8[1, 32, 512, 512]" - # t2557 = prims.where(t2556, t2547, -float('inf')) # t2557: "cuda:0 bf16[1, 32, 512, 512]" - # t2568 = ltorch._softmax(t2557, -1, dtype=None) # t2568: "cuda:0 bf16[1, 32, 512, 512]" - # t2558 = ltorch.to(t2557, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t2558: "cuda:0 f32[1, 32, 512, 512]" - # t2558 = prims.convert_element_type(t2557, dtypes.float32) # t2558: "cuda:0 f32[1, 32, 512, 512]" - # t2560 = ltorch.amax(t2558, -1, True) # t2560: "cuda:0 f32[1, 32, 512, 1]" - # t2559 = prims.amax(t2558, (3,)) # t2559: "cuda:0 f32[1, 32, 512]" - # t2560 = prims.broadcast_in_dim(t2559, [1, 32, 512, 1], [0, 1, 2]) # t2560: "cuda:0 f32[1, 32, 512, 1]" - # t2562 = ltorch.sub(t2558, t2560, alpha=None) # t2562: "cuda:0 f32[1, 32, 512, 512]" - # t2561 = prims.broadcast_in_dim(t2560, (1, 32, 512, 512), (0, 1, 2, 3)) # t2561: "cuda:0 f32[1, 32, 512, 512]" - # t2562 = prims.sub(t2558, t2561) # t2562: "cuda:0 f32[1, 32, 512, 512]" - # t2563 = ltorch.exp(t2562) # t2563: "cuda:0 f32[1, 32, 512, 512]" - # t2563 = prims.exp(t2562) # t2563: "cuda:0 f32[1, 32, 512, 512]" - # t2565 = ltorch.sum(t2563, -1, True, dtype=None) # t2565: "cuda:0 f32[1, 32, 512, 1]" - # t2564 = prims.sum(t2563, (3,)) # t2564: "cuda:0 f32[1, 32, 512]" - # t2565 = prims.broadcast_in_dim(t2564, [1, 32, 512, 1], [0, 1, 2]) # t2565: "cuda:0 f32[1, 32, 512, 1]" - # t2567 = ltorch.true_divide(t2563, t2565) # t2567: "cuda:0 f32[1, 32, 512, 512]" - # t2566 = prims.broadcast_in_dim(t2565, (1, 32, 512, 512), (0, 1, 2, 3)) # t2566: "cuda:0 f32[1, 32, 512, 512]" - # t2567 = prims.div(t2563, t2566) # t2567: "cuda:0 f32[1, 32, 512, 512]" - # t2568 = ltorch.to(t2567, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t2568: "cuda:0 bf16[1, 32, 512, 512]" - # t2568 = prims.convert_element_type(t2567, dtypes.bfloat16) # t2568: "cuda:0 bf16[1, 32, 512, 512]" - # t2569 = ltorch.matmul(t2568, t2505) # t2569: "cuda:0 bf16[1, 32, 512, 128]" - # t2569 = prims.matmul(t2568, t2505) # t2569: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t2570 = ltorch.transpose(t2569, 1, 2) # t2570: "cuda:0 bf16[1, 512, 32, 128]" - # t2570 = prims.transpose(t2569, (0, 2, 1, 3)) # t2570: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t2571 = ltorch.reshape(t2570, 1, 512, 4096) # t2571: "cuda:0 bf16[1, 512, 4096]" - # t2571 = prims.reshape(t2570, (1, 512, 4096)) # t2571: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2575 = ltorch.linear(t2571, t_transformer_h_15_attn_proj_weight, None) # t2575: "cuda:0 bf16[1, 512, 4096]" - # t2575 = prims.linear(t2571, t_transformer_h_15_attn_proj_weight, None) # t2575: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t2579 = ltorch.add(t2575, t2469, alpha=None) # t2579: "cuda:0 bf16[1, 512, 4096]" - # t2576 = prims.convert_element_type(t2575, dtypes.float32) # t2576: "cuda:0 f32[1, 512, 4096]" - # t2577 = prims.convert_element_type(t2469, dtypes.float32) # t2577: "cuda:0 f32[1, 512, 4096]" - # t2578 = prims.add(t2576, t2577) # t2578: "cuda:0 f32[1, 512, 4096]" - # t2579 = prims.convert_element_type(t2578, dtypes.bfloat16) # t2579: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t2580 = prims.convert_element_type(t2579, dtypes.float32) # t2580: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t2581 = ltorch.mul(t2580, t2580) # t2581: "cuda:0 f32[1, 512, 4096]" - # t2581 = prims.mul(t2580, t2580) # t2581: "cuda:0 f32[1, 512, 4096]" - t2585 = ltorch.mean(t2581, -1, True, dtype=None) # t2585: "cuda:0 f32[1, 512, 1]" - # t2583 = prims.sum(t2581, (2,)) # t2583: "cuda:0 f32[1, 512]" - # t2584 = prims.broadcast_in_dim(t2583, [1, 512, 1], [0, 1]) # t2584: "cuda:0 f32[1, 512, 1]" - # t2585 = ltorch.true_divide(t2584, 4096) # t2585: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t2585 = prims.div(t2584, 4096.0) # t2585: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t2587 = ltorch.add(t2585, 1e-05, alpha=None) # t2587: "cuda:0 f32[1, 512, 1]" - # t2587 = prims.add(t2585, 1e-05) # t2587: "cuda:0 f32[1, 512, 1]" - t2588 = ltorch.rsqrt(t2587) # t2588: "cuda:0 f32[1, 512, 1]" - # t2588 = prims.rsqrt(t2587) # t2588: "cuda:0 f32[1, 512, 1]" - t2590 = ltorch.mul(t2580, t2588) # t2590: "cuda:0 f32[1, 512, 4096]" - # t2589 = prims.broadcast_in_dim(t2588, (1, 512, 4096), (0, 1, 2)) # t2589: "cuda:0 f32[1, 512, 4096]" - # t2590 = prims.mul(t2580, t2589) # t2590: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t2591 = ltorch.to(t2590, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2591: "cuda:0 bf16[1, 512, 4096]" - # t2591 = prims.convert_element_type(t2590, dtypes.bfloat16) # t2591: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t2601 = ltorch.mul(t2591, t_transformer_h_15_norm_2_weight) # t2601: "cuda:0 bf16[1, 512, 4096]" - # t2597 = prims.broadcast_in_dim(t_transformer_h_15_norm_2_weight, (1, 512, 4096), (2,)) # t2597: "cuda:0 bf16[1, 512, 4096]" - # t2598 = prims.convert_element_type(t2591, dtypes.float32) # t2598: "cuda:0 f32[1, 512, 4096]" - # t2599 = prims.convert_element_type(t2597, dtypes.float32) # t2599: "cuda:0 f32[1, 512, 4096]" - # t2600 = prims.mul(t2598, t2599) # t2600: "cuda:0 f32[1, 512, 4096]" - # t2601 = prims.convert_element_type(t2600, dtypes.bfloat16) # t2601: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2606 = ltorch.linear(t2601, t_transformer_h_15_mlp_fc_1_weight, None) # t2606: "cuda:0 bf16[1, 512, 11008]" - # t2606 = prims.linear(t2601, t_transformer_h_15_mlp_fc_1_weight, None) # t2606: "cuda:0 bf16[1, 512, 11008]" - t2610 = ltorch.linear(t2601, t_transformer_h_15_mlp_fc_2_weight, None) # t2610: "cuda:0 bf16[1, 512, 11008]" - # t2610 = prims.linear(t2601, t_transformer_h_15_mlp_fc_2_weight, None) # t2610: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t2620 = ltorch.silu(t2606, False) # t2620: "cuda:0 bf16[1, 512, 11008]" - # t2611 = prims.convert_element_type(t2606, dtypes.float32) # t2611: "cuda:0 f32[1, 512, 11008]" - # t2612 = prims.neg(t2611) # t2612: "cuda:0 f32[1, 512, 11008]" - # t2613 = prims.exp(t2612) # t2613: "cuda:0 f32[1, 512, 11008]" - # t2614 = prims.add(1.0, t2613) # t2614: "cuda:0 f32[1, 512, 11008]" - # t2615 = prims.reciprocal(t2614) # t2615: "cuda:0 f32[1, 512, 11008]" - # t2616 = prims.convert_element_type(t2615, dtypes.bfloat16) # t2616: "cuda:0 bf16[1, 512, 11008]" - # t2617 = prims.convert_element_type(t2606, dtypes.float32) # t2617: "cuda:0 f32[1, 512, 11008]" - # t2618 = prims.convert_element_type(t2616, dtypes.float32) # t2618: "cuda:0 f32[1, 512, 11008]" - # t2619 = prims.mul(t2617, t2618) # t2619: "cuda:0 f32[1, 512, 11008]" - # t2620 = prims.convert_element_type(t2619, dtypes.bfloat16) # t2620: "cuda:0 bf16[1, 512, 11008]" - t2624 = ltorch.mul(t2620, t2610) # t2624: "cuda:0 bf16[1, 512, 11008]" - # t2621 = prims.convert_element_type(t2620, dtypes.float32) # t2621: "cuda:0 f32[1, 512, 11008]" - # t2622 = prims.convert_element_type(t2610, dtypes.float32) # t2622: "cuda:0 f32[1, 512, 11008]" - # t2623 = prims.mul(t2621, t2622) # t2623: "cuda:0 f32[1, 512, 11008]" - # t2624 = prims.convert_element_type(t2623, dtypes.bfloat16) # t2624: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2628 = ltorch.linear(t2624, t_transformer_h_15_mlp_proj_weight, None) # t2628: "cuda:0 bf16[1, 512, 4096]" - # t2628 = prims.linear(t2624, t_transformer_h_15_mlp_proj_weight, None) # t2628: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t2632 = ltorch.add(t2628, t2579, alpha=None) # t2632: "cuda:0 bf16[1, 512, 4096]" - # t2629 = prims.convert_element_type(t2628, dtypes.float32) # t2629: "cuda:0 f32[1, 512, 4096]" - # t2630 = prims.convert_element_type(t2579, dtypes.float32) # t2630: "cuda:0 f32[1, 512, 4096]" - # t2631 = prims.add(t2629, t2630) # t2631: "cuda:0 f32[1, 512, 4096]" - # t2632 = prims.convert_element_type(t2631, dtypes.bfloat16) # t2632: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t2633 = prims.convert_element_type(t2632, dtypes.float32) # t2633: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t2634 = ltorch.mul(t2633, t2633) # t2634: "cuda:0 f32[1, 512, 4096]" - # t2634 = prims.mul(t2633, t2633) # t2634: "cuda:0 f32[1, 512, 4096]" - t2638 = ltorch.mean(t2634, -1, True, dtype=None) # t2638: "cuda:0 f32[1, 512, 1]" - # t2636 = prims.sum(t2634, (2,)) # t2636: "cuda:0 f32[1, 512]" - # t2637 = prims.broadcast_in_dim(t2636, [1, 512, 1], [0, 1]) # t2637: "cuda:0 f32[1, 512, 1]" - # t2638 = ltorch.true_divide(t2637, 4096) # t2638: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t2638 = prims.div(t2637, 4096.0) # t2638: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t2640 = ltorch.add(t2638, 1e-05, alpha=None) # t2640: "cuda:0 f32[1, 512, 1]" - # t2640 = prims.add(t2638, 1e-05) # t2640: "cuda:0 f32[1, 512, 1]" - t2641 = ltorch.rsqrt(t2640) # t2641: "cuda:0 f32[1, 512, 1]" - # t2641 = prims.rsqrt(t2640) # t2641: "cuda:0 f32[1, 512, 1]" - t2643 = ltorch.mul(t2633, t2641) # t2643: "cuda:0 f32[1, 512, 4096]" - # t2642 = prims.broadcast_in_dim(t2641, (1, 512, 4096), (0, 1, 2)) # t2642: "cuda:0 f32[1, 512, 4096]" - # t2643 = prims.mul(t2633, t2642) # t2643: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t2644 = ltorch.to(t2643, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2644: "cuda:0 bf16[1, 512, 4096]" - # t2644 = prims.convert_element_type(t2643, dtypes.bfloat16) # t2644: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t2654 = ltorch.mul(t2644, t_transformer_ln_f_weight) # t2654: "cuda:0 bf16[1, 512, 4096]" - # t2650 = prims.broadcast_in_dim(t_transformer_ln_f_weight, (1, 512, 4096), (2,)) # t2650: "cuda:0 bf16[1, 512, 4096]" - # t2651 = prims.convert_element_type(t2644, dtypes.float32) # t2651: "cuda:0 f32[1, 512, 4096]" - # t2652 = prims.convert_element_type(t2650, dtypes.float32) # t2652: "cuda:0 f32[1, 512, 4096]" - # t2653 = prims.mul(t2651, t2652) # t2653: "cuda:0 f32[1, 512, 4096]" - # t2654 = prims.convert_element_type(t2653, dtypes.bfloat16) # t2654: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2658 = ltorch.linear(t2654, t_lm_head_weight, None) # t2658: "cuda:0 bf16[1, 512, 32000]" - # t2658 = prims.linear(t2654, t_lm_head_weight, None) # t2658: "cuda:0 bf16[1, 512, 32000]" - return t2658 -============================================ END: computation_trc split_forward_backward -============================================ START: primal_trace sort_data_parallel_syncs -# Constructed by Dead Code Elimination (took 4 milliseconds) -import thunder -import thunder.core.dtypes as dtypes -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(idx, tos1, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_1_weight, t_transformer_h_0_mlp_fc_2_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_weight, t_transformer_h_0_norm_2_weight, t_transformer_h_1_attn_attn_weight, t_transformer_h_1_attn_proj_weight, t_transformer_h_1_mlp_fc_1_weight, t_transformer_h_1_mlp_fc_2_weight, t_transformer_h_1_mlp_proj_weight, t_transformer_h_1_norm_1_weight, t_transformer_h_1_norm_2_weight, t_transformer_h_2_attn_attn_weight, t_transformer_h_2_attn_proj_weight, t_transformer_h_2_mlp_fc_1_weight, t_transformer_h_2_mlp_fc_2_weight, t_transformer_h_2_mlp_proj_weight, t_transformer_h_2_norm_1_weight, t_transformer_h_2_norm_2_weight, t_transformer_h_3_attn_attn_weight, t_transformer_h_3_attn_proj_weight, t_transformer_h_3_mlp_fc_1_weight, t_transformer_h_3_mlp_fc_2_weight, t_transformer_h_3_mlp_proj_weight, t_transformer_h_3_norm_1_weight, t_transformer_h_3_norm_2_weight, t_transformer_h_4_attn_attn_weight, t_transformer_h_4_attn_proj_weight, t_transformer_h_4_mlp_fc_1_weight, t_transformer_h_4_mlp_fc_2_weight, t_transformer_h_4_mlp_proj_weight, t_transformer_h_4_norm_1_weight, t_transformer_h_4_norm_2_weight, t_transformer_h_5_attn_attn_weight, t_transformer_h_5_attn_proj_weight, t_transformer_h_5_mlp_fc_1_weight, t_transformer_h_5_mlp_fc_2_weight, t_transformer_h_5_mlp_proj_weight, t_transformer_h_5_norm_1_weight, t_transformer_h_5_norm_2_weight, t_transformer_h_6_attn_attn_weight, t_transformer_h_6_attn_proj_weight, t_transformer_h_6_mlp_fc_1_weight, t_transformer_h_6_mlp_fc_2_weight, t_transformer_h_6_mlp_proj_weight, t_transformer_h_6_norm_1_weight, t_transformer_h_6_norm_2_weight, t_transformer_h_7_attn_attn_weight, t_transformer_h_7_attn_proj_weight, t_transformer_h_7_mlp_fc_1_weight, t_transformer_h_7_mlp_fc_2_weight, t_transformer_h_7_mlp_proj_weight, t_transformer_h_7_norm_1_weight, t_transformer_h_7_norm_2_weight, t_transformer_h_8_attn_attn_weight, t_transformer_h_8_attn_proj_weight, t_transformer_h_8_mlp_fc_1_weight, t_transformer_h_8_mlp_fc_2_weight, t_transformer_h_8_mlp_proj_weight, t_transformer_h_8_norm_1_weight, t_transformer_h_8_norm_2_weight, t_transformer_h_9_attn_attn_weight, t_transformer_h_9_attn_proj_weight, t_transformer_h_9_mlp_fc_1_weight, t_transformer_h_9_mlp_fc_2_weight, t_transformer_h_9_mlp_proj_weight, t_transformer_h_9_norm_1_weight, t_transformer_h_9_norm_2_weight, t_transformer_h_10_attn_attn_weight, t_transformer_h_10_attn_proj_weight, t_transformer_h_10_mlp_fc_1_weight, t_transformer_h_10_mlp_fc_2_weight, t_transformer_h_10_mlp_proj_weight, t_transformer_h_10_norm_1_weight, t_transformer_h_10_norm_2_weight, t_transformer_h_11_attn_attn_weight, t_transformer_h_11_attn_proj_weight, t_transformer_h_11_mlp_fc_1_weight, t_transformer_h_11_mlp_fc_2_weight, t_transformer_h_11_mlp_proj_weight, t_transformer_h_11_norm_1_weight, t_transformer_h_11_norm_2_weight, t_transformer_h_12_attn_attn_weight, t_transformer_h_12_attn_proj_weight, t_transformer_h_12_mlp_fc_1_weight, t_transformer_h_12_mlp_fc_2_weight, t_transformer_h_12_mlp_proj_weight, t_transformer_h_12_norm_1_weight, t_transformer_h_12_norm_2_weight, t_transformer_h_13_attn_attn_weight, t_transformer_h_13_attn_proj_weight, t_transformer_h_13_mlp_fc_1_weight, t_transformer_h_13_mlp_fc_2_weight, t_transformer_h_13_mlp_proj_weight, t_transformer_h_13_norm_1_weight, t_transformer_h_13_norm_2_weight, t_transformer_h_14_attn_attn_weight, t_transformer_h_14_attn_proj_weight, t_transformer_h_14_mlp_fc_1_weight, t_transformer_h_14_mlp_fc_2_weight, t_transformer_h_14_mlp_proj_weight, t_transformer_h_14_norm_1_weight, t_transformer_h_14_norm_2_weight, t_transformer_h_15_attn_attn_weight, t_transformer_h_15_attn_proj_weight, t_transformer_h_15_mlp_fc_1_weight, t_transformer_h_15_mlp_fc_2_weight, t_transformer_h_15_mlp_proj_weight, t_transformer_h_15_norm_1_weight, t_transformer_h_15_norm_2_weight, t_transformer_ln_f_weight, t_transformer_wte_weight): - # idx: "cuda:0 i64[1, 512]" - # tos1: "cuda:0 f32[4096, 128]" - # t_lm_head_weight: "cuda:0 bf16[32000, 4096]" - # t_sin: "cuda:0 f32[4096, 128]" - # t_transformer_h_0_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_0_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_0_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_0_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_0_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_0_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_0_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_1_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_1_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_1_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_1_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_1_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_1_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_1_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_2_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_2_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_2_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_2_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_2_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_2_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_2_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_3_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_3_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_3_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_3_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_3_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_3_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_3_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_4_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_4_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_4_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_4_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_4_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_4_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_4_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_5_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_5_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_5_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_5_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_5_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_5_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_5_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_6_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_6_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_6_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_6_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_6_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_6_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_6_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_7_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_7_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_7_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_7_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_7_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_7_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_7_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_8_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_8_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_8_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_8_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_8_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_8_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_8_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_9_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_9_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_9_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_9_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_9_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_9_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_9_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_10_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_10_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_10_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_10_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_10_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_10_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_10_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_11_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_11_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_11_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_11_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_11_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_11_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_11_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_12_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_12_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_12_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_12_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_12_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_12_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_12_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_13_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_13_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_13_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_13_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_13_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_13_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_13_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_14_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_14_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_14_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_14_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_14_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_14_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_14_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_15_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_15_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_15_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_15_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_15_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_15_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_15_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_ln_f_weight: "cuda:0 bf16[4096]" - # t_transformer_wte_weight: "cuda:0 bf16[32000, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:85: cos = self.cos[:T] - cos = ltorch.getitem(tos1, slice(None, 512, None)) # cos: "cuda:0 f32[512, 128]" - # cos = prims.slice_prim(tos1, [0, 0], [512, 128], [1, 1]) # cos: "cuda:0 f32[512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:86: sin = self.sin[:T] - sin = ltorch.getitem(t_sin, slice(None, 512, None)) # sin: "cuda:0 f32[512, 128]" - # sin = prims.slice_prim(t_sin, [0, 0], [512, 128], [1, 1]) # sin: "cuda:0 f32[512, 128]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/sparse.py:190: return F.embedding( - x = ltorch.embedding(idx, t_transformer_wte_weight, None, None, 2.0, False, False) # x: "cuda:0 bf16[1, 512, 4096]" - # t16 = ltorch.reshape(idx, [512]) # t16: "cuda:0 i64[512]" - # t16 = prims.reshape(idx, (512,)) # t16: "cuda:0 i64[512]" - # t17 = prims.take(t_transformer_wte_weight, t16, 0) # t17: "cuda:0 bf16[512, 4096]" - # x = ltorch.reshape(t17, [1, 512, 4096]) # x: "cuda:0 bf16[1, 512, 4096]" - # x = prims.reshape(t17, (1, 512, 4096)) # x: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - a = prims.convert_element_type(x, dtypes.float32) # a: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - result = ltorch.mul(a, a) # result: "cuda:0 f32[1, 512, 4096]" - # result = prims.mul(a, a) # result: "cuda:0 f32[1, 512, 4096]" - norm_x = ltorch.mean(result, -1, True, dtype=None) # norm_x: "cuda:0 f32[1, 512, 1]" - # t24 = prims.sum(result, (2,)) # t24: "cuda:0 f32[1, 512]" - # t25 = prims.broadcast_in_dim(t24, [1, 512, 1], [0, 1]) # t25: "cuda:0 f32[1, 512, 1]" - # norm_x = ltorch.true_divide(t25, 4096) # norm_x: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # norm_x = prims.div(t25, 4096.0) # norm_x: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t28 = ltorch.add(norm_x, 1e-05, alpha=None) # t28: "cuda:0 f32[1, 512, 1]" - # t28 = prims.add(norm_x, 1e-05) # t28: "cuda:0 f32[1, 512, 1]" - b = ltorch.rsqrt(t28) # b: "cuda:0 f32[1, 512, 1]" - # b = prims.rsqrt(t28) # b: "cuda:0 f32[1, 512, 1]" - x_normed = ltorch.mul(a, b) # x_normed: "cuda:0 f32[1, 512, 4096]" - # t30 = prims.broadcast_in_dim(b, (1, 512, 4096), (0, 1, 2)) # t30: "cuda:0 f32[1, 512, 4096]" - # x_normed = prims.mul(a, t30) # x_normed: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t32 = ltorch.to(x_normed, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t32: "cuda:0 bf16[1, 512, 4096]" - # t32 = prims.convert_element_type(x_normed, dtypes.bfloat16) # t32: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - input = ltorch.mul(t32, t_transformer_h_0_norm_1_weight) # input: "cuda:0 bf16[1, 512, 4096]" - # t38 = prims.broadcast_in_dim(t_transformer_h_0_norm_1_weight, (1, 512, 4096), (2,)) # t38: "cuda:0 bf16[1, 512, 4096]" - # t39 = prims.convert_element_type(t32, dtypes.float32) # t39: "cuda:0 f32[1, 512, 4096]" - # t40 = prims.convert_element_type(t38, dtypes.float32) # t40: "cuda:0 f32[1, 512, 4096]" - # t41 = prims.mul(t39, t40) # t41: "cuda:0 f32[1, 512, 4096]" - # input = prims.convert_element_type(t41, dtypes.bfloat16) # input: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - qkv = ltorch.linear(input, t_transformer_h_0_attn_attn_weight, None) # qkv: "cuda:0 bf16[1, 512, 12288]" - # qkv = prims.linear(input, t_transformer_h_0_attn_attn_weight, None) # qkv: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t51 = ltorch.view(qkv, 1, 512, 32, 3, 128) # t51: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t51 = ltorch.reshape(qkv, (1, 512, 32, 3, 128)) # t51: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t51 = prims.reshape(qkv, (1, 512, 32, 3, 128)) # t51: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t52 = ltorch.permute(t51, 0, 2, 3, 1, 4) # t52: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t52 = prims.transpose(t51, (0, 2, 3, 1, 4)) # t52: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (res, k, v) = ltorch.split(t52, (1, 1, 1), 2) - # res = prims.slice_prim(t52, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # res: "cuda:0 bf16[1, 32, 1, 512, 128]" - # k = prims.slice_prim(t52, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # k: "cuda:0 bf16[1, 32, 1, 512, 128]" - # v = prims.slice_prim(t52, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # v: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - q = ltorch.reshape(res, 1, -1, 512, 128) # q: "cuda:0 bf16[1, 32, 512, 128]" - # q = prims.reshape(res, (1, 32, 512, 128)) # q: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t57 = ltorch.reshape(k, 1, -1, 512, 128) # t57: "cuda:0 bf16[1, 32, 512, 128]" - # t57 = prims.reshape(k, (1, 32, 512, 128)) # t57: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t58 = ltorch.reshape(v, 1, -1, 512, 128) # t58: "cuda:0 bf16[1, 32, 512, 128]" - # t58 = prims.reshape(v, (1, 32, 512, 128)) # t58: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t60 = ltorch.getitem(q, (..., slice(None, 128, None))) # t60: "cuda:0 bf16[1, 32, 512, 128]" - # t60 = prims.slice_prim(q, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t60: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - x1 = ltorch.getitem(t60, (..., slice(None, 64, None))) # x1: "cuda:0 bf16[1, 32, 512, 64]" - # x1 = prims.slice_prim(t60, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # x1: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - x2 = ltorch.getitem(t60, (..., slice(64, None, None))) # x2: "cuda:0 bf16[1, 32, 512, 64]" - # x2 = prims.slice_prim(t60, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # x2: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t65 = ltorch.neg(x2) # t65: "cuda:0 bf16[1, 32, 512, 64]" - # t63 = prims.convert_element_type(x2, dtypes.float32) # t63: "cuda:0 f32[1, 32, 512, 64]" - # t64 = prims.neg(t63) # t64: "cuda:0 f32[1, 32, 512, 64]" - # t65 = prims.convert_element_type(t64, dtypes.bfloat16) # t65: "cuda:0 bf16[1, 32, 512, 64]" - rotated = ltorch.cat((t65, x1), -1) # rotated: "cuda:0 bf16[1, 32, 512, 128]" - # rotated = prims.cat((t65, x1), -1) # rotated: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t69 = ltorch.mul(t60, cos) # t69: "cuda:0 f32[1, 32, 512, 128]" - # t67 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t67: "cuda:0 f32[1, 32, 512, 128]" - # t68 = prims.convert_element_type(t60, dtypes.float32) # t68: "cuda:0 f32[1, 32, 512, 128]" - # t69 = prims.mul(t68, t67) # t69: "cuda:0 f32[1, 32, 512, 128]" - t72 = ltorch.mul(rotated, sin) # t72: "cuda:0 f32[1, 32, 512, 128]" - # t70 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t70: "cuda:0 f32[1, 32, 512, 128]" - # t71 = prims.convert_element_type(rotated, dtypes.float32) # t71: "cuda:0 f32[1, 32, 512, 128]" - # t72 = prims.mul(t71, t70) # t72: "cuda:0 f32[1, 32, 512, 128]" - roped = ltorch.add(t69, t72, alpha=None) # roped: "cuda:0 f32[1, 32, 512, 128]" - # roped = prims.add(t69, t72) # roped: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - q_roped = ltorch.to(roped, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # q_roped: "cuda:0 bf16[1, 32, 512, 128]" - # q_roped = prims.convert_element_type(roped, dtypes.bfloat16) # q_roped: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t75 = ltorch.getitem(t57, (..., slice(None, 128, None))) # t75: "cuda:0 bf16[1, 32, 512, 128]" - # t75 = prims.slice_prim(t57, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t75: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t76 = ltorch.getitem(t75, (..., slice(None, 64, None))) # t76: "cuda:0 bf16[1, 32, 512, 64]" - # t76 = prims.slice_prim(t75, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t76: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - tos = ltorch.getitem(t75, (..., slice(64, None, None))) # tos: "cuda:0 bf16[1, 32, 512, 64]" - # tos = prims.slice_prim(t75, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # tos: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t80 = ltorch.neg(tos) # t80: "cuda:0 bf16[1, 32, 512, 64]" - # t78 = prims.convert_element_type(tos, dtypes.float32) # t78: "cuda:0 f32[1, 32, 512, 64]" - # t79 = prims.neg(t78) # t79: "cuda:0 f32[1, 32, 512, 64]" - # t80 = prims.convert_element_type(t79, dtypes.bfloat16) # t80: "cuda:0 bf16[1, 32, 512, 64]" - t81 = ltorch.cat((t80, t76), -1) # t81: "cuda:0 bf16[1, 32, 512, 128]" - # t81 = prims.cat((t80, t76), -1) # t81: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t84 = ltorch.mul(t75, cos) # t84: "cuda:0 f32[1, 32, 512, 128]" - # t82 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t82: "cuda:0 f32[1, 32, 512, 128]" - # t83 = prims.convert_element_type(t75, dtypes.float32) # t83: "cuda:0 f32[1, 32, 512, 128]" - # t84 = prims.mul(t83, t82) # t84: "cuda:0 f32[1, 32, 512, 128]" - t87 = ltorch.mul(t81, sin) # t87: "cuda:0 f32[1, 32, 512, 128]" - # t85 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t85: "cuda:0 f32[1, 32, 512, 128]" - # t86 = prims.convert_element_type(t81, dtypes.float32) # t86: "cuda:0 f32[1, 32, 512, 128]" - # t87 = prims.mul(t86, t85) # t87: "cuda:0 f32[1, 32, 512, 128]" - t88 = ltorch.add(t84, t87, alpha=None) # t88: "cuda:0 f32[1, 32, 512, 128]" - # t88 = prims.add(t84, t87) # t88: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - k_roped = ltorch.to(t88, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # k_roped: "cuda:0 bf16[1, 32, 512, 128]" - # k_roped = prims.convert_element_type(t88, dtypes.bfloat16) # k_roped: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t90 = ltorch.getitem(q, (..., slice(128, None, None))) # t90: "cuda:0 bf16[1, 32, 512, 0]" - # t90 = prims.slice_prim(q, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t90: "cuda:0 bf16[1, 32, 512, 0]" - t91 = ltorch.cat((q_roped, t90), -1) # t91: "cuda:0 bf16[1, 32, 512, 128]" - # t91 = prims.cat((q_roped, t90), -1) # t91: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t92 = ltorch.getitem(t57, (..., slice(128, None, None))) # t92: "cuda:0 bf16[1, 32, 512, 0]" - # t92 = prims.slice_prim(t57, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t92: "cuda:0 bf16[1, 32, 512, 0]" - t93 = ltorch.cat((k_roped, t92), -1) # t93: "cuda:0 bf16[1, 32, 512, 128]" - # t93 = prims.cat((k_roped, t92), -1) # t93: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - y = ltorch.scaled_dot_product_attention(t91, t93, t58, None, 0.0, True, scale=0.08838834764831843) # y: "cuda:0 bf16[1, 32, 512, 128]" - # t96 = ltorch.mul(t91, 0.29730177875068026) # t96: "cuda:0 bf16[1, 32, 512, 128]" - # t94 = prims.convert_element_type(t91, dtypes.float32) # t94: "cuda:0 f32[1, 32, 512, 128]" - # t95 = prims.mul(t94, 0.29730177875068026) # t95: "cuda:0 f32[1, 32, 512, 128]" - # t96 = prims.convert_element_type(t95, dtypes.bfloat16) # t96: "cuda:0 bf16[1, 32, 512, 128]" - # t97 = ltorch.transpose(t93, -2, -1) # t97: "cuda:0 bf16[1, 32, 128, 512]" - # t97 = prims.transpose(t93, (0, 1, 3, 2)) # t97: "cuda:0 bf16[1, 32, 128, 512]" - # t100 = ltorch.mul(t97, 0.29730177875068026) # t100: "cuda:0 bf16[1, 32, 128, 512]" - # t98 = prims.convert_element_type(t97, dtypes.float32) # t98: "cuda:0 f32[1, 32, 128, 512]" - # t99 = prims.mul(t98, 0.29730177875068026) # t99: "cuda:0 f32[1, 32, 128, 512]" - # t100 = prims.convert_element_type(t99, dtypes.bfloat16) # t100: "cuda:0 bf16[1, 32, 128, 512]" - # t101 = ltorch.matmul(t96, t100) # t101: "cuda:0 bf16[1, 32, 512, 512]" - # t101 = prims.matmul(t96, t100) # t101: "cuda:0 bf16[1, 32, 512, 512]" - # t111 = ltorch.tril(t101, 0, fill_value=-float('inf')) # t111: "cuda:0 bf16[1, 32, 512, 512]" - # t102 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t102: "cuda:0 i64[512]" - # t102 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t102: "cuda:0 i64[512]" - # t103 = ltorch.unsqueeze(t102, -1) # t103: "cuda:0 i64[512, 1]" - # t103 = prims.broadcast_in_dim(t102, [512, 1], [0]) # t103: "cuda:0 i64[512, 1]" - # t104 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t104: "cuda:0 i64[512]" - # t104 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t104: "cuda:0 i64[512]" - # t105 = ltorch.unsqueeze(t104, -2) # t105: "cuda:0 i64[1, 512]" - # t105 = prims.broadcast_in_dim(t104, [1, 512], [1]) # t105: "cuda:0 i64[1, 512]" - # t106 = ltorch.add(t103, 0, alpha=None) # t106: "cuda:0 i64[512, 1]" - # t106 = prims.add(t103, 0) # t106: "cuda:0 i64[512, 1]" - # t109 = ltorch.ge(t106, t105) # t109: "cuda:0 b8[512, 512]" - # t107 = prims.broadcast_in_dim(t106, (512, 512), (0, 1)) # t107: "cuda:0 i64[512, 512]" - # t108 = prims.broadcast_in_dim(t105, (512, 512), (0, 1)) # t108: "cuda:0 i64[512, 512]" - # t109 = prims.ge(t107, t108) # t109: "cuda:0 b8[512, 512]" - # t111 = ltorch.where(t109, t101, -float('inf')) # t111: "cuda:0 bf16[1, 32, 512, 512]" - # t110 = prims.broadcast_in_dim(t109, (1, 32, 512, 512), (2, 3)) # t110: "cuda:0 b8[1, 32, 512, 512]" - # t111 = prims.where(t110, t101, -float('inf')) # t111: "cuda:0 bf16[1, 32, 512, 512]" - # t122 = ltorch._softmax(t111, -1, dtype=None) # t122: "cuda:0 bf16[1, 32, 512, 512]" - # t112 = ltorch.to(t111, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t112: "cuda:0 f32[1, 32, 512, 512]" - # t112 = prims.convert_element_type(t111, dtypes.float32) # t112: "cuda:0 f32[1, 32, 512, 512]" - # t114 = ltorch.amax(t112, -1, True) # t114: "cuda:0 f32[1, 32, 512, 1]" - # t113 = prims.amax(t112, (3,)) # t113: "cuda:0 f32[1, 32, 512]" - # t114 = prims.broadcast_in_dim(t113, [1, 32, 512, 1], [0, 1, 2]) # t114: "cuda:0 f32[1, 32, 512, 1]" - # t116 = ltorch.sub(t112, t114, alpha=None) # t116: "cuda:0 f32[1, 32, 512, 512]" - # t115 = prims.broadcast_in_dim(t114, (1, 32, 512, 512), (0, 1, 2, 3)) # t115: "cuda:0 f32[1, 32, 512, 512]" - # t116 = prims.sub(t112, t115) # t116: "cuda:0 f32[1, 32, 512, 512]" - # t117 = ltorch.exp(t116) # t117: "cuda:0 f32[1, 32, 512, 512]" - # t117 = prims.exp(t116) # t117: "cuda:0 f32[1, 32, 512, 512]" - # t119 = ltorch.sum(t117, -1, True, dtype=None) # t119: "cuda:0 f32[1, 32, 512, 1]" - # t118 = prims.sum(t117, (3,)) # t118: "cuda:0 f32[1, 32, 512]" - # t119 = prims.broadcast_in_dim(t118, [1, 32, 512, 1], [0, 1, 2]) # t119: "cuda:0 f32[1, 32, 512, 1]" - # t121 = ltorch.true_divide(t117, t119) # t121: "cuda:0 f32[1, 32, 512, 512]" - # t120 = prims.broadcast_in_dim(t119, (1, 32, 512, 512), (0, 1, 2, 3)) # t120: "cuda:0 f32[1, 32, 512, 512]" - # t121 = prims.div(t117, t120) # t121: "cuda:0 f32[1, 32, 512, 512]" - # t122 = ltorch.to(t121, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t122: "cuda:0 bf16[1, 32, 512, 512]" - # t122 = prims.convert_element_type(t121, dtypes.bfloat16) # t122: "cuda:0 bf16[1, 32, 512, 512]" - # y = ltorch.matmul(t122, t58) # y: "cuda:0 bf16[1, 32, 512, 128]" - # y = prims.matmul(t122, t58) # y: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t124 = ltorch.transpose(y, 1, 2) # t124: "cuda:0 bf16[1, 512, 32, 128]" - # t124 = prims.transpose(y, (0, 2, 1, 3)) # t124: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t125 = ltorch.reshape(t124, 1, 512, 4096) # t125: "cuda:0 bf16[1, 512, 4096]" - # t125 = prims.reshape(t124, (1, 512, 4096)) # t125: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - attention_output = ltorch.linear(t125, t_transformer_h_0_attn_proj_weight, None) # attention_output: "cuda:0 bf16[1, 512, 4096]" - # attention_output = prims.linear(t125, t_transformer_h_0_attn_proj_weight, None) # attention_output: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t134 = ltorch.add(attention_output, x, alpha=None) # t134: "cuda:0 bf16[1, 512, 4096]" - # t131 = prims.convert_element_type(attention_output, dtypes.float32) # t131: "cuda:0 f32[1, 512, 4096]" - # t132 = prims.convert_element_type(x, dtypes.float32) # t132: "cuda:0 f32[1, 512, 4096]" - # t133 = prims.add(t131, t132) # t133: "cuda:0 f32[1, 512, 4096]" - # t134 = prims.convert_element_type(t133, dtypes.bfloat16) # t134: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t135 = prims.convert_element_type(t134, dtypes.float32) # t135: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t136 = ltorch.mul(t135, t135) # t136: "cuda:0 f32[1, 512, 4096]" - # t136 = prims.mul(t135, t135) # t136: "cuda:0 f32[1, 512, 4096]" - t140 = ltorch.mean(t136, -1, True, dtype=None) # t140: "cuda:0 f32[1, 512, 1]" - # t138 = prims.sum(t136, (2,)) # t138: "cuda:0 f32[1, 512]" - # t139 = prims.broadcast_in_dim(t138, [1, 512, 1], [0, 1]) # t139: "cuda:0 f32[1, 512, 1]" - # t140 = ltorch.true_divide(t139, 4096) # t140: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t140 = prims.div(t139, 4096.0) # t140: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t142 = ltorch.add(t140, 1e-05, alpha=None) # t142: "cuda:0 f32[1, 512, 1]" - # t142 = prims.add(t140, 1e-05) # t142: "cuda:0 f32[1, 512, 1]" - t143 = ltorch.rsqrt(t142) # t143: "cuda:0 f32[1, 512, 1]" - # t143 = prims.rsqrt(t142) # t143: "cuda:0 f32[1, 512, 1]" - t145 = ltorch.mul(t135, t143) # t145: "cuda:0 f32[1, 512, 4096]" - # t144 = prims.broadcast_in_dim(t143, (1, 512, 4096), (0, 1, 2)) # t144: "cuda:0 f32[1, 512, 4096]" - # t145 = prims.mul(t135, t144) # t145: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t146 = ltorch.to(t145, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t146: "cuda:0 bf16[1, 512, 4096]" - # t146 = prims.convert_element_type(t145, dtypes.bfloat16) # t146: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t156 = ltorch.mul(t146, t_transformer_h_0_norm_2_weight) # t156: "cuda:0 bf16[1, 512, 4096]" - # t152 = prims.broadcast_in_dim(t_transformer_h_0_norm_2_weight, (1, 512, 4096), (2,)) # t152: "cuda:0 bf16[1, 512, 4096]" - # t153 = prims.convert_element_type(t146, dtypes.float32) # t153: "cuda:0 f32[1, 512, 4096]" - # t154 = prims.convert_element_type(t152, dtypes.float32) # t154: "cuda:0 f32[1, 512, 4096]" - # t155 = prims.mul(t153, t154) # t155: "cuda:0 f32[1, 512, 4096]" - # t156 = prims.convert_element_type(t155, dtypes.bfloat16) # t156: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - x_fc_1 = ltorch.linear(t156, t_transformer_h_0_mlp_fc_1_weight, None) # x_fc_1: "cuda:0 bf16[1, 512, 11008]" - # x_fc_1 = prims.linear(t156, t_transformer_h_0_mlp_fc_1_weight, None) # x_fc_1: "cuda:0 bf16[1, 512, 11008]" - x_fc_2 = ltorch.linear(t156, t_transformer_h_0_mlp_fc_2_weight, None) # x_fc_2: "cuda:0 bf16[1, 512, 11008]" - # x_fc_2 = prims.linear(t156, t_transformer_h_0_mlp_fc_2_weight, None) # x_fc_2: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t175 = ltorch.silu(x_fc_1, False) # t175: "cuda:0 bf16[1, 512, 11008]" - # t166 = prims.convert_element_type(x_fc_1, dtypes.float32) # t166: "cuda:0 f32[1, 512, 11008]" - # t167 = prims.neg(t166) # t167: "cuda:0 f32[1, 512, 11008]" - # t168 = prims.exp(t167) # t168: "cuda:0 f32[1, 512, 11008]" - # t169 = prims.add(1.0, t168) # t169: "cuda:0 f32[1, 512, 11008]" - # t170 = prims.reciprocal(t169) # t170: "cuda:0 f32[1, 512, 11008]" - # t171 = prims.convert_element_type(t170, dtypes.bfloat16) # t171: "cuda:0 bf16[1, 512, 11008]" - # t172 = prims.convert_element_type(x_fc_1, dtypes.float32) # t172: "cuda:0 f32[1, 512, 11008]" - # t173 = prims.convert_element_type(t171, dtypes.float32) # t173: "cuda:0 f32[1, 512, 11008]" - # t174 = prims.mul(t172, t173) # t174: "cuda:0 f32[1, 512, 11008]" - # t175 = prims.convert_element_type(t174, dtypes.bfloat16) # t175: "cuda:0 bf16[1, 512, 11008]" - t179 = ltorch.mul(t175, x_fc_2) # t179: "cuda:0 bf16[1, 512, 11008]" - # t176 = prims.convert_element_type(t175, dtypes.float32) # t176: "cuda:0 f32[1, 512, 11008]" - # t177 = prims.convert_element_type(x_fc_2, dtypes.float32) # t177: "cuda:0 f32[1, 512, 11008]" - # t178 = prims.mul(t176, t177) # t178: "cuda:0 f32[1, 512, 11008]" - # t179 = prims.convert_element_type(t178, dtypes.bfloat16) # t179: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t183 = ltorch.linear(t179, t_transformer_h_0_mlp_proj_weight, None) # t183: "cuda:0 bf16[1, 512, 4096]" - # t183 = prims.linear(t179, t_transformer_h_0_mlp_proj_weight, None) # t183: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t187 = ltorch.add(t183, t134, alpha=None) # t187: "cuda:0 bf16[1, 512, 4096]" - # t184 = prims.convert_element_type(t183, dtypes.float32) # t184: "cuda:0 f32[1, 512, 4096]" - # t185 = prims.convert_element_type(t134, dtypes.float32) # t185: "cuda:0 f32[1, 512, 4096]" - # t186 = prims.add(t184, t185) # t186: "cuda:0 f32[1, 512, 4096]" - # t187 = prims.convert_element_type(t186, dtypes.bfloat16) # t187: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t189 = prims.convert_element_type(t187, dtypes.float32) # t189: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t190 = ltorch.mul(t189, t189) # t190: "cuda:0 f32[1, 512, 4096]" - # t190 = prims.mul(t189, t189) # t190: "cuda:0 f32[1, 512, 4096]" - t194 = ltorch.mean(t190, -1, True, dtype=None) # t194: "cuda:0 f32[1, 512, 1]" - # t192 = prims.sum(t190, (2,)) # t192: "cuda:0 f32[1, 512]" - # t193 = prims.broadcast_in_dim(t192, [1, 512, 1], [0, 1]) # t193: "cuda:0 f32[1, 512, 1]" - # t194 = ltorch.true_divide(t193, 4096) # t194: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t194 = prims.div(t193, 4096.0) # t194: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t196 = ltorch.add(t194, 1e-05, alpha=None) # t196: "cuda:0 f32[1, 512, 1]" - # t196 = prims.add(t194, 1e-05) # t196: "cuda:0 f32[1, 512, 1]" - t197 = ltorch.rsqrt(t196) # t197: "cuda:0 f32[1, 512, 1]" - # t197 = prims.rsqrt(t196) # t197: "cuda:0 f32[1, 512, 1]" - t199 = ltorch.mul(t189, t197) # t199: "cuda:0 f32[1, 512, 4096]" - # t198 = prims.broadcast_in_dim(t197, (1, 512, 4096), (0, 1, 2)) # t198: "cuda:0 f32[1, 512, 4096]" - # t199 = prims.mul(t189, t198) # t199: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t200 = ltorch.to(t199, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t200: "cuda:0 bf16[1, 512, 4096]" - # t200 = prims.convert_element_type(t199, dtypes.bfloat16) # t200: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t210 = ltorch.mul(t200, t_transformer_h_1_norm_1_weight) # t210: "cuda:0 bf16[1, 512, 4096]" - # t206 = prims.broadcast_in_dim(t_transformer_h_1_norm_1_weight, (1, 512, 4096), (2,)) # t206: "cuda:0 bf16[1, 512, 4096]" - # t207 = prims.convert_element_type(t200, dtypes.float32) # t207: "cuda:0 f32[1, 512, 4096]" - # t208 = prims.convert_element_type(t206, dtypes.float32) # t208: "cuda:0 f32[1, 512, 4096]" - # t209 = prims.mul(t207, t208) # t209: "cuda:0 f32[1, 512, 4096]" - # t210 = prims.convert_element_type(t209, dtypes.bfloat16) # t210: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t215 = ltorch.linear(t210, t_transformer_h_1_attn_attn_weight, None) # t215: "cuda:0 bf16[1, 512, 12288]" - # t215 = prims.linear(t210, t_transformer_h_1_attn_attn_weight, None) # t215: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t216 = ltorch.view(t215, 1, 512, 32, 3, 128) # t216: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t216 = ltorch.reshape(t215, (1, 512, 32, 3, 128)) # t216: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t216 = prims.reshape(t215, (1, 512, 32, 3, 128)) # t216: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t217 = ltorch.permute(t216, 0, 2, 3, 1, 4) # t217: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t217 = prims.transpose(t216, (0, 2, 3, 1, 4)) # t217: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t218, t219, t220) = ltorch.split(t217, (1, 1, 1), 2) - # t218 = prims.slice_prim(t217, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t218: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t219 = prims.slice_prim(t217, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t219: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t220 = prims.slice_prim(t217, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t220: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t221 = ltorch.reshape(t218, 1, -1, 512, 128) # t221: "cuda:0 bf16[1, 32, 512, 128]" - # t221 = prims.reshape(t218, (1, 32, 512, 128)) # t221: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t222 = ltorch.reshape(t219, 1, -1, 512, 128) # t222: "cuda:0 bf16[1, 32, 512, 128]" - # t222 = prims.reshape(t219, (1, 32, 512, 128)) # t222: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t223 = ltorch.reshape(t220, 1, -1, 512, 128) # t223: "cuda:0 bf16[1, 32, 512, 128]" - # t223 = prims.reshape(t220, (1, 32, 512, 128)) # t223: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t224 = ltorch.getitem(t221, (..., slice(None, 128, None))) # t224: "cuda:0 bf16[1, 32, 512, 128]" - # t224 = prims.slice_prim(t221, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t224: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t225 = ltorch.getitem(t224, (..., slice(None, 64, None))) # t225: "cuda:0 bf16[1, 32, 512, 64]" - # t225 = prims.slice_prim(t224, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t225: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t226 = ltorch.getitem(t224, (..., slice(64, None, None))) # t226: "cuda:0 bf16[1, 32, 512, 64]" - # t226 = prims.slice_prim(t224, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t226: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t229 = ltorch.neg(t226) # t229: "cuda:0 bf16[1, 32, 512, 64]" - # t227 = prims.convert_element_type(t226, dtypes.float32) # t227: "cuda:0 f32[1, 32, 512, 64]" - # t228 = prims.neg(t227) # t228: "cuda:0 f32[1, 32, 512, 64]" - # t229 = prims.convert_element_type(t228, dtypes.bfloat16) # t229: "cuda:0 bf16[1, 32, 512, 64]" - t230 = ltorch.cat((t229, t225), -1) # t230: "cuda:0 bf16[1, 32, 512, 128]" - # t230 = prims.cat((t229, t225), -1) # t230: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t233 = ltorch.mul(t224, cos) # t233: "cuda:0 f32[1, 32, 512, 128]" - # t231 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t231: "cuda:0 f32[1, 32, 512, 128]" - # t232 = prims.convert_element_type(t224, dtypes.float32) # t232: "cuda:0 f32[1, 32, 512, 128]" - # t233 = prims.mul(t232, t231) # t233: "cuda:0 f32[1, 32, 512, 128]" - t236 = ltorch.mul(t230, sin) # t236: "cuda:0 f32[1, 32, 512, 128]" - # t234 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t234: "cuda:0 f32[1, 32, 512, 128]" - # t235 = prims.convert_element_type(t230, dtypes.float32) # t235: "cuda:0 f32[1, 32, 512, 128]" - # t236 = prims.mul(t235, t234) # t236: "cuda:0 f32[1, 32, 512, 128]" - t237 = ltorch.add(t233, t236, alpha=None) # t237: "cuda:0 f32[1, 32, 512, 128]" - # t237 = prims.add(t233, t236) # t237: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t238 = ltorch.to(t237, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t238: "cuda:0 bf16[1, 32, 512, 128]" - # t238 = prims.convert_element_type(t237, dtypes.bfloat16) # t238: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t239 = ltorch.getitem(t222, (..., slice(None, 128, None))) # t239: "cuda:0 bf16[1, 32, 512, 128]" - # t239 = prims.slice_prim(t222, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t239: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t240 = ltorch.getitem(t239, (..., slice(None, 64, None))) # t240: "cuda:0 bf16[1, 32, 512, 64]" - # t240 = prims.slice_prim(t239, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t240: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t241 = ltorch.getitem(t239, (..., slice(64, None, None))) # t241: "cuda:0 bf16[1, 32, 512, 64]" - # t241 = prims.slice_prim(t239, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t241: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t244 = ltorch.neg(t241) # t244: "cuda:0 bf16[1, 32, 512, 64]" - # t242 = prims.convert_element_type(t241, dtypes.float32) # t242: "cuda:0 f32[1, 32, 512, 64]" - # t243 = prims.neg(t242) # t243: "cuda:0 f32[1, 32, 512, 64]" - # t244 = prims.convert_element_type(t243, dtypes.bfloat16) # t244: "cuda:0 bf16[1, 32, 512, 64]" - t245 = ltorch.cat((t244, t240), -1) # t245: "cuda:0 bf16[1, 32, 512, 128]" - # t245 = prims.cat((t244, t240), -1) # t245: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t248 = ltorch.mul(t239, cos) # t248: "cuda:0 f32[1, 32, 512, 128]" - # t246 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t246: "cuda:0 f32[1, 32, 512, 128]" - # t247 = prims.convert_element_type(t239, dtypes.float32) # t247: "cuda:0 f32[1, 32, 512, 128]" - # t248 = prims.mul(t247, t246) # t248: "cuda:0 f32[1, 32, 512, 128]" - t251 = ltorch.mul(t245, sin) # t251: "cuda:0 f32[1, 32, 512, 128]" - # t249 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t249: "cuda:0 f32[1, 32, 512, 128]" - # t250 = prims.convert_element_type(t245, dtypes.float32) # t250: "cuda:0 f32[1, 32, 512, 128]" - # t251 = prims.mul(t250, t249) # t251: "cuda:0 f32[1, 32, 512, 128]" - t252 = ltorch.add(t248, t251, alpha=None) # t252: "cuda:0 f32[1, 32, 512, 128]" - # t252 = prims.add(t248, t251) # t252: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t253 = ltorch.to(t252, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t253: "cuda:0 bf16[1, 32, 512, 128]" - # t253 = prims.convert_element_type(t252, dtypes.bfloat16) # t253: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t254 = ltorch.getitem(t221, (..., slice(128, None, None))) # t254: "cuda:0 bf16[1, 32, 512, 0]" - # t254 = prims.slice_prim(t221, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t254: "cuda:0 bf16[1, 32, 512, 0]" - t255 = ltorch.cat((t238, t254), -1) # t255: "cuda:0 bf16[1, 32, 512, 128]" - # t255 = prims.cat((t238, t254), -1) # t255: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t256 = ltorch.getitem(t222, (..., slice(128, None, None))) # t256: "cuda:0 bf16[1, 32, 512, 0]" - # t256 = prims.slice_prim(t222, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t256: "cuda:0 bf16[1, 32, 512, 0]" - t257 = ltorch.cat((t253, t256), -1) # t257: "cuda:0 bf16[1, 32, 512, 128]" - # t257 = prims.cat((t253, t256), -1) # t257: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t287 = ltorch.scaled_dot_product_attention(t255, t257, t223, None, 0.0, True, scale=0.08838834764831843) # t287: "cuda:0 bf16[1, 32, 512, 128]" - # t260 = ltorch.mul(t255, 0.29730177875068026) # t260: "cuda:0 bf16[1, 32, 512, 128]" - # t258 = prims.convert_element_type(t255, dtypes.float32) # t258: "cuda:0 f32[1, 32, 512, 128]" - # t259 = prims.mul(t258, 0.29730177875068026) # t259: "cuda:0 f32[1, 32, 512, 128]" - # t260 = prims.convert_element_type(t259, dtypes.bfloat16) # t260: "cuda:0 bf16[1, 32, 512, 128]" - # t261 = ltorch.transpose(t257, -2, -1) # t261: "cuda:0 bf16[1, 32, 128, 512]" - # t261 = prims.transpose(t257, (0, 1, 3, 2)) # t261: "cuda:0 bf16[1, 32, 128, 512]" - # t264 = ltorch.mul(t261, 0.29730177875068026) # t264: "cuda:0 bf16[1, 32, 128, 512]" - # t262 = prims.convert_element_type(t261, dtypes.float32) # t262: "cuda:0 f32[1, 32, 128, 512]" - # t263 = prims.mul(t262, 0.29730177875068026) # t263: "cuda:0 f32[1, 32, 128, 512]" - # t264 = prims.convert_element_type(t263, dtypes.bfloat16) # t264: "cuda:0 bf16[1, 32, 128, 512]" - # t265 = ltorch.matmul(t260, t264) # t265: "cuda:0 bf16[1, 32, 512, 512]" - # t265 = prims.matmul(t260, t264) # t265: "cuda:0 bf16[1, 32, 512, 512]" - # t275 = ltorch.tril(t265, 0, fill_value=-float('inf')) # t275: "cuda:0 bf16[1, 32, 512, 512]" - # t266 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t266: "cuda:0 i64[512]" - # t266 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t266: "cuda:0 i64[512]" - # t267 = ltorch.unsqueeze(t266, -1) # t267: "cuda:0 i64[512, 1]" - # t267 = prims.broadcast_in_dim(t266, [512, 1], [0]) # t267: "cuda:0 i64[512, 1]" - # t268 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t268: "cuda:0 i64[512]" - # t268 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t268: "cuda:0 i64[512]" - # t269 = ltorch.unsqueeze(t268, -2) # t269: "cuda:0 i64[1, 512]" - # t269 = prims.broadcast_in_dim(t268, [1, 512], [1]) # t269: "cuda:0 i64[1, 512]" - # t270 = ltorch.add(t267, 0, alpha=None) # t270: "cuda:0 i64[512, 1]" - # t270 = prims.add(t267, 0) # t270: "cuda:0 i64[512, 1]" - # t273 = ltorch.ge(t270, t269) # t273: "cuda:0 b8[512, 512]" - # t271 = prims.broadcast_in_dim(t270, (512, 512), (0, 1)) # t271: "cuda:0 i64[512, 512]" - # t272 = prims.broadcast_in_dim(t269, (512, 512), (0, 1)) # t272: "cuda:0 i64[512, 512]" - # t273 = prims.ge(t271, t272) # t273: "cuda:0 b8[512, 512]" - # t275 = ltorch.where(t273, t265, -float('inf')) # t275: "cuda:0 bf16[1, 32, 512, 512]" - # t274 = prims.broadcast_in_dim(t273, (1, 32, 512, 512), (2, 3)) # t274: "cuda:0 b8[1, 32, 512, 512]" - # t275 = prims.where(t274, t265, -float('inf')) # t275: "cuda:0 bf16[1, 32, 512, 512]" - # t286 = ltorch._softmax(t275, -1, dtype=None) # t286: "cuda:0 bf16[1, 32, 512, 512]" - # t276 = ltorch.to(t275, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t276: "cuda:0 f32[1, 32, 512, 512]" - # t276 = prims.convert_element_type(t275, dtypes.float32) # t276: "cuda:0 f32[1, 32, 512, 512]" - # t278 = ltorch.amax(t276, -1, True) # t278: "cuda:0 f32[1, 32, 512, 1]" - # t277 = prims.amax(t276, (3,)) # t277: "cuda:0 f32[1, 32, 512]" - # t278 = prims.broadcast_in_dim(t277, [1, 32, 512, 1], [0, 1, 2]) # t278: "cuda:0 f32[1, 32, 512, 1]" - # t280 = ltorch.sub(t276, t278, alpha=None) # t280: "cuda:0 f32[1, 32, 512, 512]" - # t279 = prims.broadcast_in_dim(t278, (1, 32, 512, 512), (0, 1, 2, 3)) # t279: "cuda:0 f32[1, 32, 512, 512]" - # t280 = prims.sub(t276, t279) # t280: "cuda:0 f32[1, 32, 512, 512]" - # t281 = ltorch.exp(t280) # t281: "cuda:0 f32[1, 32, 512, 512]" - # t281 = prims.exp(t280) # t281: "cuda:0 f32[1, 32, 512, 512]" - # t283 = ltorch.sum(t281, -1, True, dtype=None) # t283: "cuda:0 f32[1, 32, 512, 1]" - # t282 = prims.sum(t281, (3,)) # t282: "cuda:0 f32[1, 32, 512]" - # t283 = prims.broadcast_in_dim(t282, [1, 32, 512, 1], [0, 1, 2]) # t283: "cuda:0 f32[1, 32, 512, 1]" - # t285 = ltorch.true_divide(t281, t283) # t285: "cuda:0 f32[1, 32, 512, 512]" - # t284 = prims.broadcast_in_dim(t283, (1, 32, 512, 512), (0, 1, 2, 3)) # t284: "cuda:0 f32[1, 32, 512, 512]" - # t285 = prims.div(t281, t284) # t285: "cuda:0 f32[1, 32, 512, 512]" - # t286 = ltorch.to(t285, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t286: "cuda:0 bf16[1, 32, 512, 512]" - # t286 = prims.convert_element_type(t285, dtypes.bfloat16) # t286: "cuda:0 bf16[1, 32, 512, 512]" - # t287 = ltorch.matmul(t286, t223) # t287: "cuda:0 bf16[1, 32, 512, 128]" - # t287 = prims.matmul(t286, t223) # t287: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t288 = ltorch.transpose(t287, 1, 2) # t288: "cuda:0 bf16[1, 512, 32, 128]" - # t288 = prims.transpose(t287, (0, 2, 1, 3)) # t288: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t289 = ltorch.reshape(t288, 1, 512, 4096) # t289: "cuda:0 bf16[1, 512, 4096]" - # t289 = prims.reshape(t288, (1, 512, 4096)) # t289: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t293 = ltorch.linear(t289, t_transformer_h_1_attn_proj_weight, None) # t293: "cuda:0 bf16[1, 512, 4096]" - # t293 = prims.linear(t289, t_transformer_h_1_attn_proj_weight, None) # t293: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t297 = ltorch.add(t293, t187, alpha=None) # t297: "cuda:0 bf16[1, 512, 4096]" - # t294 = prims.convert_element_type(t293, dtypes.float32) # t294: "cuda:0 f32[1, 512, 4096]" - # t295 = prims.convert_element_type(t187, dtypes.float32) # t295: "cuda:0 f32[1, 512, 4096]" - # t296 = prims.add(t294, t295) # t296: "cuda:0 f32[1, 512, 4096]" - # t297 = prims.convert_element_type(t296, dtypes.bfloat16) # t297: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t298 = prims.convert_element_type(t297, dtypes.float32) # t298: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t299 = ltorch.mul(t298, t298) # t299: "cuda:0 f32[1, 512, 4096]" - # t299 = prims.mul(t298, t298) # t299: "cuda:0 f32[1, 512, 4096]" - t303 = ltorch.mean(t299, -1, True, dtype=None) # t303: "cuda:0 f32[1, 512, 1]" - # t301 = prims.sum(t299, (2,)) # t301: "cuda:0 f32[1, 512]" - # t302 = prims.broadcast_in_dim(t301, [1, 512, 1], [0, 1]) # t302: "cuda:0 f32[1, 512, 1]" - # t303 = ltorch.true_divide(t302, 4096) # t303: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t303 = prims.div(t302, 4096.0) # t303: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t305 = ltorch.add(t303, 1e-05, alpha=None) # t305: "cuda:0 f32[1, 512, 1]" - # t305 = prims.add(t303, 1e-05) # t305: "cuda:0 f32[1, 512, 1]" - t306 = ltorch.rsqrt(t305) # t306: "cuda:0 f32[1, 512, 1]" - # t306 = prims.rsqrt(t305) # t306: "cuda:0 f32[1, 512, 1]" - t308 = ltorch.mul(t298, t306) # t308: "cuda:0 f32[1, 512, 4096]" - # t307 = prims.broadcast_in_dim(t306, (1, 512, 4096), (0, 1, 2)) # t307: "cuda:0 f32[1, 512, 4096]" - # t308 = prims.mul(t298, t307) # t308: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t309 = ltorch.to(t308, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t309: "cuda:0 bf16[1, 512, 4096]" - # t309 = prims.convert_element_type(t308, dtypes.bfloat16) # t309: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t319 = ltorch.mul(t309, t_transformer_h_1_norm_2_weight) # t319: "cuda:0 bf16[1, 512, 4096]" - # t315 = prims.broadcast_in_dim(t_transformer_h_1_norm_2_weight, (1, 512, 4096), (2,)) # t315: "cuda:0 bf16[1, 512, 4096]" - # t316 = prims.convert_element_type(t309, dtypes.float32) # t316: "cuda:0 f32[1, 512, 4096]" - # t317 = prims.convert_element_type(t315, dtypes.float32) # t317: "cuda:0 f32[1, 512, 4096]" - # t318 = prims.mul(t316, t317) # t318: "cuda:0 f32[1, 512, 4096]" - # t319 = prims.convert_element_type(t318, dtypes.bfloat16) # t319: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t324 = ltorch.linear(t319, t_transformer_h_1_mlp_fc_1_weight, None) # t324: "cuda:0 bf16[1, 512, 11008]" - # t324 = prims.linear(t319, t_transformer_h_1_mlp_fc_1_weight, None) # t324: "cuda:0 bf16[1, 512, 11008]" - t328 = ltorch.linear(t319, t_transformer_h_1_mlp_fc_2_weight, None) # t328: "cuda:0 bf16[1, 512, 11008]" - # t328 = prims.linear(t319, t_transformer_h_1_mlp_fc_2_weight, None) # t328: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t338 = ltorch.silu(t324, False) # t338: "cuda:0 bf16[1, 512, 11008]" - # t329 = prims.convert_element_type(t324, dtypes.float32) # t329: "cuda:0 f32[1, 512, 11008]" - # t330 = prims.neg(t329) # t330: "cuda:0 f32[1, 512, 11008]" - # t331 = prims.exp(t330) # t331: "cuda:0 f32[1, 512, 11008]" - # t332 = prims.add(1.0, t331) # t332: "cuda:0 f32[1, 512, 11008]" - # t333 = prims.reciprocal(t332) # t333: "cuda:0 f32[1, 512, 11008]" - # t334 = prims.convert_element_type(t333, dtypes.bfloat16) # t334: "cuda:0 bf16[1, 512, 11008]" - # t335 = prims.convert_element_type(t324, dtypes.float32) # t335: "cuda:0 f32[1, 512, 11008]" - # t336 = prims.convert_element_type(t334, dtypes.float32) # t336: "cuda:0 f32[1, 512, 11008]" - # t337 = prims.mul(t335, t336) # t337: "cuda:0 f32[1, 512, 11008]" - # t338 = prims.convert_element_type(t337, dtypes.bfloat16) # t338: "cuda:0 bf16[1, 512, 11008]" - t342 = ltorch.mul(t338, t328) # t342: "cuda:0 bf16[1, 512, 11008]" - # t339 = prims.convert_element_type(t338, dtypes.float32) # t339: "cuda:0 f32[1, 512, 11008]" - # t340 = prims.convert_element_type(t328, dtypes.float32) # t340: "cuda:0 f32[1, 512, 11008]" - # t341 = prims.mul(t339, t340) # t341: "cuda:0 f32[1, 512, 11008]" - # t342 = prims.convert_element_type(t341, dtypes.bfloat16) # t342: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t346 = ltorch.linear(t342, t_transformer_h_1_mlp_proj_weight, None) # t346: "cuda:0 bf16[1, 512, 4096]" - # t346 = prims.linear(t342, t_transformer_h_1_mlp_proj_weight, None) # t346: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t350 = ltorch.add(t346, t297, alpha=None) # t350: "cuda:0 bf16[1, 512, 4096]" - # t347 = prims.convert_element_type(t346, dtypes.float32) # t347: "cuda:0 f32[1, 512, 4096]" - # t348 = prims.convert_element_type(t297, dtypes.float32) # t348: "cuda:0 f32[1, 512, 4096]" - # t349 = prims.add(t347, t348) # t349: "cuda:0 f32[1, 512, 4096]" - # t350 = prims.convert_element_type(t349, dtypes.bfloat16) # t350: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t352 = prims.convert_element_type(t350, dtypes.float32) # t352: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t353 = ltorch.mul(t352, t352) # t353: "cuda:0 f32[1, 512, 4096]" - # t353 = prims.mul(t352, t352) # t353: "cuda:0 f32[1, 512, 4096]" - t357 = ltorch.mean(t353, -1, True, dtype=None) # t357: "cuda:0 f32[1, 512, 1]" - # t355 = prims.sum(t353, (2,)) # t355: "cuda:0 f32[1, 512]" - # t356 = prims.broadcast_in_dim(t355, [1, 512, 1], [0, 1]) # t356: "cuda:0 f32[1, 512, 1]" - # t357 = ltorch.true_divide(t356, 4096) # t357: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t357 = prims.div(t356, 4096.0) # t357: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t359 = ltorch.add(t357, 1e-05, alpha=None) # t359: "cuda:0 f32[1, 512, 1]" - # t359 = prims.add(t357, 1e-05) # t359: "cuda:0 f32[1, 512, 1]" - t360 = ltorch.rsqrt(t359) # t360: "cuda:0 f32[1, 512, 1]" - # t360 = prims.rsqrt(t359) # t360: "cuda:0 f32[1, 512, 1]" - t362 = ltorch.mul(t352, t360) # t362: "cuda:0 f32[1, 512, 4096]" - # t361 = prims.broadcast_in_dim(t360, (1, 512, 4096), (0, 1, 2)) # t361: "cuda:0 f32[1, 512, 4096]" - # t362 = prims.mul(t352, t361) # t362: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t363 = ltorch.to(t362, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t363: "cuda:0 bf16[1, 512, 4096]" - # t363 = prims.convert_element_type(t362, dtypes.bfloat16) # t363: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t373 = ltorch.mul(t363, t_transformer_h_2_norm_1_weight) # t373: "cuda:0 bf16[1, 512, 4096]" - # t369 = prims.broadcast_in_dim(t_transformer_h_2_norm_1_weight, (1, 512, 4096), (2,)) # t369: "cuda:0 bf16[1, 512, 4096]" - # t370 = prims.convert_element_type(t363, dtypes.float32) # t370: "cuda:0 f32[1, 512, 4096]" - # t371 = prims.convert_element_type(t369, dtypes.float32) # t371: "cuda:0 f32[1, 512, 4096]" - # t372 = prims.mul(t370, t371) # t372: "cuda:0 f32[1, 512, 4096]" - # t373 = prims.convert_element_type(t372, dtypes.bfloat16) # t373: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t378 = ltorch.linear(t373, t_transformer_h_2_attn_attn_weight, None) # t378: "cuda:0 bf16[1, 512, 12288]" - # t378 = prims.linear(t373, t_transformer_h_2_attn_attn_weight, None) # t378: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t379 = ltorch.view(t378, 1, 512, 32, 3, 128) # t379: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t379 = ltorch.reshape(t378, (1, 512, 32, 3, 128)) # t379: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t379 = prims.reshape(t378, (1, 512, 32, 3, 128)) # t379: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t380 = ltorch.permute(t379, 0, 2, 3, 1, 4) # t380: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t380 = prims.transpose(t379, (0, 2, 3, 1, 4)) # t380: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t381, t382, t383) = ltorch.split(t380, (1, 1, 1), 2) - # t381 = prims.slice_prim(t380, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t381: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t382 = prims.slice_prim(t380, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t382: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t383 = prims.slice_prim(t380, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t383: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t384 = ltorch.reshape(t381, 1, -1, 512, 128) # t384: "cuda:0 bf16[1, 32, 512, 128]" - # t384 = prims.reshape(t381, (1, 32, 512, 128)) # t384: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t385 = ltorch.reshape(t382, 1, -1, 512, 128) # t385: "cuda:0 bf16[1, 32, 512, 128]" - # t385 = prims.reshape(t382, (1, 32, 512, 128)) # t385: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t386 = ltorch.reshape(t383, 1, -1, 512, 128) # t386: "cuda:0 bf16[1, 32, 512, 128]" - # t386 = prims.reshape(t383, (1, 32, 512, 128)) # t386: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t387 = ltorch.getitem(t384, (..., slice(None, 128, None))) # t387: "cuda:0 bf16[1, 32, 512, 128]" - # t387 = prims.slice_prim(t384, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t387: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t388 = ltorch.getitem(t387, (..., slice(None, 64, None))) # t388: "cuda:0 bf16[1, 32, 512, 64]" - # t388 = prims.slice_prim(t387, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t388: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t389 = ltorch.getitem(t387, (..., slice(64, None, None))) # t389: "cuda:0 bf16[1, 32, 512, 64]" - # t389 = prims.slice_prim(t387, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t389: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t392 = ltorch.neg(t389) # t392: "cuda:0 bf16[1, 32, 512, 64]" - # t390 = prims.convert_element_type(t389, dtypes.float32) # t390: "cuda:0 f32[1, 32, 512, 64]" - # t391 = prims.neg(t390) # t391: "cuda:0 f32[1, 32, 512, 64]" - # t392 = prims.convert_element_type(t391, dtypes.bfloat16) # t392: "cuda:0 bf16[1, 32, 512, 64]" - t393 = ltorch.cat((t392, t388), -1) # t393: "cuda:0 bf16[1, 32, 512, 128]" - # t393 = prims.cat((t392, t388), -1) # t393: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t396 = ltorch.mul(t387, cos) # t396: "cuda:0 f32[1, 32, 512, 128]" - # t394 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t394: "cuda:0 f32[1, 32, 512, 128]" - # t395 = prims.convert_element_type(t387, dtypes.float32) # t395: "cuda:0 f32[1, 32, 512, 128]" - # t396 = prims.mul(t395, t394) # t396: "cuda:0 f32[1, 32, 512, 128]" - t399 = ltorch.mul(t393, sin) # t399: "cuda:0 f32[1, 32, 512, 128]" - # t397 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t397: "cuda:0 f32[1, 32, 512, 128]" - # t398 = prims.convert_element_type(t393, dtypes.float32) # t398: "cuda:0 f32[1, 32, 512, 128]" - # t399 = prims.mul(t398, t397) # t399: "cuda:0 f32[1, 32, 512, 128]" - t400 = ltorch.add(t396, t399, alpha=None) # t400: "cuda:0 f32[1, 32, 512, 128]" - # t400 = prims.add(t396, t399) # t400: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t401 = ltorch.to(t400, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t401: "cuda:0 bf16[1, 32, 512, 128]" - # t401 = prims.convert_element_type(t400, dtypes.bfloat16) # t401: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t402 = ltorch.getitem(t385, (..., slice(None, 128, None))) # t402: "cuda:0 bf16[1, 32, 512, 128]" - # t402 = prims.slice_prim(t385, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t402: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t403 = ltorch.getitem(t402, (..., slice(None, 64, None))) # t403: "cuda:0 bf16[1, 32, 512, 64]" - # t403 = prims.slice_prim(t402, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t403: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t404 = ltorch.getitem(t402, (..., slice(64, None, None))) # t404: "cuda:0 bf16[1, 32, 512, 64]" - # t404 = prims.slice_prim(t402, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t404: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t407 = ltorch.neg(t404) # t407: "cuda:0 bf16[1, 32, 512, 64]" - # t405 = prims.convert_element_type(t404, dtypes.float32) # t405: "cuda:0 f32[1, 32, 512, 64]" - # t406 = prims.neg(t405) # t406: "cuda:0 f32[1, 32, 512, 64]" - # t407 = prims.convert_element_type(t406, dtypes.bfloat16) # t407: "cuda:0 bf16[1, 32, 512, 64]" - t408 = ltorch.cat((t407, t403), -1) # t408: "cuda:0 bf16[1, 32, 512, 128]" - # t408 = prims.cat((t407, t403), -1) # t408: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t411 = ltorch.mul(t402, cos) # t411: "cuda:0 f32[1, 32, 512, 128]" - # t409 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t409: "cuda:0 f32[1, 32, 512, 128]" - # t410 = prims.convert_element_type(t402, dtypes.float32) # t410: "cuda:0 f32[1, 32, 512, 128]" - # t411 = prims.mul(t410, t409) # t411: "cuda:0 f32[1, 32, 512, 128]" - t414 = ltorch.mul(t408, sin) # t414: "cuda:0 f32[1, 32, 512, 128]" - # t412 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t412: "cuda:0 f32[1, 32, 512, 128]" - # t413 = prims.convert_element_type(t408, dtypes.float32) # t413: "cuda:0 f32[1, 32, 512, 128]" - # t414 = prims.mul(t413, t412) # t414: "cuda:0 f32[1, 32, 512, 128]" - t415 = ltorch.add(t411, t414, alpha=None) # t415: "cuda:0 f32[1, 32, 512, 128]" - # t415 = prims.add(t411, t414) # t415: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t416 = ltorch.to(t415, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t416: "cuda:0 bf16[1, 32, 512, 128]" - # t416 = prims.convert_element_type(t415, dtypes.bfloat16) # t416: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t417 = ltorch.getitem(t384, (..., slice(128, None, None))) # t417: "cuda:0 bf16[1, 32, 512, 0]" - # t417 = prims.slice_prim(t384, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t417: "cuda:0 bf16[1, 32, 512, 0]" - t418 = ltorch.cat((t401, t417), -1) # t418: "cuda:0 bf16[1, 32, 512, 128]" - # t418 = prims.cat((t401, t417), -1) # t418: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t419 = ltorch.getitem(t385, (..., slice(128, None, None))) # t419: "cuda:0 bf16[1, 32, 512, 0]" - # t419 = prims.slice_prim(t385, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t419: "cuda:0 bf16[1, 32, 512, 0]" - t420 = ltorch.cat((t416, t419), -1) # t420: "cuda:0 bf16[1, 32, 512, 128]" - # t420 = prims.cat((t416, t419), -1) # t420: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t450 = ltorch.scaled_dot_product_attention(t418, t420, t386, None, 0.0, True, scale=0.08838834764831843) # t450: "cuda:0 bf16[1, 32, 512, 128]" - # t423 = ltorch.mul(t418, 0.29730177875068026) # t423: "cuda:0 bf16[1, 32, 512, 128]" - # t421 = prims.convert_element_type(t418, dtypes.float32) # t421: "cuda:0 f32[1, 32, 512, 128]" - # t422 = prims.mul(t421, 0.29730177875068026) # t422: "cuda:0 f32[1, 32, 512, 128]" - # t423 = prims.convert_element_type(t422, dtypes.bfloat16) # t423: "cuda:0 bf16[1, 32, 512, 128]" - # t424 = ltorch.transpose(t420, -2, -1) # t424: "cuda:0 bf16[1, 32, 128, 512]" - # t424 = prims.transpose(t420, (0, 1, 3, 2)) # t424: "cuda:0 bf16[1, 32, 128, 512]" - # t427 = ltorch.mul(t424, 0.29730177875068026) # t427: "cuda:0 bf16[1, 32, 128, 512]" - # t425 = prims.convert_element_type(t424, dtypes.float32) # t425: "cuda:0 f32[1, 32, 128, 512]" - # t426 = prims.mul(t425, 0.29730177875068026) # t426: "cuda:0 f32[1, 32, 128, 512]" - # t427 = prims.convert_element_type(t426, dtypes.bfloat16) # t427: "cuda:0 bf16[1, 32, 128, 512]" - # t428 = ltorch.matmul(t423, t427) # t428: "cuda:0 bf16[1, 32, 512, 512]" - # t428 = prims.matmul(t423, t427) # t428: "cuda:0 bf16[1, 32, 512, 512]" - # t438 = ltorch.tril(t428, 0, fill_value=-float('inf')) # t438: "cuda:0 bf16[1, 32, 512, 512]" - # t429 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t429: "cuda:0 i64[512]" - # t429 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t429: "cuda:0 i64[512]" - # t430 = ltorch.unsqueeze(t429, -1) # t430: "cuda:0 i64[512, 1]" - # t430 = prims.broadcast_in_dim(t429, [512, 1], [0]) # t430: "cuda:0 i64[512, 1]" - # t431 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t431: "cuda:0 i64[512]" - # t431 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t431: "cuda:0 i64[512]" - # t432 = ltorch.unsqueeze(t431, -2) # t432: "cuda:0 i64[1, 512]" - # t432 = prims.broadcast_in_dim(t431, [1, 512], [1]) # t432: "cuda:0 i64[1, 512]" - # t433 = ltorch.add(t430, 0, alpha=None) # t433: "cuda:0 i64[512, 1]" - # t433 = prims.add(t430, 0) # t433: "cuda:0 i64[512, 1]" - # t436 = ltorch.ge(t433, t432) # t436: "cuda:0 b8[512, 512]" - # t434 = prims.broadcast_in_dim(t433, (512, 512), (0, 1)) # t434: "cuda:0 i64[512, 512]" - # t435 = prims.broadcast_in_dim(t432, (512, 512), (0, 1)) # t435: "cuda:0 i64[512, 512]" - # t436 = prims.ge(t434, t435) # t436: "cuda:0 b8[512, 512]" - # t438 = ltorch.where(t436, t428, -float('inf')) # t438: "cuda:0 bf16[1, 32, 512, 512]" - # t437 = prims.broadcast_in_dim(t436, (1, 32, 512, 512), (2, 3)) # t437: "cuda:0 b8[1, 32, 512, 512]" - # t438 = prims.where(t437, t428, -float('inf')) # t438: "cuda:0 bf16[1, 32, 512, 512]" - # t449 = ltorch._softmax(t438, -1, dtype=None) # t449: "cuda:0 bf16[1, 32, 512, 512]" - # t439 = ltorch.to(t438, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t439: "cuda:0 f32[1, 32, 512, 512]" - # t439 = prims.convert_element_type(t438, dtypes.float32) # t439: "cuda:0 f32[1, 32, 512, 512]" - # t441 = ltorch.amax(t439, -1, True) # t441: "cuda:0 f32[1, 32, 512, 1]" - # t440 = prims.amax(t439, (3,)) # t440: "cuda:0 f32[1, 32, 512]" - # t441 = prims.broadcast_in_dim(t440, [1, 32, 512, 1], [0, 1, 2]) # t441: "cuda:0 f32[1, 32, 512, 1]" - # t443 = ltorch.sub(t439, t441, alpha=None) # t443: "cuda:0 f32[1, 32, 512, 512]" - # t442 = prims.broadcast_in_dim(t441, (1, 32, 512, 512), (0, 1, 2, 3)) # t442: "cuda:0 f32[1, 32, 512, 512]" - # t443 = prims.sub(t439, t442) # t443: "cuda:0 f32[1, 32, 512, 512]" - # t444 = ltorch.exp(t443) # t444: "cuda:0 f32[1, 32, 512, 512]" - # t444 = prims.exp(t443) # t444: "cuda:0 f32[1, 32, 512, 512]" - # t446 = ltorch.sum(t444, -1, True, dtype=None) # t446: "cuda:0 f32[1, 32, 512, 1]" - # t445 = prims.sum(t444, (3,)) # t445: "cuda:0 f32[1, 32, 512]" - # t446 = prims.broadcast_in_dim(t445, [1, 32, 512, 1], [0, 1, 2]) # t446: "cuda:0 f32[1, 32, 512, 1]" - # t448 = ltorch.true_divide(t444, t446) # t448: "cuda:0 f32[1, 32, 512, 512]" - # t447 = prims.broadcast_in_dim(t446, (1, 32, 512, 512), (0, 1, 2, 3)) # t447: "cuda:0 f32[1, 32, 512, 512]" - # t448 = prims.div(t444, t447) # t448: "cuda:0 f32[1, 32, 512, 512]" - # t449 = ltorch.to(t448, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t449: "cuda:0 bf16[1, 32, 512, 512]" - # t449 = prims.convert_element_type(t448, dtypes.bfloat16) # t449: "cuda:0 bf16[1, 32, 512, 512]" - # t450 = ltorch.matmul(t449, t386) # t450: "cuda:0 bf16[1, 32, 512, 128]" - # t450 = prims.matmul(t449, t386) # t450: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t451 = ltorch.transpose(t450, 1, 2) # t451: "cuda:0 bf16[1, 512, 32, 128]" - # t451 = prims.transpose(t450, (0, 2, 1, 3)) # t451: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t452 = ltorch.reshape(t451, 1, 512, 4096) # t452: "cuda:0 bf16[1, 512, 4096]" - # t452 = prims.reshape(t451, (1, 512, 4096)) # t452: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t456 = ltorch.linear(t452, t_transformer_h_2_attn_proj_weight, None) # t456: "cuda:0 bf16[1, 512, 4096]" - # t456 = prims.linear(t452, t_transformer_h_2_attn_proj_weight, None) # t456: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t460 = ltorch.add(t456, t350, alpha=None) # t460: "cuda:0 bf16[1, 512, 4096]" - # t457 = prims.convert_element_type(t456, dtypes.float32) # t457: "cuda:0 f32[1, 512, 4096]" - # t458 = prims.convert_element_type(t350, dtypes.float32) # t458: "cuda:0 f32[1, 512, 4096]" - # t459 = prims.add(t457, t458) # t459: "cuda:0 f32[1, 512, 4096]" - # t460 = prims.convert_element_type(t459, dtypes.bfloat16) # t460: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t461 = prims.convert_element_type(t460, dtypes.float32) # t461: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t462 = ltorch.mul(t461, t461) # t462: "cuda:0 f32[1, 512, 4096]" - # t462 = prims.mul(t461, t461) # t462: "cuda:0 f32[1, 512, 4096]" - t466 = ltorch.mean(t462, -1, True, dtype=None) # t466: "cuda:0 f32[1, 512, 1]" - # t464 = prims.sum(t462, (2,)) # t464: "cuda:0 f32[1, 512]" - # t465 = prims.broadcast_in_dim(t464, [1, 512, 1], [0, 1]) # t465: "cuda:0 f32[1, 512, 1]" - # t466 = ltorch.true_divide(t465, 4096) # t466: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t466 = prims.div(t465, 4096.0) # t466: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t468 = ltorch.add(t466, 1e-05, alpha=None) # t468: "cuda:0 f32[1, 512, 1]" - # t468 = prims.add(t466, 1e-05) # t468: "cuda:0 f32[1, 512, 1]" - t469 = ltorch.rsqrt(t468) # t469: "cuda:0 f32[1, 512, 1]" - # t469 = prims.rsqrt(t468) # t469: "cuda:0 f32[1, 512, 1]" - t471 = ltorch.mul(t461, t469) # t471: "cuda:0 f32[1, 512, 4096]" - # t470 = prims.broadcast_in_dim(t469, (1, 512, 4096), (0, 1, 2)) # t470: "cuda:0 f32[1, 512, 4096]" - # t471 = prims.mul(t461, t470) # t471: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t472 = ltorch.to(t471, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t472: "cuda:0 bf16[1, 512, 4096]" - # t472 = prims.convert_element_type(t471, dtypes.bfloat16) # t472: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t482 = ltorch.mul(t472, t_transformer_h_2_norm_2_weight) # t482: "cuda:0 bf16[1, 512, 4096]" - # t478 = prims.broadcast_in_dim(t_transformer_h_2_norm_2_weight, (1, 512, 4096), (2,)) # t478: "cuda:0 bf16[1, 512, 4096]" - # t479 = prims.convert_element_type(t472, dtypes.float32) # t479: "cuda:0 f32[1, 512, 4096]" - # t480 = prims.convert_element_type(t478, dtypes.float32) # t480: "cuda:0 f32[1, 512, 4096]" - # t481 = prims.mul(t479, t480) # t481: "cuda:0 f32[1, 512, 4096]" - # t482 = prims.convert_element_type(t481, dtypes.bfloat16) # t482: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t487 = ltorch.linear(t482, t_transformer_h_2_mlp_fc_1_weight, None) # t487: "cuda:0 bf16[1, 512, 11008]" - # t487 = prims.linear(t482, t_transformer_h_2_mlp_fc_1_weight, None) # t487: "cuda:0 bf16[1, 512, 11008]" - t491 = ltorch.linear(t482, t_transformer_h_2_mlp_fc_2_weight, None) # t491: "cuda:0 bf16[1, 512, 11008]" - # t491 = prims.linear(t482, t_transformer_h_2_mlp_fc_2_weight, None) # t491: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t501 = ltorch.silu(t487, False) # t501: "cuda:0 bf16[1, 512, 11008]" - # t492 = prims.convert_element_type(t487, dtypes.float32) # t492: "cuda:0 f32[1, 512, 11008]" - # t493 = prims.neg(t492) # t493: "cuda:0 f32[1, 512, 11008]" - # t494 = prims.exp(t493) # t494: "cuda:0 f32[1, 512, 11008]" - # t495 = prims.add(1.0, t494) # t495: "cuda:0 f32[1, 512, 11008]" - # t496 = prims.reciprocal(t495) # t496: "cuda:0 f32[1, 512, 11008]" - # t497 = prims.convert_element_type(t496, dtypes.bfloat16) # t497: "cuda:0 bf16[1, 512, 11008]" - # t498 = prims.convert_element_type(t487, dtypes.float32) # t498: "cuda:0 f32[1, 512, 11008]" - # t499 = prims.convert_element_type(t497, dtypes.float32) # t499: "cuda:0 f32[1, 512, 11008]" - # t500 = prims.mul(t498, t499) # t500: "cuda:0 f32[1, 512, 11008]" - # t501 = prims.convert_element_type(t500, dtypes.bfloat16) # t501: "cuda:0 bf16[1, 512, 11008]" - t505 = ltorch.mul(t501, t491) # t505: "cuda:0 bf16[1, 512, 11008]" - # t502 = prims.convert_element_type(t501, dtypes.float32) # t502: "cuda:0 f32[1, 512, 11008]" - # t503 = prims.convert_element_type(t491, dtypes.float32) # t503: "cuda:0 f32[1, 512, 11008]" - # t504 = prims.mul(t502, t503) # t504: "cuda:0 f32[1, 512, 11008]" - # t505 = prims.convert_element_type(t504, dtypes.bfloat16) # t505: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t509 = ltorch.linear(t505, t_transformer_h_2_mlp_proj_weight, None) # t509: "cuda:0 bf16[1, 512, 4096]" - # t509 = prims.linear(t505, t_transformer_h_2_mlp_proj_weight, None) # t509: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t513 = ltorch.add(t509, t460, alpha=None) # t513: "cuda:0 bf16[1, 512, 4096]" - # t510 = prims.convert_element_type(t509, dtypes.float32) # t510: "cuda:0 f32[1, 512, 4096]" - # t511 = prims.convert_element_type(t460, dtypes.float32) # t511: "cuda:0 f32[1, 512, 4096]" - # t512 = prims.add(t510, t511) # t512: "cuda:0 f32[1, 512, 4096]" - # t513 = prims.convert_element_type(t512, dtypes.bfloat16) # t513: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t515 = prims.convert_element_type(t513, dtypes.float32) # t515: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t516 = ltorch.mul(t515, t515) # t516: "cuda:0 f32[1, 512, 4096]" - # t516 = prims.mul(t515, t515) # t516: "cuda:0 f32[1, 512, 4096]" - t520 = ltorch.mean(t516, -1, True, dtype=None) # t520: "cuda:0 f32[1, 512, 1]" - # t518 = prims.sum(t516, (2,)) # t518: "cuda:0 f32[1, 512]" - # t519 = prims.broadcast_in_dim(t518, [1, 512, 1], [0, 1]) # t519: "cuda:0 f32[1, 512, 1]" - # t520 = ltorch.true_divide(t519, 4096) # t520: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t520 = prims.div(t519, 4096.0) # t520: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t522 = ltorch.add(t520, 1e-05, alpha=None) # t522: "cuda:0 f32[1, 512, 1]" - # t522 = prims.add(t520, 1e-05) # t522: "cuda:0 f32[1, 512, 1]" - t523 = ltorch.rsqrt(t522) # t523: "cuda:0 f32[1, 512, 1]" - # t523 = prims.rsqrt(t522) # t523: "cuda:0 f32[1, 512, 1]" - t525 = ltorch.mul(t515, t523) # t525: "cuda:0 f32[1, 512, 4096]" - # t524 = prims.broadcast_in_dim(t523, (1, 512, 4096), (0, 1, 2)) # t524: "cuda:0 f32[1, 512, 4096]" - # t525 = prims.mul(t515, t524) # t525: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t526 = ltorch.to(t525, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t526: "cuda:0 bf16[1, 512, 4096]" - # t526 = prims.convert_element_type(t525, dtypes.bfloat16) # t526: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t536 = ltorch.mul(t526, t_transformer_h_3_norm_1_weight) # t536: "cuda:0 bf16[1, 512, 4096]" - # t532 = prims.broadcast_in_dim(t_transformer_h_3_norm_1_weight, (1, 512, 4096), (2,)) # t532: "cuda:0 bf16[1, 512, 4096]" - # t533 = prims.convert_element_type(t526, dtypes.float32) # t533: "cuda:0 f32[1, 512, 4096]" - # t534 = prims.convert_element_type(t532, dtypes.float32) # t534: "cuda:0 f32[1, 512, 4096]" - # t535 = prims.mul(t533, t534) # t535: "cuda:0 f32[1, 512, 4096]" - # t536 = prims.convert_element_type(t535, dtypes.bfloat16) # t536: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t541 = ltorch.linear(t536, t_transformer_h_3_attn_attn_weight, None) # t541: "cuda:0 bf16[1, 512, 12288]" - # t541 = prims.linear(t536, t_transformer_h_3_attn_attn_weight, None) # t541: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t542 = ltorch.view(t541, 1, 512, 32, 3, 128) # t542: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t542 = ltorch.reshape(t541, (1, 512, 32, 3, 128)) # t542: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t542 = prims.reshape(t541, (1, 512, 32, 3, 128)) # t542: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t543 = ltorch.permute(t542, 0, 2, 3, 1, 4) # t543: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t543 = prims.transpose(t542, (0, 2, 3, 1, 4)) # t543: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t544, t545, t546) = ltorch.split(t543, (1, 1, 1), 2) - # t544 = prims.slice_prim(t543, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t544: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t545 = prims.slice_prim(t543, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t545: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t546 = prims.slice_prim(t543, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t546: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t547 = ltorch.reshape(t544, 1, -1, 512, 128) # t547: "cuda:0 bf16[1, 32, 512, 128]" - # t547 = prims.reshape(t544, (1, 32, 512, 128)) # t547: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t548 = ltorch.reshape(t545, 1, -1, 512, 128) # t548: "cuda:0 bf16[1, 32, 512, 128]" - # t548 = prims.reshape(t545, (1, 32, 512, 128)) # t548: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t549 = ltorch.reshape(t546, 1, -1, 512, 128) # t549: "cuda:0 bf16[1, 32, 512, 128]" - # t549 = prims.reshape(t546, (1, 32, 512, 128)) # t549: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t550 = ltorch.getitem(t547, (..., slice(None, 128, None))) # t550: "cuda:0 bf16[1, 32, 512, 128]" - # t550 = prims.slice_prim(t547, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t550: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t551 = ltorch.getitem(t550, (..., slice(None, 64, None))) # t551: "cuda:0 bf16[1, 32, 512, 64]" - # t551 = prims.slice_prim(t550, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t551: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t552 = ltorch.getitem(t550, (..., slice(64, None, None))) # t552: "cuda:0 bf16[1, 32, 512, 64]" - # t552 = prims.slice_prim(t550, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t552: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t555 = ltorch.neg(t552) # t555: "cuda:0 bf16[1, 32, 512, 64]" - # t553 = prims.convert_element_type(t552, dtypes.float32) # t553: "cuda:0 f32[1, 32, 512, 64]" - # t554 = prims.neg(t553) # t554: "cuda:0 f32[1, 32, 512, 64]" - # t555 = prims.convert_element_type(t554, dtypes.bfloat16) # t555: "cuda:0 bf16[1, 32, 512, 64]" - t556 = ltorch.cat((t555, t551), -1) # t556: "cuda:0 bf16[1, 32, 512, 128]" - # t556 = prims.cat((t555, t551), -1) # t556: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t559 = ltorch.mul(t550, cos) # t559: "cuda:0 f32[1, 32, 512, 128]" - # t557 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t557: "cuda:0 f32[1, 32, 512, 128]" - # t558 = prims.convert_element_type(t550, dtypes.float32) # t558: "cuda:0 f32[1, 32, 512, 128]" - # t559 = prims.mul(t558, t557) # t559: "cuda:0 f32[1, 32, 512, 128]" - t562 = ltorch.mul(t556, sin) # t562: "cuda:0 f32[1, 32, 512, 128]" - # t560 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t560: "cuda:0 f32[1, 32, 512, 128]" - # t561 = prims.convert_element_type(t556, dtypes.float32) # t561: "cuda:0 f32[1, 32, 512, 128]" - # t562 = prims.mul(t561, t560) # t562: "cuda:0 f32[1, 32, 512, 128]" - t563 = ltorch.add(t559, t562, alpha=None) # t563: "cuda:0 f32[1, 32, 512, 128]" - # t563 = prims.add(t559, t562) # t563: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t564 = ltorch.to(t563, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t564: "cuda:0 bf16[1, 32, 512, 128]" - # t564 = prims.convert_element_type(t563, dtypes.bfloat16) # t564: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t565 = ltorch.getitem(t548, (..., slice(None, 128, None))) # t565: "cuda:0 bf16[1, 32, 512, 128]" - # t565 = prims.slice_prim(t548, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t565: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t566 = ltorch.getitem(t565, (..., slice(None, 64, None))) # t566: "cuda:0 bf16[1, 32, 512, 64]" - # t566 = prims.slice_prim(t565, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t566: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t567 = ltorch.getitem(t565, (..., slice(64, None, None))) # t567: "cuda:0 bf16[1, 32, 512, 64]" - # t567 = prims.slice_prim(t565, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t567: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t570 = ltorch.neg(t567) # t570: "cuda:0 bf16[1, 32, 512, 64]" - # t568 = prims.convert_element_type(t567, dtypes.float32) # t568: "cuda:0 f32[1, 32, 512, 64]" - # t569 = prims.neg(t568) # t569: "cuda:0 f32[1, 32, 512, 64]" - # t570 = prims.convert_element_type(t569, dtypes.bfloat16) # t570: "cuda:0 bf16[1, 32, 512, 64]" - t571 = ltorch.cat((t570, t566), -1) # t571: "cuda:0 bf16[1, 32, 512, 128]" - # t571 = prims.cat((t570, t566), -1) # t571: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t574 = ltorch.mul(t565, cos) # t574: "cuda:0 f32[1, 32, 512, 128]" - # t572 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t572: "cuda:0 f32[1, 32, 512, 128]" - # t573 = prims.convert_element_type(t565, dtypes.float32) # t573: "cuda:0 f32[1, 32, 512, 128]" - # t574 = prims.mul(t573, t572) # t574: "cuda:0 f32[1, 32, 512, 128]" - t577 = ltorch.mul(t571, sin) # t577: "cuda:0 f32[1, 32, 512, 128]" - # t575 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t575: "cuda:0 f32[1, 32, 512, 128]" - # t576 = prims.convert_element_type(t571, dtypes.float32) # t576: "cuda:0 f32[1, 32, 512, 128]" - # t577 = prims.mul(t576, t575) # t577: "cuda:0 f32[1, 32, 512, 128]" - t578 = ltorch.add(t574, t577, alpha=None) # t578: "cuda:0 f32[1, 32, 512, 128]" - # t578 = prims.add(t574, t577) # t578: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t579 = ltorch.to(t578, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t579: "cuda:0 bf16[1, 32, 512, 128]" - # t579 = prims.convert_element_type(t578, dtypes.bfloat16) # t579: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t580 = ltorch.getitem(t547, (..., slice(128, None, None))) # t580: "cuda:0 bf16[1, 32, 512, 0]" - # t580 = prims.slice_prim(t547, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t580: "cuda:0 bf16[1, 32, 512, 0]" - t581 = ltorch.cat((t564, t580), -1) # t581: "cuda:0 bf16[1, 32, 512, 128]" - # t581 = prims.cat((t564, t580), -1) # t581: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t582 = ltorch.getitem(t548, (..., slice(128, None, None))) # t582: "cuda:0 bf16[1, 32, 512, 0]" - # t582 = prims.slice_prim(t548, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t582: "cuda:0 bf16[1, 32, 512, 0]" - t583 = ltorch.cat((t579, t582), -1) # t583: "cuda:0 bf16[1, 32, 512, 128]" - # t583 = prims.cat((t579, t582), -1) # t583: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t613 = ltorch.scaled_dot_product_attention(t581, t583, t549, None, 0.0, True, scale=0.08838834764831843) # t613: "cuda:0 bf16[1, 32, 512, 128]" - # t586 = ltorch.mul(t581, 0.29730177875068026) # t586: "cuda:0 bf16[1, 32, 512, 128]" - # t584 = prims.convert_element_type(t581, dtypes.float32) # t584: "cuda:0 f32[1, 32, 512, 128]" - # t585 = prims.mul(t584, 0.29730177875068026) # t585: "cuda:0 f32[1, 32, 512, 128]" - # t586 = prims.convert_element_type(t585, dtypes.bfloat16) # t586: "cuda:0 bf16[1, 32, 512, 128]" - # t587 = ltorch.transpose(t583, -2, -1) # t587: "cuda:0 bf16[1, 32, 128, 512]" - # t587 = prims.transpose(t583, (0, 1, 3, 2)) # t587: "cuda:0 bf16[1, 32, 128, 512]" - # t590 = ltorch.mul(t587, 0.29730177875068026) # t590: "cuda:0 bf16[1, 32, 128, 512]" - # t588 = prims.convert_element_type(t587, dtypes.float32) # t588: "cuda:0 f32[1, 32, 128, 512]" - # t589 = prims.mul(t588, 0.29730177875068026) # t589: "cuda:0 f32[1, 32, 128, 512]" - # t590 = prims.convert_element_type(t589, dtypes.bfloat16) # t590: "cuda:0 bf16[1, 32, 128, 512]" - # t591 = ltorch.matmul(t586, t590) # t591: "cuda:0 bf16[1, 32, 512, 512]" - # t591 = prims.matmul(t586, t590) # t591: "cuda:0 bf16[1, 32, 512, 512]" - # t601 = ltorch.tril(t591, 0, fill_value=-float('inf')) # t601: "cuda:0 bf16[1, 32, 512, 512]" - # t592 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t592: "cuda:0 i64[512]" - # t592 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t592: "cuda:0 i64[512]" - # t593 = ltorch.unsqueeze(t592, -1) # t593: "cuda:0 i64[512, 1]" - # t593 = prims.broadcast_in_dim(t592, [512, 1], [0]) # t593: "cuda:0 i64[512, 1]" - # t594 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t594: "cuda:0 i64[512]" - # t594 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t594: "cuda:0 i64[512]" - # t595 = ltorch.unsqueeze(t594, -2) # t595: "cuda:0 i64[1, 512]" - # t595 = prims.broadcast_in_dim(t594, [1, 512], [1]) # t595: "cuda:0 i64[1, 512]" - # t596 = ltorch.add(t593, 0, alpha=None) # t596: "cuda:0 i64[512, 1]" - # t596 = prims.add(t593, 0) # t596: "cuda:0 i64[512, 1]" - # t599 = ltorch.ge(t596, t595) # t599: "cuda:0 b8[512, 512]" - # t597 = prims.broadcast_in_dim(t596, (512, 512), (0, 1)) # t597: "cuda:0 i64[512, 512]" - # t598 = prims.broadcast_in_dim(t595, (512, 512), (0, 1)) # t598: "cuda:0 i64[512, 512]" - # t599 = prims.ge(t597, t598) # t599: "cuda:0 b8[512, 512]" - # t601 = ltorch.where(t599, t591, -float('inf')) # t601: "cuda:0 bf16[1, 32, 512, 512]" - # t600 = prims.broadcast_in_dim(t599, (1, 32, 512, 512), (2, 3)) # t600: "cuda:0 b8[1, 32, 512, 512]" - # t601 = prims.where(t600, t591, -float('inf')) # t601: "cuda:0 bf16[1, 32, 512, 512]" - # t612 = ltorch._softmax(t601, -1, dtype=None) # t612: "cuda:0 bf16[1, 32, 512, 512]" - # t602 = ltorch.to(t601, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t602: "cuda:0 f32[1, 32, 512, 512]" - # t602 = prims.convert_element_type(t601, dtypes.float32) # t602: "cuda:0 f32[1, 32, 512, 512]" - # t604 = ltorch.amax(t602, -1, True) # t604: "cuda:0 f32[1, 32, 512, 1]" - # t603 = prims.amax(t602, (3,)) # t603: "cuda:0 f32[1, 32, 512]" - # t604 = prims.broadcast_in_dim(t603, [1, 32, 512, 1], [0, 1, 2]) # t604: "cuda:0 f32[1, 32, 512, 1]" - # t606 = ltorch.sub(t602, t604, alpha=None) # t606: "cuda:0 f32[1, 32, 512, 512]" - # t605 = prims.broadcast_in_dim(t604, (1, 32, 512, 512), (0, 1, 2, 3)) # t605: "cuda:0 f32[1, 32, 512, 512]" - # t606 = prims.sub(t602, t605) # t606: "cuda:0 f32[1, 32, 512, 512]" - # t607 = ltorch.exp(t606) # t607: "cuda:0 f32[1, 32, 512, 512]" - # t607 = prims.exp(t606) # t607: "cuda:0 f32[1, 32, 512, 512]" - # t609 = ltorch.sum(t607, -1, True, dtype=None) # t609: "cuda:0 f32[1, 32, 512, 1]" - # t608 = prims.sum(t607, (3,)) # t608: "cuda:0 f32[1, 32, 512]" - # t609 = prims.broadcast_in_dim(t608, [1, 32, 512, 1], [0, 1, 2]) # t609: "cuda:0 f32[1, 32, 512, 1]" - # t611 = ltorch.true_divide(t607, t609) # t611: "cuda:0 f32[1, 32, 512, 512]" - # t610 = prims.broadcast_in_dim(t609, (1, 32, 512, 512), (0, 1, 2, 3)) # t610: "cuda:0 f32[1, 32, 512, 512]" - # t611 = prims.div(t607, t610) # t611: "cuda:0 f32[1, 32, 512, 512]" - # t612 = ltorch.to(t611, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t612: "cuda:0 bf16[1, 32, 512, 512]" - # t612 = prims.convert_element_type(t611, dtypes.bfloat16) # t612: "cuda:0 bf16[1, 32, 512, 512]" - # t613 = ltorch.matmul(t612, t549) # t613: "cuda:0 bf16[1, 32, 512, 128]" - # t613 = prims.matmul(t612, t549) # t613: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t614 = ltorch.transpose(t613, 1, 2) # t614: "cuda:0 bf16[1, 512, 32, 128]" - # t614 = prims.transpose(t613, (0, 2, 1, 3)) # t614: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t615 = ltorch.reshape(t614, 1, 512, 4096) # t615: "cuda:0 bf16[1, 512, 4096]" - # t615 = prims.reshape(t614, (1, 512, 4096)) # t615: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t619 = ltorch.linear(t615, t_transformer_h_3_attn_proj_weight, None) # t619: "cuda:0 bf16[1, 512, 4096]" - # t619 = prims.linear(t615, t_transformer_h_3_attn_proj_weight, None) # t619: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t623 = ltorch.add(t619, t513, alpha=None) # t623: "cuda:0 bf16[1, 512, 4096]" - # t620 = prims.convert_element_type(t619, dtypes.float32) # t620: "cuda:0 f32[1, 512, 4096]" - # t621 = prims.convert_element_type(t513, dtypes.float32) # t621: "cuda:0 f32[1, 512, 4096]" - # t622 = prims.add(t620, t621) # t622: "cuda:0 f32[1, 512, 4096]" - # t623 = prims.convert_element_type(t622, dtypes.bfloat16) # t623: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t624 = prims.convert_element_type(t623, dtypes.float32) # t624: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t625 = ltorch.mul(t624, t624) # t625: "cuda:0 f32[1, 512, 4096]" - # t625 = prims.mul(t624, t624) # t625: "cuda:0 f32[1, 512, 4096]" - t629 = ltorch.mean(t625, -1, True, dtype=None) # t629: "cuda:0 f32[1, 512, 1]" - # t627 = prims.sum(t625, (2,)) # t627: "cuda:0 f32[1, 512]" - # t628 = prims.broadcast_in_dim(t627, [1, 512, 1], [0, 1]) # t628: "cuda:0 f32[1, 512, 1]" - # t629 = ltorch.true_divide(t628, 4096) # t629: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t629 = prims.div(t628, 4096.0) # t629: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t631 = ltorch.add(t629, 1e-05, alpha=None) # t631: "cuda:0 f32[1, 512, 1]" - # t631 = prims.add(t629, 1e-05) # t631: "cuda:0 f32[1, 512, 1]" - t632 = ltorch.rsqrt(t631) # t632: "cuda:0 f32[1, 512, 1]" - # t632 = prims.rsqrt(t631) # t632: "cuda:0 f32[1, 512, 1]" - t634 = ltorch.mul(t624, t632) # t634: "cuda:0 f32[1, 512, 4096]" - # t633 = prims.broadcast_in_dim(t632, (1, 512, 4096), (0, 1, 2)) # t633: "cuda:0 f32[1, 512, 4096]" - # t634 = prims.mul(t624, t633) # t634: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t635 = ltorch.to(t634, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t635: "cuda:0 bf16[1, 512, 4096]" - # t635 = prims.convert_element_type(t634, dtypes.bfloat16) # t635: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t645 = ltorch.mul(t635, t_transformer_h_3_norm_2_weight) # t645: "cuda:0 bf16[1, 512, 4096]" - # t641 = prims.broadcast_in_dim(t_transformer_h_3_norm_2_weight, (1, 512, 4096), (2,)) # t641: "cuda:0 bf16[1, 512, 4096]" - # t642 = prims.convert_element_type(t635, dtypes.float32) # t642: "cuda:0 f32[1, 512, 4096]" - # t643 = prims.convert_element_type(t641, dtypes.float32) # t643: "cuda:0 f32[1, 512, 4096]" - # t644 = prims.mul(t642, t643) # t644: "cuda:0 f32[1, 512, 4096]" - # t645 = prims.convert_element_type(t644, dtypes.bfloat16) # t645: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t650 = ltorch.linear(t645, t_transformer_h_3_mlp_fc_1_weight, None) # t650: "cuda:0 bf16[1, 512, 11008]" - # t650 = prims.linear(t645, t_transformer_h_3_mlp_fc_1_weight, None) # t650: "cuda:0 bf16[1, 512, 11008]" - t654 = ltorch.linear(t645, t_transformer_h_3_mlp_fc_2_weight, None) # t654: "cuda:0 bf16[1, 512, 11008]" - # t654 = prims.linear(t645, t_transformer_h_3_mlp_fc_2_weight, None) # t654: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t664 = ltorch.silu(t650, False) # t664: "cuda:0 bf16[1, 512, 11008]" - # t655 = prims.convert_element_type(t650, dtypes.float32) # t655: "cuda:0 f32[1, 512, 11008]" - # t656 = prims.neg(t655) # t656: "cuda:0 f32[1, 512, 11008]" - # t657 = prims.exp(t656) # t657: "cuda:0 f32[1, 512, 11008]" - # t658 = prims.add(1.0, t657) # t658: "cuda:0 f32[1, 512, 11008]" - # t659 = prims.reciprocal(t658) # t659: "cuda:0 f32[1, 512, 11008]" - # t660 = prims.convert_element_type(t659, dtypes.bfloat16) # t660: "cuda:0 bf16[1, 512, 11008]" - # t661 = prims.convert_element_type(t650, dtypes.float32) # t661: "cuda:0 f32[1, 512, 11008]" - # t662 = prims.convert_element_type(t660, dtypes.float32) # t662: "cuda:0 f32[1, 512, 11008]" - # t663 = prims.mul(t661, t662) # t663: "cuda:0 f32[1, 512, 11008]" - # t664 = prims.convert_element_type(t663, dtypes.bfloat16) # t664: "cuda:0 bf16[1, 512, 11008]" - t668 = ltorch.mul(t664, t654) # t668: "cuda:0 bf16[1, 512, 11008]" - # t665 = prims.convert_element_type(t664, dtypes.float32) # t665: "cuda:0 f32[1, 512, 11008]" - # t666 = prims.convert_element_type(t654, dtypes.float32) # t666: "cuda:0 f32[1, 512, 11008]" - # t667 = prims.mul(t665, t666) # t667: "cuda:0 f32[1, 512, 11008]" - # t668 = prims.convert_element_type(t667, dtypes.bfloat16) # t668: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t672 = ltorch.linear(t668, t_transformer_h_3_mlp_proj_weight, None) # t672: "cuda:0 bf16[1, 512, 4096]" - # t672 = prims.linear(t668, t_transformer_h_3_mlp_proj_weight, None) # t672: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t676 = ltorch.add(t672, t623, alpha=None) # t676: "cuda:0 bf16[1, 512, 4096]" - # t673 = prims.convert_element_type(t672, dtypes.float32) # t673: "cuda:0 f32[1, 512, 4096]" - # t674 = prims.convert_element_type(t623, dtypes.float32) # t674: "cuda:0 f32[1, 512, 4096]" - # t675 = prims.add(t673, t674) # t675: "cuda:0 f32[1, 512, 4096]" - # t676 = prims.convert_element_type(t675, dtypes.bfloat16) # t676: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t678 = prims.convert_element_type(t676, dtypes.float32) # t678: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t679 = ltorch.mul(t678, t678) # t679: "cuda:0 f32[1, 512, 4096]" - # t679 = prims.mul(t678, t678) # t679: "cuda:0 f32[1, 512, 4096]" - t683 = ltorch.mean(t679, -1, True, dtype=None) # t683: "cuda:0 f32[1, 512, 1]" - # t681 = prims.sum(t679, (2,)) # t681: "cuda:0 f32[1, 512]" - # t682 = prims.broadcast_in_dim(t681, [1, 512, 1], [0, 1]) # t682: "cuda:0 f32[1, 512, 1]" - # t683 = ltorch.true_divide(t682, 4096) # t683: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t683 = prims.div(t682, 4096.0) # t683: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t685 = ltorch.add(t683, 1e-05, alpha=None) # t685: "cuda:0 f32[1, 512, 1]" - # t685 = prims.add(t683, 1e-05) # t685: "cuda:0 f32[1, 512, 1]" - t686 = ltorch.rsqrt(t685) # t686: "cuda:0 f32[1, 512, 1]" - # t686 = prims.rsqrt(t685) # t686: "cuda:0 f32[1, 512, 1]" - t688 = ltorch.mul(t678, t686) # t688: "cuda:0 f32[1, 512, 4096]" - # t687 = prims.broadcast_in_dim(t686, (1, 512, 4096), (0, 1, 2)) # t687: "cuda:0 f32[1, 512, 4096]" - # t688 = prims.mul(t678, t687) # t688: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t689 = ltorch.to(t688, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t689: "cuda:0 bf16[1, 512, 4096]" - # t689 = prims.convert_element_type(t688, dtypes.bfloat16) # t689: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t699 = ltorch.mul(t689, t_transformer_h_4_norm_1_weight) # t699: "cuda:0 bf16[1, 512, 4096]" - # t695 = prims.broadcast_in_dim(t_transformer_h_4_norm_1_weight, (1, 512, 4096), (2,)) # t695: "cuda:0 bf16[1, 512, 4096]" - # t696 = prims.convert_element_type(t689, dtypes.float32) # t696: "cuda:0 f32[1, 512, 4096]" - # t697 = prims.convert_element_type(t695, dtypes.float32) # t697: "cuda:0 f32[1, 512, 4096]" - # t698 = prims.mul(t696, t697) # t698: "cuda:0 f32[1, 512, 4096]" - # t699 = prims.convert_element_type(t698, dtypes.bfloat16) # t699: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t704 = ltorch.linear(t699, t_transformer_h_4_attn_attn_weight, None) # t704: "cuda:0 bf16[1, 512, 12288]" - # t704 = prims.linear(t699, t_transformer_h_4_attn_attn_weight, None) # t704: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t705 = ltorch.view(t704, 1, 512, 32, 3, 128) # t705: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t705 = ltorch.reshape(t704, (1, 512, 32, 3, 128)) # t705: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t705 = prims.reshape(t704, (1, 512, 32, 3, 128)) # t705: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t706 = ltorch.permute(t705, 0, 2, 3, 1, 4) # t706: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t706 = prims.transpose(t705, (0, 2, 3, 1, 4)) # t706: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t707, t708, t709) = ltorch.split(t706, (1, 1, 1), 2) - # t707 = prims.slice_prim(t706, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t707: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t708 = prims.slice_prim(t706, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t708: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t709 = prims.slice_prim(t706, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t709: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t710 = ltorch.reshape(t707, 1, -1, 512, 128) # t710: "cuda:0 bf16[1, 32, 512, 128]" - # t710 = prims.reshape(t707, (1, 32, 512, 128)) # t710: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t711 = ltorch.reshape(t708, 1, -1, 512, 128) # t711: "cuda:0 bf16[1, 32, 512, 128]" - # t711 = prims.reshape(t708, (1, 32, 512, 128)) # t711: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t712 = ltorch.reshape(t709, 1, -1, 512, 128) # t712: "cuda:0 bf16[1, 32, 512, 128]" - # t712 = prims.reshape(t709, (1, 32, 512, 128)) # t712: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t713 = ltorch.getitem(t710, (..., slice(None, 128, None))) # t713: "cuda:0 bf16[1, 32, 512, 128]" - # t713 = prims.slice_prim(t710, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t713: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t714 = ltorch.getitem(t713, (..., slice(None, 64, None))) # t714: "cuda:0 bf16[1, 32, 512, 64]" - # t714 = prims.slice_prim(t713, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t714: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t715 = ltorch.getitem(t713, (..., slice(64, None, None))) # t715: "cuda:0 bf16[1, 32, 512, 64]" - # t715 = prims.slice_prim(t713, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t715: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t718 = ltorch.neg(t715) # t718: "cuda:0 bf16[1, 32, 512, 64]" - # t716 = prims.convert_element_type(t715, dtypes.float32) # t716: "cuda:0 f32[1, 32, 512, 64]" - # t717 = prims.neg(t716) # t717: "cuda:0 f32[1, 32, 512, 64]" - # t718 = prims.convert_element_type(t717, dtypes.bfloat16) # t718: "cuda:0 bf16[1, 32, 512, 64]" - t719 = ltorch.cat((t718, t714), -1) # t719: "cuda:0 bf16[1, 32, 512, 128]" - # t719 = prims.cat((t718, t714), -1) # t719: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t722 = ltorch.mul(t713, cos) # t722: "cuda:0 f32[1, 32, 512, 128]" - # t720 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t720: "cuda:0 f32[1, 32, 512, 128]" - # t721 = prims.convert_element_type(t713, dtypes.float32) # t721: "cuda:0 f32[1, 32, 512, 128]" - # t722 = prims.mul(t721, t720) # t722: "cuda:0 f32[1, 32, 512, 128]" - t725 = ltorch.mul(t719, sin) # t725: "cuda:0 f32[1, 32, 512, 128]" - # t723 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t723: "cuda:0 f32[1, 32, 512, 128]" - # t724 = prims.convert_element_type(t719, dtypes.float32) # t724: "cuda:0 f32[1, 32, 512, 128]" - # t725 = prims.mul(t724, t723) # t725: "cuda:0 f32[1, 32, 512, 128]" - t726 = ltorch.add(t722, t725, alpha=None) # t726: "cuda:0 f32[1, 32, 512, 128]" - # t726 = prims.add(t722, t725) # t726: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t727 = ltorch.to(t726, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t727: "cuda:0 bf16[1, 32, 512, 128]" - # t727 = prims.convert_element_type(t726, dtypes.bfloat16) # t727: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t728 = ltorch.getitem(t711, (..., slice(None, 128, None))) # t728: "cuda:0 bf16[1, 32, 512, 128]" - # t728 = prims.slice_prim(t711, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t728: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t729 = ltorch.getitem(t728, (..., slice(None, 64, None))) # t729: "cuda:0 bf16[1, 32, 512, 64]" - # t729 = prims.slice_prim(t728, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t729: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t730 = ltorch.getitem(t728, (..., slice(64, None, None))) # t730: "cuda:0 bf16[1, 32, 512, 64]" - # t730 = prims.slice_prim(t728, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t730: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t733 = ltorch.neg(t730) # t733: "cuda:0 bf16[1, 32, 512, 64]" - # t731 = prims.convert_element_type(t730, dtypes.float32) # t731: "cuda:0 f32[1, 32, 512, 64]" - # t732 = prims.neg(t731) # t732: "cuda:0 f32[1, 32, 512, 64]" - # t733 = prims.convert_element_type(t732, dtypes.bfloat16) # t733: "cuda:0 bf16[1, 32, 512, 64]" - t734 = ltorch.cat((t733, t729), -1) # t734: "cuda:0 bf16[1, 32, 512, 128]" - # t734 = prims.cat((t733, t729), -1) # t734: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t737 = ltorch.mul(t728, cos) # t737: "cuda:0 f32[1, 32, 512, 128]" - # t735 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t735: "cuda:0 f32[1, 32, 512, 128]" - # t736 = prims.convert_element_type(t728, dtypes.float32) # t736: "cuda:0 f32[1, 32, 512, 128]" - # t737 = prims.mul(t736, t735) # t737: "cuda:0 f32[1, 32, 512, 128]" - t740 = ltorch.mul(t734, sin) # t740: "cuda:0 f32[1, 32, 512, 128]" - # t738 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t738: "cuda:0 f32[1, 32, 512, 128]" - # t739 = prims.convert_element_type(t734, dtypes.float32) # t739: "cuda:0 f32[1, 32, 512, 128]" - # t740 = prims.mul(t739, t738) # t740: "cuda:0 f32[1, 32, 512, 128]" - t741 = ltorch.add(t737, t740, alpha=None) # t741: "cuda:0 f32[1, 32, 512, 128]" - # t741 = prims.add(t737, t740) # t741: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t742 = ltorch.to(t741, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t742: "cuda:0 bf16[1, 32, 512, 128]" - # t742 = prims.convert_element_type(t741, dtypes.bfloat16) # t742: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t743 = ltorch.getitem(t710, (..., slice(128, None, None))) # t743: "cuda:0 bf16[1, 32, 512, 0]" - # t743 = prims.slice_prim(t710, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t743: "cuda:0 bf16[1, 32, 512, 0]" - t744 = ltorch.cat((t727, t743), -1) # t744: "cuda:0 bf16[1, 32, 512, 128]" - # t744 = prims.cat((t727, t743), -1) # t744: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t745 = ltorch.getitem(t711, (..., slice(128, None, None))) # t745: "cuda:0 bf16[1, 32, 512, 0]" - # t745 = prims.slice_prim(t711, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t745: "cuda:0 bf16[1, 32, 512, 0]" - t746 = ltorch.cat((t742, t745), -1) # t746: "cuda:0 bf16[1, 32, 512, 128]" - # t746 = prims.cat((t742, t745), -1) # t746: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t776 = ltorch.scaled_dot_product_attention(t744, t746, t712, None, 0.0, True, scale=0.08838834764831843) # t776: "cuda:0 bf16[1, 32, 512, 128]" - # t749 = ltorch.mul(t744, 0.29730177875068026) # t749: "cuda:0 bf16[1, 32, 512, 128]" - # t747 = prims.convert_element_type(t744, dtypes.float32) # t747: "cuda:0 f32[1, 32, 512, 128]" - # t748 = prims.mul(t747, 0.29730177875068026) # t748: "cuda:0 f32[1, 32, 512, 128]" - # t749 = prims.convert_element_type(t748, dtypes.bfloat16) # t749: "cuda:0 bf16[1, 32, 512, 128]" - # t750 = ltorch.transpose(t746, -2, -1) # t750: "cuda:0 bf16[1, 32, 128, 512]" - # t750 = prims.transpose(t746, (0, 1, 3, 2)) # t750: "cuda:0 bf16[1, 32, 128, 512]" - # t753 = ltorch.mul(t750, 0.29730177875068026) # t753: "cuda:0 bf16[1, 32, 128, 512]" - # t751 = prims.convert_element_type(t750, dtypes.float32) # t751: "cuda:0 f32[1, 32, 128, 512]" - # t752 = prims.mul(t751, 0.29730177875068026) # t752: "cuda:0 f32[1, 32, 128, 512]" - # t753 = prims.convert_element_type(t752, dtypes.bfloat16) # t753: "cuda:0 bf16[1, 32, 128, 512]" - # t754 = ltorch.matmul(t749, t753) # t754: "cuda:0 bf16[1, 32, 512, 512]" - # t754 = prims.matmul(t749, t753) # t754: "cuda:0 bf16[1, 32, 512, 512]" - # t764 = ltorch.tril(t754, 0, fill_value=-float('inf')) # t764: "cuda:0 bf16[1, 32, 512, 512]" - # t755 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t755: "cuda:0 i64[512]" - # t755 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t755: "cuda:0 i64[512]" - # t756 = ltorch.unsqueeze(t755, -1) # t756: "cuda:0 i64[512, 1]" - # t756 = prims.broadcast_in_dim(t755, [512, 1], [0]) # t756: "cuda:0 i64[512, 1]" - # t757 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t757: "cuda:0 i64[512]" - # t757 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t757: "cuda:0 i64[512]" - # t758 = ltorch.unsqueeze(t757, -2) # t758: "cuda:0 i64[1, 512]" - # t758 = prims.broadcast_in_dim(t757, [1, 512], [1]) # t758: "cuda:0 i64[1, 512]" - # t759 = ltorch.add(t756, 0, alpha=None) # t759: "cuda:0 i64[512, 1]" - # t759 = prims.add(t756, 0) # t759: "cuda:0 i64[512, 1]" - # t762 = ltorch.ge(t759, t758) # t762: "cuda:0 b8[512, 512]" - # t760 = prims.broadcast_in_dim(t759, (512, 512), (0, 1)) # t760: "cuda:0 i64[512, 512]" - # t761 = prims.broadcast_in_dim(t758, (512, 512), (0, 1)) # t761: "cuda:0 i64[512, 512]" - # t762 = prims.ge(t760, t761) # t762: "cuda:0 b8[512, 512]" - # t764 = ltorch.where(t762, t754, -float('inf')) # t764: "cuda:0 bf16[1, 32, 512, 512]" - # t763 = prims.broadcast_in_dim(t762, (1, 32, 512, 512), (2, 3)) # t763: "cuda:0 b8[1, 32, 512, 512]" - # t764 = prims.where(t763, t754, -float('inf')) # t764: "cuda:0 bf16[1, 32, 512, 512]" - # t775 = ltorch._softmax(t764, -1, dtype=None) # t775: "cuda:0 bf16[1, 32, 512, 512]" - # t765 = ltorch.to(t764, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t765: "cuda:0 f32[1, 32, 512, 512]" - # t765 = prims.convert_element_type(t764, dtypes.float32) # t765: "cuda:0 f32[1, 32, 512, 512]" - # t767 = ltorch.amax(t765, -1, True) # t767: "cuda:0 f32[1, 32, 512, 1]" - # t766 = prims.amax(t765, (3,)) # t766: "cuda:0 f32[1, 32, 512]" - # t767 = prims.broadcast_in_dim(t766, [1, 32, 512, 1], [0, 1, 2]) # t767: "cuda:0 f32[1, 32, 512, 1]" - # t769 = ltorch.sub(t765, t767, alpha=None) # t769: "cuda:0 f32[1, 32, 512, 512]" - # t768 = prims.broadcast_in_dim(t767, (1, 32, 512, 512), (0, 1, 2, 3)) # t768: "cuda:0 f32[1, 32, 512, 512]" - # t769 = prims.sub(t765, t768) # t769: "cuda:0 f32[1, 32, 512, 512]" - # t770 = ltorch.exp(t769) # t770: "cuda:0 f32[1, 32, 512, 512]" - # t770 = prims.exp(t769) # t770: "cuda:0 f32[1, 32, 512, 512]" - # t772 = ltorch.sum(t770, -1, True, dtype=None) # t772: "cuda:0 f32[1, 32, 512, 1]" - # t771 = prims.sum(t770, (3,)) # t771: "cuda:0 f32[1, 32, 512]" - # t772 = prims.broadcast_in_dim(t771, [1, 32, 512, 1], [0, 1, 2]) # t772: "cuda:0 f32[1, 32, 512, 1]" - # t774 = ltorch.true_divide(t770, t772) # t774: "cuda:0 f32[1, 32, 512, 512]" - # t773 = prims.broadcast_in_dim(t772, (1, 32, 512, 512), (0, 1, 2, 3)) # t773: "cuda:0 f32[1, 32, 512, 512]" - # t774 = prims.div(t770, t773) # t774: "cuda:0 f32[1, 32, 512, 512]" - # t775 = ltorch.to(t774, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t775: "cuda:0 bf16[1, 32, 512, 512]" - # t775 = prims.convert_element_type(t774, dtypes.bfloat16) # t775: "cuda:0 bf16[1, 32, 512, 512]" - # t776 = ltorch.matmul(t775, t712) # t776: "cuda:0 bf16[1, 32, 512, 128]" - # t776 = prims.matmul(t775, t712) # t776: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t777 = ltorch.transpose(t776, 1, 2) # t777: "cuda:0 bf16[1, 512, 32, 128]" - # t777 = prims.transpose(t776, (0, 2, 1, 3)) # t777: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t778 = ltorch.reshape(t777, 1, 512, 4096) # t778: "cuda:0 bf16[1, 512, 4096]" - # t778 = prims.reshape(t777, (1, 512, 4096)) # t778: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t782 = ltorch.linear(t778, t_transformer_h_4_attn_proj_weight, None) # t782: "cuda:0 bf16[1, 512, 4096]" - # t782 = prims.linear(t778, t_transformer_h_4_attn_proj_weight, None) # t782: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t786 = ltorch.add(t782, t676, alpha=None) # t786: "cuda:0 bf16[1, 512, 4096]" - # t783 = prims.convert_element_type(t782, dtypes.float32) # t783: "cuda:0 f32[1, 512, 4096]" - # t784 = prims.convert_element_type(t676, dtypes.float32) # t784: "cuda:0 f32[1, 512, 4096]" - # t785 = prims.add(t783, t784) # t785: "cuda:0 f32[1, 512, 4096]" - # t786 = prims.convert_element_type(t785, dtypes.bfloat16) # t786: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t787 = prims.convert_element_type(t786, dtypes.float32) # t787: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t788 = ltorch.mul(t787, t787) # t788: "cuda:0 f32[1, 512, 4096]" - # t788 = prims.mul(t787, t787) # t788: "cuda:0 f32[1, 512, 4096]" - t792 = ltorch.mean(t788, -1, True, dtype=None) # t792: "cuda:0 f32[1, 512, 1]" - # t790 = prims.sum(t788, (2,)) # t790: "cuda:0 f32[1, 512]" - # t791 = prims.broadcast_in_dim(t790, [1, 512, 1], [0, 1]) # t791: "cuda:0 f32[1, 512, 1]" - # t792 = ltorch.true_divide(t791, 4096) # t792: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t792 = prims.div(t791, 4096.0) # t792: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t794 = ltorch.add(t792, 1e-05, alpha=None) # t794: "cuda:0 f32[1, 512, 1]" - # t794 = prims.add(t792, 1e-05) # t794: "cuda:0 f32[1, 512, 1]" - t795 = ltorch.rsqrt(t794) # t795: "cuda:0 f32[1, 512, 1]" - # t795 = prims.rsqrt(t794) # t795: "cuda:0 f32[1, 512, 1]" - t797 = ltorch.mul(t787, t795) # t797: "cuda:0 f32[1, 512, 4096]" - # t796 = prims.broadcast_in_dim(t795, (1, 512, 4096), (0, 1, 2)) # t796: "cuda:0 f32[1, 512, 4096]" - # t797 = prims.mul(t787, t796) # t797: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t798 = ltorch.to(t797, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t798: "cuda:0 bf16[1, 512, 4096]" - # t798 = prims.convert_element_type(t797, dtypes.bfloat16) # t798: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t808 = ltorch.mul(t798, t_transformer_h_4_norm_2_weight) # t808: "cuda:0 bf16[1, 512, 4096]" - # t804 = prims.broadcast_in_dim(t_transformer_h_4_norm_2_weight, (1, 512, 4096), (2,)) # t804: "cuda:0 bf16[1, 512, 4096]" - # t805 = prims.convert_element_type(t798, dtypes.float32) # t805: "cuda:0 f32[1, 512, 4096]" - # t806 = prims.convert_element_type(t804, dtypes.float32) # t806: "cuda:0 f32[1, 512, 4096]" - # t807 = prims.mul(t805, t806) # t807: "cuda:0 f32[1, 512, 4096]" - # t808 = prims.convert_element_type(t807, dtypes.bfloat16) # t808: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t813 = ltorch.linear(t808, t_transformer_h_4_mlp_fc_1_weight, None) # t813: "cuda:0 bf16[1, 512, 11008]" - # t813 = prims.linear(t808, t_transformer_h_4_mlp_fc_1_weight, None) # t813: "cuda:0 bf16[1, 512, 11008]" - t817 = ltorch.linear(t808, t_transformer_h_4_mlp_fc_2_weight, None) # t817: "cuda:0 bf16[1, 512, 11008]" - # t817 = prims.linear(t808, t_transformer_h_4_mlp_fc_2_weight, None) # t817: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t827 = ltorch.silu(t813, False) # t827: "cuda:0 bf16[1, 512, 11008]" - # t818 = prims.convert_element_type(t813, dtypes.float32) # t818: "cuda:0 f32[1, 512, 11008]" - # t819 = prims.neg(t818) # t819: "cuda:0 f32[1, 512, 11008]" - # t820 = prims.exp(t819) # t820: "cuda:0 f32[1, 512, 11008]" - # t821 = prims.add(1.0, t820) # t821: "cuda:0 f32[1, 512, 11008]" - # t822 = prims.reciprocal(t821) # t822: "cuda:0 f32[1, 512, 11008]" - # t823 = prims.convert_element_type(t822, dtypes.bfloat16) # t823: "cuda:0 bf16[1, 512, 11008]" - # t824 = prims.convert_element_type(t813, dtypes.float32) # t824: "cuda:0 f32[1, 512, 11008]" - # t825 = prims.convert_element_type(t823, dtypes.float32) # t825: "cuda:0 f32[1, 512, 11008]" - # t826 = prims.mul(t824, t825) # t826: "cuda:0 f32[1, 512, 11008]" - # t827 = prims.convert_element_type(t826, dtypes.bfloat16) # t827: "cuda:0 bf16[1, 512, 11008]" - t831 = ltorch.mul(t827, t817) # t831: "cuda:0 bf16[1, 512, 11008]" - # t828 = prims.convert_element_type(t827, dtypes.float32) # t828: "cuda:0 f32[1, 512, 11008]" - # t829 = prims.convert_element_type(t817, dtypes.float32) # t829: "cuda:0 f32[1, 512, 11008]" - # t830 = prims.mul(t828, t829) # t830: "cuda:0 f32[1, 512, 11008]" - # t831 = prims.convert_element_type(t830, dtypes.bfloat16) # t831: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t835 = ltorch.linear(t831, t_transformer_h_4_mlp_proj_weight, None) # t835: "cuda:0 bf16[1, 512, 4096]" - # t835 = prims.linear(t831, t_transformer_h_4_mlp_proj_weight, None) # t835: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t839 = ltorch.add(t835, t786, alpha=None) # t839: "cuda:0 bf16[1, 512, 4096]" - # t836 = prims.convert_element_type(t835, dtypes.float32) # t836: "cuda:0 f32[1, 512, 4096]" - # t837 = prims.convert_element_type(t786, dtypes.float32) # t837: "cuda:0 f32[1, 512, 4096]" - # t838 = prims.add(t836, t837) # t838: "cuda:0 f32[1, 512, 4096]" - # t839 = prims.convert_element_type(t838, dtypes.bfloat16) # t839: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t841 = prims.convert_element_type(t839, dtypes.float32) # t841: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t842 = ltorch.mul(t841, t841) # t842: "cuda:0 f32[1, 512, 4096]" - # t842 = prims.mul(t841, t841) # t842: "cuda:0 f32[1, 512, 4096]" - t846 = ltorch.mean(t842, -1, True, dtype=None) # t846: "cuda:0 f32[1, 512, 1]" - # t844 = prims.sum(t842, (2,)) # t844: "cuda:0 f32[1, 512]" - # t845 = prims.broadcast_in_dim(t844, [1, 512, 1], [0, 1]) # t845: "cuda:0 f32[1, 512, 1]" - # t846 = ltorch.true_divide(t845, 4096) # t846: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t846 = prims.div(t845, 4096.0) # t846: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t848 = ltorch.add(t846, 1e-05, alpha=None) # t848: "cuda:0 f32[1, 512, 1]" - # t848 = prims.add(t846, 1e-05) # t848: "cuda:0 f32[1, 512, 1]" - t849 = ltorch.rsqrt(t848) # t849: "cuda:0 f32[1, 512, 1]" - # t849 = prims.rsqrt(t848) # t849: "cuda:0 f32[1, 512, 1]" - t851 = ltorch.mul(t841, t849) # t851: "cuda:0 f32[1, 512, 4096]" - # t850 = prims.broadcast_in_dim(t849, (1, 512, 4096), (0, 1, 2)) # t850: "cuda:0 f32[1, 512, 4096]" - # t851 = prims.mul(t841, t850) # t851: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t852 = ltorch.to(t851, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t852: "cuda:0 bf16[1, 512, 4096]" - # t852 = prims.convert_element_type(t851, dtypes.bfloat16) # t852: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t862 = ltorch.mul(t852, t_transformer_h_5_norm_1_weight) # t862: "cuda:0 bf16[1, 512, 4096]" - # t858 = prims.broadcast_in_dim(t_transformer_h_5_norm_1_weight, (1, 512, 4096), (2,)) # t858: "cuda:0 bf16[1, 512, 4096]" - # t859 = prims.convert_element_type(t852, dtypes.float32) # t859: "cuda:0 f32[1, 512, 4096]" - # t860 = prims.convert_element_type(t858, dtypes.float32) # t860: "cuda:0 f32[1, 512, 4096]" - # t861 = prims.mul(t859, t860) # t861: "cuda:0 f32[1, 512, 4096]" - # t862 = prims.convert_element_type(t861, dtypes.bfloat16) # t862: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t867 = ltorch.linear(t862, t_transformer_h_5_attn_attn_weight, None) # t867: "cuda:0 bf16[1, 512, 12288]" - # t867 = prims.linear(t862, t_transformer_h_5_attn_attn_weight, None) # t867: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t868 = ltorch.view(t867, 1, 512, 32, 3, 128) # t868: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t868 = ltorch.reshape(t867, (1, 512, 32, 3, 128)) # t868: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t868 = prims.reshape(t867, (1, 512, 32, 3, 128)) # t868: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t869 = ltorch.permute(t868, 0, 2, 3, 1, 4) # t869: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t869 = prims.transpose(t868, (0, 2, 3, 1, 4)) # t869: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t870, t871, t872) = ltorch.split(t869, (1, 1, 1), 2) - # t870 = prims.slice_prim(t869, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t870: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t871 = prims.slice_prim(t869, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t871: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t872 = prims.slice_prim(t869, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t872: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t873 = ltorch.reshape(t870, 1, -1, 512, 128) # t873: "cuda:0 bf16[1, 32, 512, 128]" - # t873 = prims.reshape(t870, (1, 32, 512, 128)) # t873: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t874 = ltorch.reshape(t871, 1, -1, 512, 128) # t874: "cuda:0 bf16[1, 32, 512, 128]" - # t874 = prims.reshape(t871, (1, 32, 512, 128)) # t874: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t875 = ltorch.reshape(t872, 1, -1, 512, 128) # t875: "cuda:0 bf16[1, 32, 512, 128]" - # t875 = prims.reshape(t872, (1, 32, 512, 128)) # t875: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t876 = ltorch.getitem(t873, (..., slice(None, 128, None))) # t876: "cuda:0 bf16[1, 32, 512, 128]" - # t876 = prims.slice_prim(t873, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t876: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t877 = ltorch.getitem(t876, (..., slice(None, 64, None))) # t877: "cuda:0 bf16[1, 32, 512, 64]" - # t877 = prims.slice_prim(t876, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t877: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t878 = ltorch.getitem(t876, (..., slice(64, None, None))) # t878: "cuda:0 bf16[1, 32, 512, 64]" - # t878 = prims.slice_prim(t876, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t878: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t881 = ltorch.neg(t878) # t881: "cuda:0 bf16[1, 32, 512, 64]" - # t879 = prims.convert_element_type(t878, dtypes.float32) # t879: "cuda:0 f32[1, 32, 512, 64]" - # t880 = prims.neg(t879) # t880: "cuda:0 f32[1, 32, 512, 64]" - # t881 = prims.convert_element_type(t880, dtypes.bfloat16) # t881: "cuda:0 bf16[1, 32, 512, 64]" - t882 = ltorch.cat((t881, t877), -1) # t882: "cuda:0 bf16[1, 32, 512, 128]" - # t882 = prims.cat((t881, t877), -1) # t882: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t885 = ltorch.mul(t876, cos) # t885: "cuda:0 f32[1, 32, 512, 128]" - # t883 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t883: "cuda:0 f32[1, 32, 512, 128]" - # t884 = prims.convert_element_type(t876, dtypes.float32) # t884: "cuda:0 f32[1, 32, 512, 128]" - # t885 = prims.mul(t884, t883) # t885: "cuda:0 f32[1, 32, 512, 128]" - t888 = ltorch.mul(t882, sin) # t888: "cuda:0 f32[1, 32, 512, 128]" - # t886 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t886: "cuda:0 f32[1, 32, 512, 128]" - # t887 = prims.convert_element_type(t882, dtypes.float32) # t887: "cuda:0 f32[1, 32, 512, 128]" - # t888 = prims.mul(t887, t886) # t888: "cuda:0 f32[1, 32, 512, 128]" - t889 = ltorch.add(t885, t888, alpha=None) # t889: "cuda:0 f32[1, 32, 512, 128]" - # t889 = prims.add(t885, t888) # t889: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t890 = ltorch.to(t889, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t890: "cuda:0 bf16[1, 32, 512, 128]" - # t890 = prims.convert_element_type(t889, dtypes.bfloat16) # t890: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t891 = ltorch.getitem(t874, (..., slice(None, 128, None))) # t891: "cuda:0 bf16[1, 32, 512, 128]" - # t891 = prims.slice_prim(t874, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t891: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t892 = ltorch.getitem(t891, (..., slice(None, 64, None))) # t892: "cuda:0 bf16[1, 32, 512, 64]" - # t892 = prims.slice_prim(t891, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t892: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t893 = ltorch.getitem(t891, (..., slice(64, None, None))) # t893: "cuda:0 bf16[1, 32, 512, 64]" - # t893 = prims.slice_prim(t891, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t893: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t896 = ltorch.neg(t893) # t896: "cuda:0 bf16[1, 32, 512, 64]" - # t894 = prims.convert_element_type(t893, dtypes.float32) # t894: "cuda:0 f32[1, 32, 512, 64]" - # t895 = prims.neg(t894) # t895: "cuda:0 f32[1, 32, 512, 64]" - # t896 = prims.convert_element_type(t895, dtypes.bfloat16) # t896: "cuda:0 bf16[1, 32, 512, 64]" - t897 = ltorch.cat((t896, t892), -1) # t897: "cuda:0 bf16[1, 32, 512, 128]" - # t897 = prims.cat((t896, t892), -1) # t897: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t900 = ltorch.mul(t891, cos) # t900: "cuda:0 f32[1, 32, 512, 128]" - # t898 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t898: "cuda:0 f32[1, 32, 512, 128]" - # t899 = prims.convert_element_type(t891, dtypes.float32) # t899: "cuda:0 f32[1, 32, 512, 128]" - # t900 = prims.mul(t899, t898) # t900: "cuda:0 f32[1, 32, 512, 128]" - t903 = ltorch.mul(t897, sin) # t903: "cuda:0 f32[1, 32, 512, 128]" - # t901 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t901: "cuda:0 f32[1, 32, 512, 128]" - # t902 = prims.convert_element_type(t897, dtypes.float32) # t902: "cuda:0 f32[1, 32, 512, 128]" - # t903 = prims.mul(t902, t901) # t903: "cuda:0 f32[1, 32, 512, 128]" - t904 = ltorch.add(t900, t903, alpha=None) # t904: "cuda:0 f32[1, 32, 512, 128]" - # t904 = prims.add(t900, t903) # t904: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t905 = ltorch.to(t904, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t905: "cuda:0 bf16[1, 32, 512, 128]" - # t905 = prims.convert_element_type(t904, dtypes.bfloat16) # t905: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t906 = ltorch.getitem(t873, (..., slice(128, None, None))) # t906: "cuda:0 bf16[1, 32, 512, 0]" - # t906 = prims.slice_prim(t873, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t906: "cuda:0 bf16[1, 32, 512, 0]" - t907 = ltorch.cat((t890, t906), -1) # t907: "cuda:0 bf16[1, 32, 512, 128]" - # t907 = prims.cat((t890, t906), -1) # t907: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t908 = ltorch.getitem(t874, (..., slice(128, None, None))) # t908: "cuda:0 bf16[1, 32, 512, 0]" - # t908 = prims.slice_prim(t874, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t908: "cuda:0 bf16[1, 32, 512, 0]" - t909 = ltorch.cat((t905, t908), -1) # t909: "cuda:0 bf16[1, 32, 512, 128]" - # t909 = prims.cat((t905, t908), -1) # t909: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t939 = ltorch.scaled_dot_product_attention(t907, t909, t875, None, 0.0, True, scale=0.08838834764831843) # t939: "cuda:0 bf16[1, 32, 512, 128]" - # t912 = ltorch.mul(t907, 0.29730177875068026) # t912: "cuda:0 bf16[1, 32, 512, 128]" - # t910 = prims.convert_element_type(t907, dtypes.float32) # t910: "cuda:0 f32[1, 32, 512, 128]" - # t911 = prims.mul(t910, 0.29730177875068026) # t911: "cuda:0 f32[1, 32, 512, 128]" - # t912 = prims.convert_element_type(t911, dtypes.bfloat16) # t912: "cuda:0 bf16[1, 32, 512, 128]" - # t913 = ltorch.transpose(t909, -2, -1) # t913: "cuda:0 bf16[1, 32, 128, 512]" - # t913 = prims.transpose(t909, (0, 1, 3, 2)) # t913: "cuda:0 bf16[1, 32, 128, 512]" - # t916 = ltorch.mul(t913, 0.29730177875068026) # t916: "cuda:0 bf16[1, 32, 128, 512]" - # t914 = prims.convert_element_type(t913, dtypes.float32) # t914: "cuda:0 f32[1, 32, 128, 512]" - # t915 = prims.mul(t914, 0.29730177875068026) # t915: "cuda:0 f32[1, 32, 128, 512]" - # t916 = prims.convert_element_type(t915, dtypes.bfloat16) # t916: "cuda:0 bf16[1, 32, 128, 512]" - # t917 = ltorch.matmul(t912, t916) # t917: "cuda:0 bf16[1, 32, 512, 512]" - # t917 = prims.matmul(t912, t916) # t917: "cuda:0 bf16[1, 32, 512, 512]" - # t927 = ltorch.tril(t917, 0, fill_value=-float('inf')) # t927: "cuda:0 bf16[1, 32, 512, 512]" - # t918 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t918: "cuda:0 i64[512]" - # t918 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t918: "cuda:0 i64[512]" - # t919 = ltorch.unsqueeze(t918, -1) # t919: "cuda:0 i64[512, 1]" - # t919 = prims.broadcast_in_dim(t918, [512, 1], [0]) # t919: "cuda:0 i64[512, 1]" - # t920 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t920: "cuda:0 i64[512]" - # t920 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t920: "cuda:0 i64[512]" - # t921 = ltorch.unsqueeze(t920, -2) # t921: "cuda:0 i64[1, 512]" - # t921 = prims.broadcast_in_dim(t920, [1, 512], [1]) # t921: "cuda:0 i64[1, 512]" - # t922 = ltorch.add(t919, 0, alpha=None) # t922: "cuda:0 i64[512, 1]" - # t922 = prims.add(t919, 0) # t922: "cuda:0 i64[512, 1]" - # t925 = ltorch.ge(t922, t921) # t925: "cuda:0 b8[512, 512]" - # t923 = prims.broadcast_in_dim(t922, (512, 512), (0, 1)) # t923: "cuda:0 i64[512, 512]" - # t924 = prims.broadcast_in_dim(t921, (512, 512), (0, 1)) # t924: "cuda:0 i64[512, 512]" - # t925 = prims.ge(t923, t924) # t925: "cuda:0 b8[512, 512]" - # t927 = ltorch.where(t925, t917, -float('inf')) # t927: "cuda:0 bf16[1, 32, 512, 512]" - # t926 = prims.broadcast_in_dim(t925, (1, 32, 512, 512), (2, 3)) # t926: "cuda:0 b8[1, 32, 512, 512]" - # t927 = prims.where(t926, t917, -float('inf')) # t927: "cuda:0 bf16[1, 32, 512, 512]" - # t938 = ltorch._softmax(t927, -1, dtype=None) # t938: "cuda:0 bf16[1, 32, 512, 512]" - # t928 = ltorch.to(t927, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t928: "cuda:0 f32[1, 32, 512, 512]" - # t928 = prims.convert_element_type(t927, dtypes.float32) # t928: "cuda:0 f32[1, 32, 512, 512]" - # t930 = ltorch.amax(t928, -1, True) # t930: "cuda:0 f32[1, 32, 512, 1]" - # t929 = prims.amax(t928, (3,)) # t929: "cuda:0 f32[1, 32, 512]" - # t930 = prims.broadcast_in_dim(t929, [1, 32, 512, 1], [0, 1, 2]) # t930: "cuda:0 f32[1, 32, 512, 1]" - # t932 = ltorch.sub(t928, t930, alpha=None) # t932: "cuda:0 f32[1, 32, 512, 512]" - # t931 = prims.broadcast_in_dim(t930, (1, 32, 512, 512), (0, 1, 2, 3)) # t931: "cuda:0 f32[1, 32, 512, 512]" - # t932 = prims.sub(t928, t931) # t932: "cuda:0 f32[1, 32, 512, 512]" - # t933 = ltorch.exp(t932) # t933: "cuda:0 f32[1, 32, 512, 512]" - # t933 = prims.exp(t932) # t933: "cuda:0 f32[1, 32, 512, 512]" - # t935 = ltorch.sum(t933, -1, True, dtype=None) # t935: "cuda:0 f32[1, 32, 512, 1]" - # t934 = prims.sum(t933, (3,)) # t934: "cuda:0 f32[1, 32, 512]" - # t935 = prims.broadcast_in_dim(t934, [1, 32, 512, 1], [0, 1, 2]) # t935: "cuda:0 f32[1, 32, 512, 1]" - # t937 = ltorch.true_divide(t933, t935) # t937: "cuda:0 f32[1, 32, 512, 512]" - # t936 = prims.broadcast_in_dim(t935, (1, 32, 512, 512), (0, 1, 2, 3)) # t936: "cuda:0 f32[1, 32, 512, 512]" - # t937 = prims.div(t933, t936) # t937: "cuda:0 f32[1, 32, 512, 512]" - # t938 = ltorch.to(t937, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t938: "cuda:0 bf16[1, 32, 512, 512]" - # t938 = prims.convert_element_type(t937, dtypes.bfloat16) # t938: "cuda:0 bf16[1, 32, 512, 512]" - # t939 = ltorch.matmul(t938, t875) # t939: "cuda:0 bf16[1, 32, 512, 128]" - # t939 = prims.matmul(t938, t875) # t939: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t940 = ltorch.transpose(t939, 1, 2) # t940: "cuda:0 bf16[1, 512, 32, 128]" - # t940 = prims.transpose(t939, (0, 2, 1, 3)) # t940: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t941 = ltorch.reshape(t940, 1, 512, 4096) # t941: "cuda:0 bf16[1, 512, 4096]" - # t941 = prims.reshape(t940, (1, 512, 4096)) # t941: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t945 = ltorch.linear(t941, t_transformer_h_5_attn_proj_weight, None) # t945: "cuda:0 bf16[1, 512, 4096]" - # t945 = prims.linear(t941, t_transformer_h_5_attn_proj_weight, None) # t945: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t949 = ltorch.add(t945, t839, alpha=None) # t949: "cuda:0 bf16[1, 512, 4096]" - # t946 = prims.convert_element_type(t945, dtypes.float32) # t946: "cuda:0 f32[1, 512, 4096]" - # t947 = prims.convert_element_type(t839, dtypes.float32) # t947: "cuda:0 f32[1, 512, 4096]" - # t948 = prims.add(t946, t947) # t948: "cuda:0 f32[1, 512, 4096]" - # t949 = prims.convert_element_type(t948, dtypes.bfloat16) # t949: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t950 = prims.convert_element_type(t949, dtypes.float32) # t950: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t951 = ltorch.mul(t950, t950) # t951: "cuda:0 f32[1, 512, 4096]" - # t951 = prims.mul(t950, t950) # t951: "cuda:0 f32[1, 512, 4096]" - t955 = ltorch.mean(t951, -1, True, dtype=None) # t955: "cuda:0 f32[1, 512, 1]" - # t953 = prims.sum(t951, (2,)) # t953: "cuda:0 f32[1, 512]" - # t954 = prims.broadcast_in_dim(t953, [1, 512, 1], [0, 1]) # t954: "cuda:0 f32[1, 512, 1]" - # t955 = ltorch.true_divide(t954, 4096) # t955: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t955 = prims.div(t954, 4096.0) # t955: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t957 = ltorch.add(t955, 1e-05, alpha=None) # t957: "cuda:0 f32[1, 512, 1]" - # t957 = prims.add(t955, 1e-05) # t957: "cuda:0 f32[1, 512, 1]" - t958 = ltorch.rsqrt(t957) # t958: "cuda:0 f32[1, 512, 1]" - # t958 = prims.rsqrt(t957) # t958: "cuda:0 f32[1, 512, 1]" - t960 = ltorch.mul(t950, t958) # t960: "cuda:0 f32[1, 512, 4096]" - # t959 = prims.broadcast_in_dim(t958, (1, 512, 4096), (0, 1, 2)) # t959: "cuda:0 f32[1, 512, 4096]" - # t960 = prims.mul(t950, t959) # t960: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t961 = ltorch.to(t960, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t961: "cuda:0 bf16[1, 512, 4096]" - # t961 = prims.convert_element_type(t960, dtypes.bfloat16) # t961: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t971 = ltorch.mul(t961, t_transformer_h_5_norm_2_weight) # t971: "cuda:0 bf16[1, 512, 4096]" - # t967 = prims.broadcast_in_dim(t_transformer_h_5_norm_2_weight, (1, 512, 4096), (2,)) # t967: "cuda:0 bf16[1, 512, 4096]" - # t968 = prims.convert_element_type(t961, dtypes.float32) # t968: "cuda:0 f32[1, 512, 4096]" - # t969 = prims.convert_element_type(t967, dtypes.float32) # t969: "cuda:0 f32[1, 512, 4096]" - # t970 = prims.mul(t968, t969) # t970: "cuda:0 f32[1, 512, 4096]" - # t971 = prims.convert_element_type(t970, dtypes.bfloat16) # t971: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t976 = ltorch.linear(t971, t_transformer_h_5_mlp_fc_1_weight, None) # t976: "cuda:0 bf16[1, 512, 11008]" - # t976 = prims.linear(t971, t_transformer_h_5_mlp_fc_1_weight, None) # t976: "cuda:0 bf16[1, 512, 11008]" - t980 = ltorch.linear(t971, t_transformer_h_5_mlp_fc_2_weight, None) # t980: "cuda:0 bf16[1, 512, 11008]" - # t980 = prims.linear(t971, t_transformer_h_5_mlp_fc_2_weight, None) # t980: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t990 = ltorch.silu(t976, False) # t990: "cuda:0 bf16[1, 512, 11008]" - # t981 = prims.convert_element_type(t976, dtypes.float32) # t981: "cuda:0 f32[1, 512, 11008]" - # t982 = prims.neg(t981) # t982: "cuda:0 f32[1, 512, 11008]" - # t983 = prims.exp(t982) # t983: "cuda:0 f32[1, 512, 11008]" - # t984 = prims.add(1.0, t983) # t984: "cuda:0 f32[1, 512, 11008]" - # t985 = prims.reciprocal(t984) # t985: "cuda:0 f32[1, 512, 11008]" - # t986 = prims.convert_element_type(t985, dtypes.bfloat16) # t986: "cuda:0 bf16[1, 512, 11008]" - # t987 = prims.convert_element_type(t976, dtypes.float32) # t987: "cuda:0 f32[1, 512, 11008]" - # t988 = prims.convert_element_type(t986, dtypes.float32) # t988: "cuda:0 f32[1, 512, 11008]" - # t989 = prims.mul(t987, t988) # t989: "cuda:0 f32[1, 512, 11008]" - # t990 = prims.convert_element_type(t989, dtypes.bfloat16) # t990: "cuda:0 bf16[1, 512, 11008]" - t994 = ltorch.mul(t990, t980) # t994: "cuda:0 bf16[1, 512, 11008]" - # t991 = prims.convert_element_type(t990, dtypes.float32) # t991: "cuda:0 f32[1, 512, 11008]" - # t992 = prims.convert_element_type(t980, dtypes.float32) # t992: "cuda:0 f32[1, 512, 11008]" - # t993 = prims.mul(t991, t992) # t993: "cuda:0 f32[1, 512, 11008]" - # t994 = prims.convert_element_type(t993, dtypes.bfloat16) # t994: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t998 = ltorch.linear(t994, t_transformer_h_5_mlp_proj_weight, None) # t998: "cuda:0 bf16[1, 512, 4096]" - # t998 = prims.linear(t994, t_transformer_h_5_mlp_proj_weight, None) # t998: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t1002 = ltorch.add(t998, t949, alpha=None) # t1002: "cuda:0 bf16[1, 512, 4096]" - # t999 = prims.convert_element_type(t998, dtypes.float32) # t999: "cuda:0 f32[1, 512, 4096]" - # t1000 = prims.convert_element_type(t949, dtypes.float32) # t1000: "cuda:0 f32[1, 512, 4096]" - # t1001 = prims.add(t999, t1000) # t1001: "cuda:0 f32[1, 512, 4096]" - # t1002 = prims.convert_element_type(t1001, dtypes.bfloat16) # t1002: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1004 = prims.convert_element_type(t1002, dtypes.float32) # t1004: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1005 = ltorch.mul(t1004, t1004) # t1005: "cuda:0 f32[1, 512, 4096]" - # t1005 = prims.mul(t1004, t1004) # t1005: "cuda:0 f32[1, 512, 4096]" - t1009 = ltorch.mean(t1005, -1, True, dtype=None) # t1009: "cuda:0 f32[1, 512, 1]" - # t1007 = prims.sum(t1005, (2,)) # t1007: "cuda:0 f32[1, 512]" - # t1008 = prims.broadcast_in_dim(t1007, [1, 512, 1], [0, 1]) # t1008: "cuda:0 f32[1, 512, 1]" - # t1009 = ltorch.true_divide(t1008, 4096) # t1009: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1009 = prims.div(t1008, 4096.0) # t1009: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1011 = ltorch.add(t1009, 1e-05, alpha=None) # t1011: "cuda:0 f32[1, 512, 1]" - # t1011 = prims.add(t1009, 1e-05) # t1011: "cuda:0 f32[1, 512, 1]" - t1012 = ltorch.rsqrt(t1011) # t1012: "cuda:0 f32[1, 512, 1]" - # t1012 = prims.rsqrt(t1011) # t1012: "cuda:0 f32[1, 512, 1]" - t1014 = ltorch.mul(t1004, t1012) # t1014: "cuda:0 f32[1, 512, 4096]" - # t1013 = prims.broadcast_in_dim(t1012, (1, 512, 4096), (0, 1, 2)) # t1013: "cuda:0 f32[1, 512, 4096]" - # t1014 = prims.mul(t1004, t1013) # t1014: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1015 = ltorch.to(t1014, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1015: "cuda:0 bf16[1, 512, 4096]" - # t1015 = prims.convert_element_type(t1014, dtypes.bfloat16) # t1015: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1025 = ltorch.mul(t1015, t_transformer_h_6_norm_1_weight) # t1025: "cuda:0 bf16[1, 512, 4096]" - # t1021 = prims.broadcast_in_dim(t_transformer_h_6_norm_1_weight, (1, 512, 4096), (2,)) # t1021: "cuda:0 bf16[1, 512, 4096]" - # t1022 = prims.convert_element_type(t1015, dtypes.float32) # t1022: "cuda:0 f32[1, 512, 4096]" - # t1023 = prims.convert_element_type(t1021, dtypes.float32) # t1023: "cuda:0 f32[1, 512, 4096]" - # t1024 = prims.mul(t1022, t1023) # t1024: "cuda:0 f32[1, 512, 4096]" - # t1025 = prims.convert_element_type(t1024, dtypes.bfloat16) # t1025: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1030 = ltorch.linear(t1025, t_transformer_h_6_attn_attn_weight, None) # t1030: "cuda:0 bf16[1, 512, 12288]" - # t1030 = prims.linear(t1025, t_transformer_h_6_attn_attn_weight, None) # t1030: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t1031 = ltorch.view(t1030, 1, 512, 32, 3, 128) # t1031: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1031 = ltorch.reshape(t1030, (1, 512, 32, 3, 128)) # t1031: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1031 = prims.reshape(t1030, (1, 512, 32, 3, 128)) # t1031: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t1032 = ltorch.permute(t1031, 0, 2, 3, 1, 4) # t1032: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t1032 = prims.transpose(t1031, (0, 2, 3, 1, 4)) # t1032: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t1033, t1034, t1035) = ltorch.split(t1032, (1, 1, 1), 2) - # t1033 = prims.slice_prim(t1032, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1033: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1034 = prims.slice_prim(t1032, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1034: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1035 = prims.slice_prim(t1032, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1035: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t1036 = ltorch.reshape(t1033, 1, -1, 512, 128) # t1036: "cuda:0 bf16[1, 32, 512, 128]" - # t1036 = prims.reshape(t1033, (1, 32, 512, 128)) # t1036: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t1037 = ltorch.reshape(t1034, 1, -1, 512, 128) # t1037: "cuda:0 bf16[1, 32, 512, 128]" - # t1037 = prims.reshape(t1034, (1, 32, 512, 128)) # t1037: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t1038 = ltorch.reshape(t1035, 1, -1, 512, 128) # t1038: "cuda:0 bf16[1, 32, 512, 128]" - # t1038 = prims.reshape(t1035, (1, 32, 512, 128)) # t1038: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t1039 = ltorch.getitem(t1036, (..., slice(None, 128, None))) # t1039: "cuda:0 bf16[1, 32, 512, 128]" - # t1039 = prims.slice_prim(t1036, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1039: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1040 = ltorch.getitem(t1039, (..., slice(None, 64, None))) # t1040: "cuda:0 bf16[1, 32, 512, 64]" - # t1040 = prims.slice_prim(t1039, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1040: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1041 = ltorch.getitem(t1039, (..., slice(64, None, None))) # t1041: "cuda:0 bf16[1, 32, 512, 64]" - # t1041 = prims.slice_prim(t1039, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1041: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1044 = ltorch.neg(t1041) # t1044: "cuda:0 bf16[1, 32, 512, 64]" - # t1042 = prims.convert_element_type(t1041, dtypes.float32) # t1042: "cuda:0 f32[1, 32, 512, 64]" - # t1043 = prims.neg(t1042) # t1043: "cuda:0 f32[1, 32, 512, 64]" - # t1044 = prims.convert_element_type(t1043, dtypes.bfloat16) # t1044: "cuda:0 bf16[1, 32, 512, 64]" - t1045 = ltorch.cat((t1044, t1040), -1) # t1045: "cuda:0 bf16[1, 32, 512, 128]" - # t1045 = prims.cat((t1044, t1040), -1) # t1045: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1048 = ltorch.mul(t1039, cos) # t1048: "cuda:0 f32[1, 32, 512, 128]" - # t1046 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1046: "cuda:0 f32[1, 32, 512, 128]" - # t1047 = prims.convert_element_type(t1039, dtypes.float32) # t1047: "cuda:0 f32[1, 32, 512, 128]" - # t1048 = prims.mul(t1047, t1046) # t1048: "cuda:0 f32[1, 32, 512, 128]" - t1051 = ltorch.mul(t1045, sin) # t1051: "cuda:0 f32[1, 32, 512, 128]" - # t1049 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1049: "cuda:0 f32[1, 32, 512, 128]" - # t1050 = prims.convert_element_type(t1045, dtypes.float32) # t1050: "cuda:0 f32[1, 32, 512, 128]" - # t1051 = prims.mul(t1050, t1049) # t1051: "cuda:0 f32[1, 32, 512, 128]" - t1052 = ltorch.add(t1048, t1051, alpha=None) # t1052: "cuda:0 f32[1, 32, 512, 128]" - # t1052 = prims.add(t1048, t1051) # t1052: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1053 = ltorch.to(t1052, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1053: "cuda:0 bf16[1, 32, 512, 128]" - # t1053 = prims.convert_element_type(t1052, dtypes.bfloat16) # t1053: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t1054 = ltorch.getitem(t1037, (..., slice(None, 128, None))) # t1054: "cuda:0 bf16[1, 32, 512, 128]" - # t1054 = prims.slice_prim(t1037, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1054: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1055 = ltorch.getitem(t1054, (..., slice(None, 64, None))) # t1055: "cuda:0 bf16[1, 32, 512, 64]" - # t1055 = prims.slice_prim(t1054, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1055: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1056 = ltorch.getitem(t1054, (..., slice(64, None, None))) # t1056: "cuda:0 bf16[1, 32, 512, 64]" - # t1056 = prims.slice_prim(t1054, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1056: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1059 = ltorch.neg(t1056) # t1059: "cuda:0 bf16[1, 32, 512, 64]" - # t1057 = prims.convert_element_type(t1056, dtypes.float32) # t1057: "cuda:0 f32[1, 32, 512, 64]" - # t1058 = prims.neg(t1057) # t1058: "cuda:0 f32[1, 32, 512, 64]" - # t1059 = prims.convert_element_type(t1058, dtypes.bfloat16) # t1059: "cuda:0 bf16[1, 32, 512, 64]" - t1060 = ltorch.cat((t1059, t1055), -1) # t1060: "cuda:0 bf16[1, 32, 512, 128]" - # t1060 = prims.cat((t1059, t1055), -1) # t1060: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1063 = ltorch.mul(t1054, cos) # t1063: "cuda:0 f32[1, 32, 512, 128]" - # t1061 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1061: "cuda:0 f32[1, 32, 512, 128]" - # t1062 = prims.convert_element_type(t1054, dtypes.float32) # t1062: "cuda:0 f32[1, 32, 512, 128]" - # t1063 = prims.mul(t1062, t1061) # t1063: "cuda:0 f32[1, 32, 512, 128]" - t1066 = ltorch.mul(t1060, sin) # t1066: "cuda:0 f32[1, 32, 512, 128]" - # t1064 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1064: "cuda:0 f32[1, 32, 512, 128]" - # t1065 = prims.convert_element_type(t1060, dtypes.float32) # t1065: "cuda:0 f32[1, 32, 512, 128]" - # t1066 = prims.mul(t1065, t1064) # t1066: "cuda:0 f32[1, 32, 512, 128]" - t1067 = ltorch.add(t1063, t1066, alpha=None) # t1067: "cuda:0 f32[1, 32, 512, 128]" - # t1067 = prims.add(t1063, t1066) # t1067: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1068 = ltorch.to(t1067, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1068: "cuda:0 bf16[1, 32, 512, 128]" - # t1068 = prims.convert_element_type(t1067, dtypes.bfloat16) # t1068: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t1069 = ltorch.getitem(t1036, (..., slice(128, None, None))) # t1069: "cuda:0 bf16[1, 32, 512, 0]" - # t1069 = prims.slice_prim(t1036, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1069: "cuda:0 bf16[1, 32, 512, 0]" - t1070 = ltorch.cat((t1053, t1069), -1) # t1070: "cuda:0 bf16[1, 32, 512, 128]" - # t1070 = prims.cat((t1053, t1069), -1) # t1070: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t1071 = ltorch.getitem(t1037, (..., slice(128, None, None))) # t1071: "cuda:0 bf16[1, 32, 512, 0]" - # t1071 = prims.slice_prim(t1037, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1071: "cuda:0 bf16[1, 32, 512, 0]" - t1072 = ltorch.cat((t1068, t1071), -1) # t1072: "cuda:0 bf16[1, 32, 512, 128]" - # t1072 = prims.cat((t1068, t1071), -1) # t1072: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t1102 = ltorch.scaled_dot_product_attention(t1070, t1072, t1038, None, 0.0, True, scale=0.08838834764831843) # t1102: "cuda:0 bf16[1, 32, 512, 128]" - # t1075 = ltorch.mul(t1070, 0.29730177875068026) # t1075: "cuda:0 bf16[1, 32, 512, 128]" - # t1073 = prims.convert_element_type(t1070, dtypes.float32) # t1073: "cuda:0 f32[1, 32, 512, 128]" - # t1074 = prims.mul(t1073, 0.29730177875068026) # t1074: "cuda:0 f32[1, 32, 512, 128]" - # t1075 = prims.convert_element_type(t1074, dtypes.bfloat16) # t1075: "cuda:0 bf16[1, 32, 512, 128]" - # t1076 = ltorch.transpose(t1072, -2, -1) # t1076: "cuda:0 bf16[1, 32, 128, 512]" - # t1076 = prims.transpose(t1072, (0, 1, 3, 2)) # t1076: "cuda:0 bf16[1, 32, 128, 512]" - # t1079 = ltorch.mul(t1076, 0.29730177875068026) # t1079: "cuda:0 bf16[1, 32, 128, 512]" - # t1077 = prims.convert_element_type(t1076, dtypes.float32) # t1077: "cuda:0 f32[1, 32, 128, 512]" - # t1078 = prims.mul(t1077, 0.29730177875068026) # t1078: "cuda:0 f32[1, 32, 128, 512]" - # t1079 = prims.convert_element_type(t1078, dtypes.bfloat16) # t1079: "cuda:0 bf16[1, 32, 128, 512]" - # t1080 = ltorch.matmul(t1075, t1079) # t1080: "cuda:0 bf16[1, 32, 512, 512]" - # t1080 = prims.matmul(t1075, t1079) # t1080: "cuda:0 bf16[1, 32, 512, 512]" - # t1090 = ltorch.tril(t1080, 0, fill_value=-float('inf')) # t1090: "cuda:0 bf16[1, 32, 512, 512]" - # t1081 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1081: "cuda:0 i64[512]" - # t1081 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1081: "cuda:0 i64[512]" - # t1082 = ltorch.unsqueeze(t1081, -1) # t1082: "cuda:0 i64[512, 1]" - # t1082 = prims.broadcast_in_dim(t1081, [512, 1], [0]) # t1082: "cuda:0 i64[512, 1]" - # t1083 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1083: "cuda:0 i64[512]" - # t1083 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1083: "cuda:0 i64[512]" - # t1084 = ltorch.unsqueeze(t1083, -2) # t1084: "cuda:0 i64[1, 512]" - # t1084 = prims.broadcast_in_dim(t1083, [1, 512], [1]) # t1084: "cuda:0 i64[1, 512]" - # t1085 = ltorch.add(t1082, 0, alpha=None) # t1085: "cuda:0 i64[512, 1]" - # t1085 = prims.add(t1082, 0) # t1085: "cuda:0 i64[512, 1]" - # t1088 = ltorch.ge(t1085, t1084) # t1088: "cuda:0 b8[512, 512]" - # t1086 = prims.broadcast_in_dim(t1085, (512, 512), (0, 1)) # t1086: "cuda:0 i64[512, 512]" - # t1087 = prims.broadcast_in_dim(t1084, (512, 512), (0, 1)) # t1087: "cuda:0 i64[512, 512]" - # t1088 = prims.ge(t1086, t1087) # t1088: "cuda:0 b8[512, 512]" - # t1090 = ltorch.where(t1088, t1080, -float('inf')) # t1090: "cuda:0 bf16[1, 32, 512, 512]" - # t1089 = prims.broadcast_in_dim(t1088, (1, 32, 512, 512), (2, 3)) # t1089: "cuda:0 b8[1, 32, 512, 512]" - # t1090 = prims.where(t1089, t1080, -float('inf')) # t1090: "cuda:0 bf16[1, 32, 512, 512]" - # t1101 = ltorch._softmax(t1090, -1, dtype=None) # t1101: "cuda:0 bf16[1, 32, 512, 512]" - # t1091 = ltorch.to(t1090, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t1091: "cuda:0 f32[1, 32, 512, 512]" - # t1091 = prims.convert_element_type(t1090, dtypes.float32) # t1091: "cuda:0 f32[1, 32, 512, 512]" - # t1093 = ltorch.amax(t1091, -1, True) # t1093: "cuda:0 f32[1, 32, 512, 1]" - # t1092 = prims.amax(t1091, (3,)) # t1092: "cuda:0 f32[1, 32, 512]" - # t1093 = prims.broadcast_in_dim(t1092, [1, 32, 512, 1], [0, 1, 2]) # t1093: "cuda:0 f32[1, 32, 512, 1]" - # t1095 = ltorch.sub(t1091, t1093, alpha=None) # t1095: "cuda:0 f32[1, 32, 512, 512]" - # t1094 = prims.broadcast_in_dim(t1093, (1, 32, 512, 512), (0, 1, 2, 3)) # t1094: "cuda:0 f32[1, 32, 512, 512]" - # t1095 = prims.sub(t1091, t1094) # t1095: "cuda:0 f32[1, 32, 512, 512]" - # t1096 = ltorch.exp(t1095) # t1096: "cuda:0 f32[1, 32, 512, 512]" - # t1096 = prims.exp(t1095) # t1096: "cuda:0 f32[1, 32, 512, 512]" - # t1098 = ltorch.sum(t1096, -1, True, dtype=None) # t1098: "cuda:0 f32[1, 32, 512, 1]" - # t1097 = prims.sum(t1096, (3,)) # t1097: "cuda:0 f32[1, 32, 512]" - # t1098 = prims.broadcast_in_dim(t1097, [1, 32, 512, 1], [0, 1, 2]) # t1098: "cuda:0 f32[1, 32, 512, 1]" - # t1100 = ltorch.true_divide(t1096, t1098) # t1100: "cuda:0 f32[1, 32, 512, 512]" - # t1099 = prims.broadcast_in_dim(t1098, (1, 32, 512, 512), (0, 1, 2, 3)) # t1099: "cuda:0 f32[1, 32, 512, 512]" - # t1100 = prims.div(t1096, t1099) # t1100: "cuda:0 f32[1, 32, 512, 512]" - # t1101 = ltorch.to(t1100, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t1101: "cuda:0 bf16[1, 32, 512, 512]" - # t1101 = prims.convert_element_type(t1100, dtypes.bfloat16) # t1101: "cuda:0 bf16[1, 32, 512, 512]" - # t1102 = ltorch.matmul(t1101, t1038) # t1102: "cuda:0 bf16[1, 32, 512, 128]" - # t1102 = prims.matmul(t1101, t1038) # t1102: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t1103 = ltorch.transpose(t1102, 1, 2) # t1103: "cuda:0 bf16[1, 512, 32, 128]" - # t1103 = prims.transpose(t1102, (0, 2, 1, 3)) # t1103: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t1104 = ltorch.reshape(t1103, 1, 512, 4096) # t1104: "cuda:0 bf16[1, 512, 4096]" - # t1104 = prims.reshape(t1103, (1, 512, 4096)) # t1104: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1108 = ltorch.linear(t1104, t_transformer_h_6_attn_proj_weight, None) # t1108: "cuda:0 bf16[1, 512, 4096]" - # t1108 = prims.linear(t1104, t_transformer_h_6_attn_proj_weight, None) # t1108: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t1112 = ltorch.add(t1108, t1002, alpha=None) # t1112: "cuda:0 bf16[1, 512, 4096]" - # t1109 = prims.convert_element_type(t1108, dtypes.float32) # t1109: "cuda:0 f32[1, 512, 4096]" - # t1110 = prims.convert_element_type(t1002, dtypes.float32) # t1110: "cuda:0 f32[1, 512, 4096]" - # t1111 = prims.add(t1109, t1110) # t1111: "cuda:0 f32[1, 512, 4096]" - # t1112 = prims.convert_element_type(t1111, dtypes.bfloat16) # t1112: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1113 = prims.convert_element_type(t1112, dtypes.float32) # t1113: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1114 = ltorch.mul(t1113, t1113) # t1114: "cuda:0 f32[1, 512, 4096]" - # t1114 = prims.mul(t1113, t1113) # t1114: "cuda:0 f32[1, 512, 4096]" - t1118 = ltorch.mean(t1114, -1, True, dtype=None) # t1118: "cuda:0 f32[1, 512, 1]" - # t1116 = prims.sum(t1114, (2,)) # t1116: "cuda:0 f32[1, 512]" - # t1117 = prims.broadcast_in_dim(t1116, [1, 512, 1], [0, 1]) # t1117: "cuda:0 f32[1, 512, 1]" - # t1118 = ltorch.true_divide(t1117, 4096) # t1118: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1118 = prims.div(t1117, 4096.0) # t1118: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1120 = ltorch.add(t1118, 1e-05, alpha=None) # t1120: "cuda:0 f32[1, 512, 1]" - # t1120 = prims.add(t1118, 1e-05) # t1120: "cuda:0 f32[1, 512, 1]" - t1121 = ltorch.rsqrt(t1120) # t1121: "cuda:0 f32[1, 512, 1]" - # t1121 = prims.rsqrt(t1120) # t1121: "cuda:0 f32[1, 512, 1]" - t1123 = ltorch.mul(t1113, t1121) # t1123: "cuda:0 f32[1, 512, 4096]" - # t1122 = prims.broadcast_in_dim(t1121, (1, 512, 4096), (0, 1, 2)) # t1122: "cuda:0 f32[1, 512, 4096]" - # t1123 = prims.mul(t1113, t1122) # t1123: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1124 = ltorch.to(t1123, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1124: "cuda:0 bf16[1, 512, 4096]" - # t1124 = prims.convert_element_type(t1123, dtypes.bfloat16) # t1124: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1134 = ltorch.mul(t1124, t_transformer_h_6_norm_2_weight) # t1134: "cuda:0 bf16[1, 512, 4096]" - # t1130 = prims.broadcast_in_dim(t_transformer_h_6_norm_2_weight, (1, 512, 4096), (2,)) # t1130: "cuda:0 bf16[1, 512, 4096]" - # t1131 = prims.convert_element_type(t1124, dtypes.float32) # t1131: "cuda:0 f32[1, 512, 4096]" - # t1132 = prims.convert_element_type(t1130, dtypes.float32) # t1132: "cuda:0 f32[1, 512, 4096]" - # t1133 = prims.mul(t1131, t1132) # t1133: "cuda:0 f32[1, 512, 4096]" - # t1134 = prims.convert_element_type(t1133, dtypes.bfloat16) # t1134: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1139 = ltorch.linear(t1134, t_transformer_h_6_mlp_fc_1_weight, None) # t1139: "cuda:0 bf16[1, 512, 11008]" - # t1139 = prims.linear(t1134, t_transformer_h_6_mlp_fc_1_weight, None) # t1139: "cuda:0 bf16[1, 512, 11008]" - t1143 = ltorch.linear(t1134, t_transformer_h_6_mlp_fc_2_weight, None) # t1143: "cuda:0 bf16[1, 512, 11008]" - # t1143 = prims.linear(t1134, t_transformer_h_6_mlp_fc_2_weight, None) # t1143: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t1153 = ltorch.silu(t1139, False) # t1153: "cuda:0 bf16[1, 512, 11008]" - # t1144 = prims.convert_element_type(t1139, dtypes.float32) # t1144: "cuda:0 f32[1, 512, 11008]" - # t1145 = prims.neg(t1144) # t1145: "cuda:0 f32[1, 512, 11008]" - # t1146 = prims.exp(t1145) # t1146: "cuda:0 f32[1, 512, 11008]" - # t1147 = prims.add(1.0, t1146) # t1147: "cuda:0 f32[1, 512, 11008]" - # t1148 = prims.reciprocal(t1147) # t1148: "cuda:0 f32[1, 512, 11008]" - # t1149 = prims.convert_element_type(t1148, dtypes.bfloat16) # t1149: "cuda:0 bf16[1, 512, 11008]" - # t1150 = prims.convert_element_type(t1139, dtypes.float32) # t1150: "cuda:0 f32[1, 512, 11008]" - # t1151 = prims.convert_element_type(t1149, dtypes.float32) # t1151: "cuda:0 f32[1, 512, 11008]" - # t1152 = prims.mul(t1150, t1151) # t1152: "cuda:0 f32[1, 512, 11008]" - # t1153 = prims.convert_element_type(t1152, dtypes.bfloat16) # t1153: "cuda:0 bf16[1, 512, 11008]" - t1157 = ltorch.mul(t1153, t1143) # t1157: "cuda:0 bf16[1, 512, 11008]" - # t1154 = prims.convert_element_type(t1153, dtypes.float32) # t1154: "cuda:0 f32[1, 512, 11008]" - # t1155 = prims.convert_element_type(t1143, dtypes.float32) # t1155: "cuda:0 f32[1, 512, 11008]" - # t1156 = prims.mul(t1154, t1155) # t1156: "cuda:0 f32[1, 512, 11008]" - # t1157 = prims.convert_element_type(t1156, dtypes.bfloat16) # t1157: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1161 = ltorch.linear(t1157, t_transformer_h_6_mlp_proj_weight, None) # t1161: "cuda:0 bf16[1, 512, 4096]" - # t1161 = prims.linear(t1157, t_transformer_h_6_mlp_proj_weight, None) # t1161: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t1165 = ltorch.add(t1161, t1112, alpha=None) # t1165: "cuda:0 bf16[1, 512, 4096]" - # t1162 = prims.convert_element_type(t1161, dtypes.float32) # t1162: "cuda:0 f32[1, 512, 4096]" - # t1163 = prims.convert_element_type(t1112, dtypes.float32) # t1163: "cuda:0 f32[1, 512, 4096]" - # t1164 = prims.add(t1162, t1163) # t1164: "cuda:0 f32[1, 512, 4096]" - # t1165 = prims.convert_element_type(t1164, dtypes.bfloat16) # t1165: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1167 = prims.convert_element_type(t1165, dtypes.float32) # t1167: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1168 = ltorch.mul(t1167, t1167) # t1168: "cuda:0 f32[1, 512, 4096]" - # t1168 = prims.mul(t1167, t1167) # t1168: "cuda:0 f32[1, 512, 4096]" - t1172 = ltorch.mean(t1168, -1, True, dtype=None) # t1172: "cuda:0 f32[1, 512, 1]" - # t1170 = prims.sum(t1168, (2,)) # t1170: "cuda:0 f32[1, 512]" - # t1171 = prims.broadcast_in_dim(t1170, [1, 512, 1], [0, 1]) # t1171: "cuda:0 f32[1, 512, 1]" - # t1172 = ltorch.true_divide(t1171, 4096) # t1172: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1172 = prims.div(t1171, 4096.0) # t1172: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1174 = ltorch.add(t1172, 1e-05, alpha=None) # t1174: "cuda:0 f32[1, 512, 1]" - # t1174 = prims.add(t1172, 1e-05) # t1174: "cuda:0 f32[1, 512, 1]" - t1175 = ltorch.rsqrt(t1174) # t1175: "cuda:0 f32[1, 512, 1]" - # t1175 = prims.rsqrt(t1174) # t1175: "cuda:0 f32[1, 512, 1]" - t1177 = ltorch.mul(t1167, t1175) # t1177: "cuda:0 f32[1, 512, 4096]" - # t1176 = prims.broadcast_in_dim(t1175, (1, 512, 4096), (0, 1, 2)) # t1176: "cuda:0 f32[1, 512, 4096]" - # t1177 = prims.mul(t1167, t1176) # t1177: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1178 = ltorch.to(t1177, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1178: "cuda:0 bf16[1, 512, 4096]" - # t1178 = prims.convert_element_type(t1177, dtypes.bfloat16) # t1178: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1188 = ltorch.mul(t1178, t_transformer_h_7_norm_1_weight) # t1188: "cuda:0 bf16[1, 512, 4096]" - # t1184 = prims.broadcast_in_dim(t_transformer_h_7_norm_1_weight, (1, 512, 4096), (2,)) # t1184: "cuda:0 bf16[1, 512, 4096]" - # t1185 = prims.convert_element_type(t1178, dtypes.float32) # t1185: "cuda:0 f32[1, 512, 4096]" - # t1186 = prims.convert_element_type(t1184, dtypes.float32) # t1186: "cuda:0 f32[1, 512, 4096]" - # t1187 = prims.mul(t1185, t1186) # t1187: "cuda:0 f32[1, 512, 4096]" - # t1188 = prims.convert_element_type(t1187, dtypes.bfloat16) # t1188: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1193 = ltorch.linear(t1188, t_transformer_h_7_attn_attn_weight, None) # t1193: "cuda:0 bf16[1, 512, 12288]" - # t1193 = prims.linear(t1188, t_transformer_h_7_attn_attn_weight, None) # t1193: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t1194 = ltorch.view(t1193, 1, 512, 32, 3, 128) # t1194: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1194 = ltorch.reshape(t1193, (1, 512, 32, 3, 128)) # t1194: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1194 = prims.reshape(t1193, (1, 512, 32, 3, 128)) # t1194: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t1195 = ltorch.permute(t1194, 0, 2, 3, 1, 4) # t1195: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t1195 = prims.transpose(t1194, (0, 2, 3, 1, 4)) # t1195: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t1196, t1197, t1198) = ltorch.split(t1195, (1, 1, 1), 2) - # t1196 = prims.slice_prim(t1195, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1196: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1197 = prims.slice_prim(t1195, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1197: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1198 = prims.slice_prim(t1195, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1198: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t1199 = ltorch.reshape(t1196, 1, -1, 512, 128) # t1199: "cuda:0 bf16[1, 32, 512, 128]" - # t1199 = prims.reshape(t1196, (1, 32, 512, 128)) # t1199: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t1200 = ltorch.reshape(t1197, 1, -1, 512, 128) # t1200: "cuda:0 bf16[1, 32, 512, 128]" - # t1200 = prims.reshape(t1197, (1, 32, 512, 128)) # t1200: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t1201 = ltorch.reshape(t1198, 1, -1, 512, 128) # t1201: "cuda:0 bf16[1, 32, 512, 128]" - # t1201 = prims.reshape(t1198, (1, 32, 512, 128)) # t1201: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t1202 = ltorch.getitem(t1199, (..., slice(None, 128, None))) # t1202: "cuda:0 bf16[1, 32, 512, 128]" - # t1202 = prims.slice_prim(t1199, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1202: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1203 = ltorch.getitem(t1202, (..., slice(None, 64, None))) # t1203: "cuda:0 bf16[1, 32, 512, 64]" - # t1203 = prims.slice_prim(t1202, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1203: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1204 = ltorch.getitem(t1202, (..., slice(64, None, None))) # t1204: "cuda:0 bf16[1, 32, 512, 64]" - # t1204 = prims.slice_prim(t1202, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1204: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1207 = ltorch.neg(t1204) # t1207: "cuda:0 bf16[1, 32, 512, 64]" - # t1205 = prims.convert_element_type(t1204, dtypes.float32) # t1205: "cuda:0 f32[1, 32, 512, 64]" - # t1206 = prims.neg(t1205) # t1206: "cuda:0 f32[1, 32, 512, 64]" - # t1207 = prims.convert_element_type(t1206, dtypes.bfloat16) # t1207: "cuda:0 bf16[1, 32, 512, 64]" - t1208 = ltorch.cat((t1207, t1203), -1) # t1208: "cuda:0 bf16[1, 32, 512, 128]" - # t1208 = prims.cat((t1207, t1203), -1) # t1208: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1211 = ltorch.mul(t1202, cos) # t1211: "cuda:0 f32[1, 32, 512, 128]" - # t1209 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1209: "cuda:0 f32[1, 32, 512, 128]" - # t1210 = prims.convert_element_type(t1202, dtypes.float32) # t1210: "cuda:0 f32[1, 32, 512, 128]" - # t1211 = prims.mul(t1210, t1209) # t1211: "cuda:0 f32[1, 32, 512, 128]" - t1214 = ltorch.mul(t1208, sin) # t1214: "cuda:0 f32[1, 32, 512, 128]" - # t1212 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1212: "cuda:0 f32[1, 32, 512, 128]" - # t1213 = prims.convert_element_type(t1208, dtypes.float32) # t1213: "cuda:0 f32[1, 32, 512, 128]" - # t1214 = prims.mul(t1213, t1212) # t1214: "cuda:0 f32[1, 32, 512, 128]" - t1215 = ltorch.add(t1211, t1214, alpha=None) # t1215: "cuda:0 f32[1, 32, 512, 128]" - # t1215 = prims.add(t1211, t1214) # t1215: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1216 = ltorch.to(t1215, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1216: "cuda:0 bf16[1, 32, 512, 128]" - # t1216 = prims.convert_element_type(t1215, dtypes.bfloat16) # t1216: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t1217 = ltorch.getitem(t1200, (..., slice(None, 128, None))) # t1217: "cuda:0 bf16[1, 32, 512, 128]" - # t1217 = prims.slice_prim(t1200, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1217: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1218 = ltorch.getitem(t1217, (..., slice(None, 64, None))) # t1218: "cuda:0 bf16[1, 32, 512, 64]" - # t1218 = prims.slice_prim(t1217, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1218: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1219 = ltorch.getitem(t1217, (..., slice(64, None, None))) # t1219: "cuda:0 bf16[1, 32, 512, 64]" - # t1219 = prims.slice_prim(t1217, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1219: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1222 = ltorch.neg(t1219) # t1222: "cuda:0 bf16[1, 32, 512, 64]" - # t1220 = prims.convert_element_type(t1219, dtypes.float32) # t1220: "cuda:0 f32[1, 32, 512, 64]" - # t1221 = prims.neg(t1220) # t1221: "cuda:0 f32[1, 32, 512, 64]" - # t1222 = prims.convert_element_type(t1221, dtypes.bfloat16) # t1222: "cuda:0 bf16[1, 32, 512, 64]" - t1223 = ltorch.cat((t1222, t1218), -1) # t1223: "cuda:0 bf16[1, 32, 512, 128]" - # t1223 = prims.cat((t1222, t1218), -1) # t1223: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1226 = ltorch.mul(t1217, cos) # t1226: "cuda:0 f32[1, 32, 512, 128]" - # t1224 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1224: "cuda:0 f32[1, 32, 512, 128]" - # t1225 = prims.convert_element_type(t1217, dtypes.float32) # t1225: "cuda:0 f32[1, 32, 512, 128]" - # t1226 = prims.mul(t1225, t1224) # t1226: "cuda:0 f32[1, 32, 512, 128]" - t1229 = ltorch.mul(t1223, sin) # t1229: "cuda:0 f32[1, 32, 512, 128]" - # t1227 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1227: "cuda:0 f32[1, 32, 512, 128]" - # t1228 = prims.convert_element_type(t1223, dtypes.float32) # t1228: "cuda:0 f32[1, 32, 512, 128]" - # t1229 = prims.mul(t1228, t1227) # t1229: "cuda:0 f32[1, 32, 512, 128]" - t1230 = ltorch.add(t1226, t1229, alpha=None) # t1230: "cuda:0 f32[1, 32, 512, 128]" - # t1230 = prims.add(t1226, t1229) # t1230: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1231 = ltorch.to(t1230, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1231: "cuda:0 bf16[1, 32, 512, 128]" - # t1231 = prims.convert_element_type(t1230, dtypes.bfloat16) # t1231: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t1232 = ltorch.getitem(t1199, (..., slice(128, None, None))) # t1232: "cuda:0 bf16[1, 32, 512, 0]" - # t1232 = prims.slice_prim(t1199, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1232: "cuda:0 bf16[1, 32, 512, 0]" - t1233 = ltorch.cat((t1216, t1232), -1) # t1233: "cuda:0 bf16[1, 32, 512, 128]" - # t1233 = prims.cat((t1216, t1232), -1) # t1233: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t1234 = ltorch.getitem(t1200, (..., slice(128, None, None))) # t1234: "cuda:0 bf16[1, 32, 512, 0]" - # t1234 = prims.slice_prim(t1200, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1234: "cuda:0 bf16[1, 32, 512, 0]" - t1235 = ltorch.cat((t1231, t1234), -1) # t1235: "cuda:0 bf16[1, 32, 512, 128]" - # t1235 = prims.cat((t1231, t1234), -1) # t1235: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t1265 = ltorch.scaled_dot_product_attention(t1233, t1235, t1201, None, 0.0, True, scale=0.08838834764831843) # t1265: "cuda:0 bf16[1, 32, 512, 128]" - # t1238 = ltorch.mul(t1233, 0.29730177875068026) # t1238: "cuda:0 bf16[1, 32, 512, 128]" - # t1236 = prims.convert_element_type(t1233, dtypes.float32) # t1236: "cuda:0 f32[1, 32, 512, 128]" - # t1237 = prims.mul(t1236, 0.29730177875068026) # t1237: "cuda:0 f32[1, 32, 512, 128]" - # t1238 = prims.convert_element_type(t1237, dtypes.bfloat16) # t1238: "cuda:0 bf16[1, 32, 512, 128]" - # t1239 = ltorch.transpose(t1235, -2, -1) # t1239: "cuda:0 bf16[1, 32, 128, 512]" - # t1239 = prims.transpose(t1235, (0, 1, 3, 2)) # t1239: "cuda:0 bf16[1, 32, 128, 512]" - # t1242 = ltorch.mul(t1239, 0.29730177875068026) # t1242: "cuda:0 bf16[1, 32, 128, 512]" - # t1240 = prims.convert_element_type(t1239, dtypes.float32) # t1240: "cuda:0 f32[1, 32, 128, 512]" - # t1241 = prims.mul(t1240, 0.29730177875068026) # t1241: "cuda:0 f32[1, 32, 128, 512]" - # t1242 = prims.convert_element_type(t1241, dtypes.bfloat16) # t1242: "cuda:0 bf16[1, 32, 128, 512]" - # t1243 = ltorch.matmul(t1238, t1242) # t1243: "cuda:0 bf16[1, 32, 512, 512]" - # t1243 = prims.matmul(t1238, t1242) # t1243: "cuda:0 bf16[1, 32, 512, 512]" - # t1253 = ltorch.tril(t1243, 0, fill_value=-float('inf')) # t1253: "cuda:0 bf16[1, 32, 512, 512]" - # t1244 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1244: "cuda:0 i64[512]" - # t1244 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1244: "cuda:0 i64[512]" - # t1245 = ltorch.unsqueeze(t1244, -1) # t1245: "cuda:0 i64[512, 1]" - # t1245 = prims.broadcast_in_dim(t1244, [512, 1], [0]) # t1245: "cuda:0 i64[512, 1]" - # t1246 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1246: "cuda:0 i64[512]" - # t1246 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1246: "cuda:0 i64[512]" - # t1247 = ltorch.unsqueeze(t1246, -2) # t1247: "cuda:0 i64[1, 512]" - # t1247 = prims.broadcast_in_dim(t1246, [1, 512], [1]) # t1247: "cuda:0 i64[1, 512]" - # t1248 = ltorch.add(t1245, 0, alpha=None) # t1248: "cuda:0 i64[512, 1]" - # t1248 = prims.add(t1245, 0) # t1248: "cuda:0 i64[512, 1]" - # t1251 = ltorch.ge(t1248, t1247) # t1251: "cuda:0 b8[512, 512]" - # t1249 = prims.broadcast_in_dim(t1248, (512, 512), (0, 1)) # t1249: "cuda:0 i64[512, 512]" - # t1250 = prims.broadcast_in_dim(t1247, (512, 512), (0, 1)) # t1250: "cuda:0 i64[512, 512]" - # t1251 = prims.ge(t1249, t1250) # t1251: "cuda:0 b8[512, 512]" - # t1253 = ltorch.where(t1251, t1243, -float('inf')) # t1253: "cuda:0 bf16[1, 32, 512, 512]" - # t1252 = prims.broadcast_in_dim(t1251, (1, 32, 512, 512), (2, 3)) # t1252: "cuda:0 b8[1, 32, 512, 512]" - # t1253 = prims.where(t1252, t1243, -float('inf')) # t1253: "cuda:0 bf16[1, 32, 512, 512]" - # t1264 = ltorch._softmax(t1253, -1, dtype=None) # t1264: "cuda:0 bf16[1, 32, 512, 512]" - # t1254 = ltorch.to(t1253, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t1254: "cuda:0 f32[1, 32, 512, 512]" - # t1254 = prims.convert_element_type(t1253, dtypes.float32) # t1254: "cuda:0 f32[1, 32, 512, 512]" - # t1256 = ltorch.amax(t1254, -1, True) # t1256: "cuda:0 f32[1, 32, 512, 1]" - # t1255 = prims.amax(t1254, (3,)) # t1255: "cuda:0 f32[1, 32, 512]" - # t1256 = prims.broadcast_in_dim(t1255, [1, 32, 512, 1], [0, 1, 2]) # t1256: "cuda:0 f32[1, 32, 512, 1]" - # t1258 = ltorch.sub(t1254, t1256, alpha=None) # t1258: "cuda:0 f32[1, 32, 512, 512]" - # t1257 = prims.broadcast_in_dim(t1256, (1, 32, 512, 512), (0, 1, 2, 3)) # t1257: "cuda:0 f32[1, 32, 512, 512]" - # t1258 = prims.sub(t1254, t1257) # t1258: "cuda:0 f32[1, 32, 512, 512]" - # t1259 = ltorch.exp(t1258) # t1259: "cuda:0 f32[1, 32, 512, 512]" - # t1259 = prims.exp(t1258) # t1259: "cuda:0 f32[1, 32, 512, 512]" - # t1261 = ltorch.sum(t1259, -1, True, dtype=None) # t1261: "cuda:0 f32[1, 32, 512, 1]" - # t1260 = prims.sum(t1259, (3,)) # t1260: "cuda:0 f32[1, 32, 512]" - # t1261 = prims.broadcast_in_dim(t1260, [1, 32, 512, 1], [0, 1, 2]) # t1261: "cuda:0 f32[1, 32, 512, 1]" - # t1263 = ltorch.true_divide(t1259, t1261) # t1263: "cuda:0 f32[1, 32, 512, 512]" - # t1262 = prims.broadcast_in_dim(t1261, (1, 32, 512, 512), (0, 1, 2, 3)) # t1262: "cuda:0 f32[1, 32, 512, 512]" - # t1263 = prims.div(t1259, t1262) # t1263: "cuda:0 f32[1, 32, 512, 512]" - # t1264 = ltorch.to(t1263, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t1264: "cuda:0 bf16[1, 32, 512, 512]" - # t1264 = prims.convert_element_type(t1263, dtypes.bfloat16) # t1264: "cuda:0 bf16[1, 32, 512, 512]" - # t1265 = ltorch.matmul(t1264, t1201) # t1265: "cuda:0 bf16[1, 32, 512, 128]" - # t1265 = prims.matmul(t1264, t1201) # t1265: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t1266 = ltorch.transpose(t1265, 1, 2) # t1266: "cuda:0 bf16[1, 512, 32, 128]" - # t1266 = prims.transpose(t1265, (0, 2, 1, 3)) # t1266: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t1267 = ltorch.reshape(t1266, 1, 512, 4096) # t1267: "cuda:0 bf16[1, 512, 4096]" - # t1267 = prims.reshape(t1266, (1, 512, 4096)) # t1267: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1271 = ltorch.linear(t1267, t_transformer_h_7_attn_proj_weight, None) # t1271: "cuda:0 bf16[1, 512, 4096]" - # t1271 = prims.linear(t1267, t_transformer_h_7_attn_proj_weight, None) # t1271: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t1275 = ltorch.add(t1271, t1165, alpha=None) # t1275: "cuda:0 bf16[1, 512, 4096]" - # t1272 = prims.convert_element_type(t1271, dtypes.float32) # t1272: "cuda:0 f32[1, 512, 4096]" - # t1273 = prims.convert_element_type(t1165, dtypes.float32) # t1273: "cuda:0 f32[1, 512, 4096]" - # t1274 = prims.add(t1272, t1273) # t1274: "cuda:0 f32[1, 512, 4096]" - # t1275 = prims.convert_element_type(t1274, dtypes.bfloat16) # t1275: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1276 = prims.convert_element_type(t1275, dtypes.float32) # t1276: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1277 = ltorch.mul(t1276, t1276) # t1277: "cuda:0 f32[1, 512, 4096]" - # t1277 = prims.mul(t1276, t1276) # t1277: "cuda:0 f32[1, 512, 4096]" - t1281 = ltorch.mean(t1277, -1, True, dtype=None) # t1281: "cuda:0 f32[1, 512, 1]" - # t1279 = prims.sum(t1277, (2,)) # t1279: "cuda:0 f32[1, 512]" - # t1280 = prims.broadcast_in_dim(t1279, [1, 512, 1], [0, 1]) # t1280: "cuda:0 f32[1, 512, 1]" - # t1281 = ltorch.true_divide(t1280, 4096) # t1281: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1281 = prims.div(t1280, 4096.0) # t1281: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1283 = ltorch.add(t1281, 1e-05, alpha=None) # t1283: "cuda:0 f32[1, 512, 1]" - # t1283 = prims.add(t1281, 1e-05) # t1283: "cuda:0 f32[1, 512, 1]" - t1284 = ltorch.rsqrt(t1283) # t1284: "cuda:0 f32[1, 512, 1]" - # t1284 = prims.rsqrt(t1283) # t1284: "cuda:0 f32[1, 512, 1]" - t1286 = ltorch.mul(t1276, t1284) # t1286: "cuda:0 f32[1, 512, 4096]" - # t1285 = prims.broadcast_in_dim(t1284, (1, 512, 4096), (0, 1, 2)) # t1285: "cuda:0 f32[1, 512, 4096]" - # t1286 = prims.mul(t1276, t1285) # t1286: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1287 = ltorch.to(t1286, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1287: "cuda:0 bf16[1, 512, 4096]" - # t1287 = prims.convert_element_type(t1286, dtypes.bfloat16) # t1287: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1297 = ltorch.mul(t1287, t_transformer_h_7_norm_2_weight) # t1297: "cuda:0 bf16[1, 512, 4096]" - # t1293 = prims.broadcast_in_dim(t_transformer_h_7_norm_2_weight, (1, 512, 4096), (2,)) # t1293: "cuda:0 bf16[1, 512, 4096]" - # t1294 = prims.convert_element_type(t1287, dtypes.float32) # t1294: "cuda:0 f32[1, 512, 4096]" - # t1295 = prims.convert_element_type(t1293, dtypes.float32) # t1295: "cuda:0 f32[1, 512, 4096]" - # t1296 = prims.mul(t1294, t1295) # t1296: "cuda:0 f32[1, 512, 4096]" - # t1297 = prims.convert_element_type(t1296, dtypes.bfloat16) # t1297: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1302 = ltorch.linear(t1297, t_transformer_h_7_mlp_fc_1_weight, None) # t1302: "cuda:0 bf16[1, 512, 11008]" - # t1302 = prims.linear(t1297, t_transformer_h_7_mlp_fc_1_weight, None) # t1302: "cuda:0 bf16[1, 512, 11008]" - t1306 = ltorch.linear(t1297, t_transformer_h_7_mlp_fc_2_weight, None) # t1306: "cuda:0 bf16[1, 512, 11008]" - # t1306 = prims.linear(t1297, t_transformer_h_7_mlp_fc_2_weight, None) # t1306: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t1316 = ltorch.silu(t1302, False) # t1316: "cuda:0 bf16[1, 512, 11008]" - # t1307 = prims.convert_element_type(t1302, dtypes.float32) # t1307: "cuda:0 f32[1, 512, 11008]" - # t1308 = prims.neg(t1307) # t1308: "cuda:0 f32[1, 512, 11008]" - # t1309 = prims.exp(t1308) # t1309: "cuda:0 f32[1, 512, 11008]" - # t1310 = prims.add(1.0, t1309) # t1310: "cuda:0 f32[1, 512, 11008]" - # t1311 = prims.reciprocal(t1310) # t1311: "cuda:0 f32[1, 512, 11008]" - # t1312 = prims.convert_element_type(t1311, dtypes.bfloat16) # t1312: "cuda:0 bf16[1, 512, 11008]" - # t1313 = prims.convert_element_type(t1302, dtypes.float32) # t1313: "cuda:0 f32[1, 512, 11008]" - # t1314 = prims.convert_element_type(t1312, dtypes.float32) # t1314: "cuda:0 f32[1, 512, 11008]" - # t1315 = prims.mul(t1313, t1314) # t1315: "cuda:0 f32[1, 512, 11008]" - # t1316 = prims.convert_element_type(t1315, dtypes.bfloat16) # t1316: "cuda:0 bf16[1, 512, 11008]" - t1320 = ltorch.mul(t1316, t1306) # t1320: "cuda:0 bf16[1, 512, 11008]" - # t1317 = prims.convert_element_type(t1316, dtypes.float32) # t1317: "cuda:0 f32[1, 512, 11008]" - # t1318 = prims.convert_element_type(t1306, dtypes.float32) # t1318: "cuda:0 f32[1, 512, 11008]" - # t1319 = prims.mul(t1317, t1318) # t1319: "cuda:0 f32[1, 512, 11008]" - # t1320 = prims.convert_element_type(t1319, dtypes.bfloat16) # t1320: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1324 = ltorch.linear(t1320, t_transformer_h_7_mlp_proj_weight, None) # t1324: "cuda:0 bf16[1, 512, 4096]" - # t1324 = prims.linear(t1320, t_transformer_h_7_mlp_proj_weight, None) # t1324: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t1328 = ltorch.add(t1324, t1275, alpha=None) # t1328: "cuda:0 bf16[1, 512, 4096]" - # t1325 = prims.convert_element_type(t1324, dtypes.float32) # t1325: "cuda:0 f32[1, 512, 4096]" - # t1326 = prims.convert_element_type(t1275, dtypes.float32) # t1326: "cuda:0 f32[1, 512, 4096]" - # t1327 = prims.add(t1325, t1326) # t1327: "cuda:0 f32[1, 512, 4096]" - # t1328 = prims.convert_element_type(t1327, dtypes.bfloat16) # t1328: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1330 = prims.convert_element_type(t1328, dtypes.float32) # t1330: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1331 = ltorch.mul(t1330, t1330) # t1331: "cuda:0 f32[1, 512, 4096]" - # t1331 = prims.mul(t1330, t1330) # t1331: "cuda:0 f32[1, 512, 4096]" - t1335 = ltorch.mean(t1331, -1, True, dtype=None) # t1335: "cuda:0 f32[1, 512, 1]" - # t1333 = prims.sum(t1331, (2,)) # t1333: "cuda:0 f32[1, 512]" - # t1334 = prims.broadcast_in_dim(t1333, [1, 512, 1], [0, 1]) # t1334: "cuda:0 f32[1, 512, 1]" - # t1335 = ltorch.true_divide(t1334, 4096) # t1335: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1335 = prims.div(t1334, 4096.0) # t1335: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1337 = ltorch.add(t1335, 1e-05, alpha=None) # t1337: "cuda:0 f32[1, 512, 1]" - # t1337 = prims.add(t1335, 1e-05) # t1337: "cuda:0 f32[1, 512, 1]" - t1338 = ltorch.rsqrt(t1337) # t1338: "cuda:0 f32[1, 512, 1]" - # t1338 = prims.rsqrt(t1337) # t1338: "cuda:0 f32[1, 512, 1]" - t1340 = ltorch.mul(t1330, t1338) # t1340: "cuda:0 f32[1, 512, 4096]" - # t1339 = prims.broadcast_in_dim(t1338, (1, 512, 4096), (0, 1, 2)) # t1339: "cuda:0 f32[1, 512, 4096]" - # t1340 = prims.mul(t1330, t1339) # t1340: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1341 = ltorch.to(t1340, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1341: "cuda:0 bf16[1, 512, 4096]" - # t1341 = prims.convert_element_type(t1340, dtypes.bfloat16) # t1341: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1351 = ltorch.mul(t1341, t_transformer_h_8_norm_1_weight) # t1351: "cuda:0 bf16[1, 512, 4096]" - # t1347 = prims.broadcast_in_dim(t_transformer_h_8_norm_1_weight, (1, 512, 4096), (2,)) # t1347: "cuda:0 bf16[1, 512, 4096]" - # t1348 = prims.convert_element_type(t1341, dtypes.float32) # t1348: "cuda:0 f32[1, 512, 4096]" - # t1349 = prims.convert_element_type(t1347, dtypes.float32) # t1349: "cuda:0 f32[1, 512, 4096]" - # t1350 = prims.mul(t1348, t1349) # t1350: "cuda:0 f32[1, 512, 4096]" - # t1351 = prims.convert_element_type(t1350, dtypes.bfloat16) # t1351: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1356 = ltorch.linear(t1351, t_transformer_h_8_attn_attn_weight, None) # t1356: "cuda:0 bf16[1, 512, 12288]" - # t1356 = prims.linear(t1351, t_transformer_h_8_attn_attn_weight, None) # t1356: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t1357 = ltorch.view(t1356, 1, 512, 32, 3, 128) # t1357: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1357 = ltorch.reshape(t1356, (1, 512, 32, 3, 128)) # t1357: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1357 = prims.reshape(t1356, (1, 512, 32, 3, 128)) # t1357: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t1358 = ltorch.permute(t1357, 0, 2, 3, 1, 4) # t1358: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t1358 = prims.transpose(t1357, (0, 2, 3, 1, 4)) # t1358: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t1359, t1360, t1361) = ltorch.split(t1358, (1, 1, 1), 2) - # t1359 = prims.slice_prim(t1358, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1359: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1360 = prims.slice_prim(t1358, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1360: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1361 = prims.slice_prim(t1358, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1361: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t1362 = ltorch.reshape(t1359, 1, -1, 512, 128) # t1362: "cuda:0 bf16[1, 32, 512, 128]" - # t1362 = prims.reshape(t1359, (1, 32, 512, 128)) # t1362: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t1363 = ltorch.reshape(t1360, 1, -1, 512, 128) # t1363: "cuda:0 bf16[1, 32, 512, 128]" - # t1363 = prims.reshape(t1360, (1, 32, 512, 128)) # t1363: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t1364 = ltorch.reshape(t1361, 1, -1, 512, 128) # t1364: "cuda:0 bf16[1, 32, 512, 128]" - # t1364 = prims.reshape(t1361, (1, 32, 512, 128)) # t1364: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t1365 = ltorch.getitem(t1362, (..., slice(None, 128, None))) # t1365: "cuda:0 bf16[1, 32, 512, 128]" - # t1365 = prims.slice_prim(t1362, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1365: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1366 = ltorch.getitem(t1365, (..., slice(None, 64, None))) # t1366: "cuda:0 bf16[1, 32, 512, 64]" - # t1366 = prims.slice_prim(t1365, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1366: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1367 = ltorch.getitem(t1365, (..., slice(64, None, None))) # t1367: "cuda:0 bf16[1, 32, 512, 64]" - # t1367 = prims.slice_prim(t1365, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1367: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1370 = ltorch.neg(t1367) # t1370: "cuda:0 bf16[1, 32, 512, 64]" - # t1368 = prims.convert_element_type(t1367, dtypes.float32) # t1368: "cuda:0 f32[1, 32, 512, 64]" - # t1369 = prims.neg(t1368) # t1369: "cuda:0 f32[1, 32, 512, 64]" - # t1370 = prims.convert_element_type(t1369, dtypes.bfloat16) # t1370: "cuda:0 bf16[1, 32, 512, 64]" - t1371 = ltorch.cat((t1370, t1366), -1) # t1371: "cuda:0 bf16[1, 32, 512, 128]" - # t1371 = prims.cat((t1370, t1366), -1) # t1371: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1374 = ltorch.mul(t1365, cos) # t1374: "cuda:0 f32[1, 32, 512, 128]" - # t1372 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1372: "cuda:0 f32[1, 32, 512, 128]" - # t1373 = prims.convert_element_type(t1365, dtypes.float32) # t1373: "cuda:0 f32[1, 32, 512, 128]" - # t1374 = prims.mul(t1373, t1372) # t1374: "cuda:0 f32[1, 32, 512, 128]" - t1377 = ltorch.mul(t1371, sin) # t1377: "cuda:0 f32[1, 32, 512, 128]" - # t1375 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1375: "cuda:0 f32[1, 32, 512, 128]" - # t1376 = prims.convert_element_type(t1371, dtypes.float32) # t1376: "cuda:0 f32[1, 32, 512, 128]" - # t1377 = prims.mul(t1376, t1375) # t1377: "cuda:0 f32[1, 32, 512, 128]" - t1378 = ltorch.add(t1374, t1377, alpha=None) # t1378: "cuda:0 f32[1, 32, 512, 128]" - # t1378 = prims.add(t1374, t1377) # t1378: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1379 = ltorch.to(t1378, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1379: "cuda:0 bf16[1, 32, 512, 128]" - # t1379 = prims.convert_element_type(t1378, dtypes.bfloat16) # t1379: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t1380 = ltorch.getitem(t1363, (..., slice(None, 128, None))) # t1380: "cuda:0 bf16[1, 32, 512, 128]" - # t1380 = prims.slice_prim(t1363, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1380: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1381 = ltorch.getitem(t1380, (..., slice(None, 64, None))) # t1381: "cuda:0 bf16[1, 32, 512, 64]" - # t1381 = prims.slice_prim(t1380, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1381: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1382 = ltorch.getitem(t1380, (..., slice(64, None, None))) # t1382: "cuda:0 bf16[1, 32, 512, 64]" - # t1382 = prims.slice_prim(t1380, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1382: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1385 = ltorch.neg(t1382) # t1385: "cuda:0 bf16[1, 32, 512, 64]" - # t1383 = prims.convert_element_type(t1382, dtypes.float32) # t1383: "cuda:0 f32[1, 32, 512, 64]" - # t1384 = prims.neg(t1383) # t1384: "cuda:0 f32[1, 32, 512, 64]" - # t1385 = prims.convert_element_type(t1384, dtypes.bfloat16) # t1385: "cuda:0 bf16[1, 32, 512, 64]" - t1386 = ltorch.cat((t1385, t1381), -1) # t1386: "cuda:0 bf16[1, 32, 512, 128]" - # t1386 = prims.cat((t1385, t1381), -1) # t1386: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1389 = ltorch.mul(t1380, cos) # t1389: "cuda:0 f32[1, 32, 512, 128]" - # t1387 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1387: "cuda:0 f32[1, 32, 512, 128]" - # t1388 = prims.convert_element_type(t1380, dtypes.float32) # t1388: "cuda:0 f32[1, 32, 512, 128]" - # t1389 = prims.mul(t1388, t1387) # t1389: "cuda:0 f32[1, 32, 512, 128]" - t1392 = ltorch.mul(t1386, sin) # t1392: "cuda:0 f32[1, 32, 512, 128]" - # t1390 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1390: "cuda:0 f32[1, 32, 512, 128]" - # t1391 = prims.convert_element_type(t1386, dtypes.float32) # t1391: "cuda:0 f32[1, 32, 512, 128]" - # t1392 = prims.mul(t1391, t1390) # t1392: "cuda:0 f32[1, 32, 512, 128]" - t1393 = ltorch.add(t1389, t1392, alpha=None) # t1393: "cuda:0 f32[1, 32, 512, 128]" - # t1393 = prims.add(t1389, t1392) # t1393: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1394 = ltorch.to(t1393, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1394: "cuda:0 bf16[1, 32, 512, 128]" - # t1394 = prims.convert_element_type(t1393, dtypes.bfloat16) # t1394: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t1395 = ltorch.getitem(t1362, (..., slice(128, None, None))) # t1395: "cuda:0 bf16[1, 32, 512, 0]" - # t1395 = prims.slice_prim(t1362, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1395: "cuda:0 bf16[1, 32, 512, 0]" - t1396 = ltorch.cat((t1379, t1395), -1) # t1396: "cuda:0 bf16[1, 32, 512, 128]" - # t1396 = prims.cat((t1379, t1395), -1) # t1396: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t1397 = ltorch.getitem(t1363, (..., slice(128, None, None))) # t1397: "cuda:0 bf16[1, 32, 512, 0]" - # t1397 = prims.slice_prim(t1363, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1397: "cuda:0 bf16[1, 32, 512, 0]" - t1398 = ltorch.cat((t1394, t1397), -1) # t1398: "cuda:0 bf16[1, 32, 512, 128]" - # t1398 = prims.cat((t1394, t1397), -1) # t1398: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t1428 = ltorch.scaled_dot_product_attention(t1396, t1398, t1364, None, 0.0, True, scale=0.08838834764831843) # t1428: "cuda:0 bf16[1, 32, 512, 128]" - # t1401 = ltorch.mul(t1396, 0.29730177875068026) # t1401: "cuda:0 bf16[1, 32, 512, 128]" - # t1399 = prims.convert_element_type(t1396, dtypes.float32) # t1399: "cuda:0 f32[1, 32, 512, 128]" - # t1400 = prims.mul(t1399, 0.29730177875068026) # t1400: "cuda:0 f32[1, 32, 512, 128]" - # t1401 = prims.convert_element_type(t1400, dtypes.bfloat16) # t1401: "cuda:0 bf16[1, 32, 512, 128]" - # t1402 = ltorch.transpose(t1398, -2, -1) # t1402: "cuda:0 bf16[1, 32, 128, 512]" - # t1402 = prims.transpose(t1398, (0, 1, 3, 2)) # t1402: "cuda:0 bf16[1, 32, 128, 512]" - # t1405 = ltorch.mul(t1402, 0.29730177875068026) # t1405: "cuda:0 bf16[1, 32, 128, 512]" - # t1403 = prims.convert_element_type(t1402, dtypes.float32) # t1403: "cuda:0 f32[1, 32, 128, 512]" - # t1404 = prims.mul(t1403, 0.29730177875068026) # t1404: "cuda:0 f32[1, 32, 128, 512]" - # t1405 = prims.convert_element_type(t1404, dtypes.bfloat16) # t1405: "cuda:0 bf16[1, 32, 128, 512]" - # t1406 = ltorch.matmul(t1401, t1405) # t1406: "cuda:0 bf16[1, 32, 512, 512]" - # t1406 = prims.matmul(t1401, t1405) # t1406: "cuda:0 bf16[1, 32, 512, 512]" - # t1416 = ltorch.tril(t1406, 0, fill_value=-float('inf')) # t1416: "cuda:0 bf16[1, 32, 512, 512]" - # t1407 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1407: "cuda:0 i64[512]" - # t1407 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1407: "cuda:0 i64[512]" - # t1408 = ltorch.unsqueeze(t1407, -1) # t1408: "cuda:0 i64[512, 1]" - # t1408 = prims.broadcast_in_dim(t1407, [512, 1], [0]) # t1408: "cuda:0 i64[512, 1]" - # t1409 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1409: "cuda:0 i64[512]" - # t1409 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1409: "cuda:0 i64[512]" - # t1410 = ltorch.unsqueeze(t1409, -2) # t1410: "cuda:0 i64[1, 512]" - # t1410 = prims.broadcast_in_dim(t1409, [1, 512], [1]) # t1410: "cuda:0 i64[1, 512]" - # t1411 = ltorch.add(t1408, 0, alpha=None) # t1411: "cuda:0 i64[512, 1]" - # t1411 = prims.add(t1408, 0) # t1411: "cuda:0 i64[512, 1]" - # t1414 = ltorch.ge(t1411, t1410) # t1414: "cuda:0 b8[512, 512]" - # t1412 = prims.broadcast_in_dim(t1411, (512, 512), (0, 1)) # t1412: "cuda:0 i64[512, 512]" - # t1413 = prims.broadcast_in_dim(t1410, (512, 512), (0, 1)) # t1413: "cuda:0 i64[512, 512]" - # t1414 = prims.ge(t1412, t1413) # t1414: "cuda:0 b8[512, 512]" - # t1416 = ltorch.where(t1414, t1406, -float('inf')) # t1416: "cuda:0 bf16[1, 32, 512, 512]" - # t1415 = prims.broadcast_in_dim(t1414, (1, 32, 512, 512), (2, 3)) # t1415: "cuda:0 b8[1, 32, 512, 512]" - # t1416 = prims.where(t1415, t1406, -float('inf')) # t1416: "cuda:0 bf16[1, 32, 512, 512]" - # t1427 = ltorch._softmax(t1416, -1, dtype=None) # t1427: "cuda:0 bf16[1, 32, 512, 512]" - # t1417 = ltorch.to(t1416, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t1417: "cuda:0 f32[1, 32, 512, 512]" - # t1417 = prims.convert_element_type(t1416, dtypes.float32) # t1417: "cuda:0 f32[1, 32, 512, 512]" - # t1419 = ltorch.amax(t1417, -1, True) # t1419: "cuda:0 f32[1, 32, 512, 1]" - # t1418 = prims.amax(t1417, (3,)) # t1418: "cuda:0 f32[1, 32, 512]" - # t1419 = prims.broadcast_in_dim(t1418, [1, 32, 512, 1], [0, 1, 2]) # t1419: "cuda:0 f32[1, 32, 512, 1]" - # t1421 = ltorch.sub(t1417, t1419, alpha=None) # t1421: "cuda:0 f32[1, 32, 512, 512]" - # t1420 = prims.broadcast_in_dim(t1419, (1, 32, 512, 512), (0, 1, 2, 3)) # t1420: "cuda:0 f32[1, 32, 512, 512]" - # t1421 = prims.sub(t1417, t1420) # t1421: "cuda:0 f32[1, 32, 512, 512]" - # t1422 = ltorch.exp(t1421) # t1422: "cuda:0 f32[1, 32, 512, 512]" - # t1422 = prims.exp(t1421) # t1422: "cuda:0 f32[1, 32, 512, 512]" - # t1424 = ltorch.sum(t1422, -1, True, dtype=None) # t1424: "cuda:0 f32[1, 32, 512, 1]" - # t1423 = prims.sum(t1422, (3,)) # t1423: "cuda:0 f32[1, 32, 512]" - # t1424 = prims.broadcast_in_dim(t1423, [1, 32, 512, 1], [0, 1, 2]) # t1424: "cuda:0 f32[1, 32, 512, 1]" - # t1426 = ltorch.true_divide(t1422, t1424) # t1426: "cuda:0 f32[1, 32, 512, 512]" - # t1425 = prims.broadcast_in_dim(t1424, (1, 32, 512, 512), (0, 1, 2, 3)) # t1425: "cuda:0 f32[1, 32, 512, 512]" - # t1426 = prims.div(t1422, t1425) # t1426: "cuda:0 f32[1, 32, 512, 512]" - # t1427 = ltorch.to(t1426, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t1427: "cuda:0 bf16[1, 32, 512, 512]" - # t1427 = prims.convert_element_type(t1426, dtypes.bfloat16) # t1427: "cuda:0 bf16[1, 32, 512, 512]" - # t1428 = ltorch.matmul(t1427, t1364) # t1428: "cuda:0 bf16[1, 32, 512, 128]" - # t1428 = prims.matmul(t1427, t1364) # t1428: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t1429 = ltorch.transpose(t1428, 1, 2) # t1429: "cuda:0 bf16[1, 512, 32, 128]" - # t1429 = prims.transpose(t1428, (0, 2, 1, 3)) # t1429: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t1430 = ltorch.reshape(t1429, 1, 512, 4096) # t1430: "cuda:0 bf16[1, 512, 4096]" - # t1430 = prims.reshape(t1429, (1, 512, 4096)) # t1430: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1434 = ltorch.linear(t1430, t_transformer_h_8_attn_proj_weight, None) # t1434: "cuda:0 bf16[1, 512, 4096]" - # t1434 = prims.linear(t1430, t_transformer_h_8_attn_proj_weight, None) # t1434: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t1438 = ltorch.add(t1434, t1328, alpha=None) # t1438: "cuda:0 bf16[1, 512, 4096]" - # t1435 = prims.convert_element_type(t1434, dtypes.float32) # t1435: "cuda:0 f32[1, 512, 4096]" - # t1436 = prims.convert_element_type(t1328, dtypes.float32) # t1436: "cuda:0 f32[1, 512, 4096]" - # t1437 = prims.add(t1435, t1436) # t1437: "cuda:0 f32[1, 512, 4096]" - # t1438 = prims.convert_element_type(t1437, dtypes.bfloat16) # t1438: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1439 = prims.convert_element_type(t1438, dtypes.float32) # t1439: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1440 = ltorch.mul(t1439, t1439) # t1440: "cuda:0 f32[1, 512, 4096]" - # t1440 = prims.mul(t1439, t1439) # t1440: "cuda:0 f32[1, 512, 4096]" - t1444 = ltorch.mean(t1440, -1, True, dtype=None) # t1444: "cuda:0 f32[1, 512, 1]" - # t1442 = prims.sum(t1440, (2,)) # t1442: "cuda:0 f32[1, 512]" - # t1443 = prims.broadcast_in_dim(t1442, [1, 512, 1], [0, 1]) # t1443: "cuda:0 f32[1, 512, 1]" - # t1444 = ltorch.true_divide(t1443, 4096) # t1444: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1444 = prims.div(t1443, 4096.0) # t1444: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1446 = ltorch.add(t1444, 1e-05, alpha=None) # t1446: "cuda:0 f32[1, 512, 1]" - # t1446 = prims.add(t1444, 1e-05) # t1446: "cuda:0 f32[1, 512, 1]" - t1447 = ltorch.rsqrt(t1446) # t1447: "cuda:0 f32[1, 512, 1]" - # t1447 = prims.rsqrt(t1446) # t1447: "cuda:0 f32[1, 512, 1]" - t1449 = ltorch.mul(t1439, t1447) # t1449: "cuda:0 f32[1, 512, 4096]" - # t1448 = prims.broadcast_in_dim(t1447, (1, 512, 4096), (0, 1, 2)) # t1448: "cuda:0 f32[1, 512, 4096]" - # t1449 = prims.mul(t1439, t1448) # t1449: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1450 = ltorch.to(t1449, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1450: "cuda:0 bf16[1, 512, 4096]" - # t1450 = prims.convert_element_type(t1449, dtypes.bfloat16) # t1450: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1460 = ltorch.mul(t1450, t_transformer_h_8_norm_2_weight) # t1460: "cuda:0 bf16[1, 512, 4096]" - # t1456 = prims.broadcast_in_dim(t_transformer_h_8_norm_2_weight, (1, 512, 4096), (2,)) # t1456: "cuda:0 bf16[1, 512, 4096]" - # t1457 = prims.convert_element_type(t1450, dtypes.float32) # t1457: "cuda:0 f32[1, 512, 4096]" - # t1458 = prims.convert_element_type(t1456, dtypes.float32) # t1458: "cuda:0 f32[1, 512, 4096]" - # t1459 = prims.mul(t1457, t1458) # t1459: "cuda:0 f32[1, 512, 4096]" - # t1460 = prims.convert_element_type(t1459, dtypes.bfloat16) # t1460: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1465 = ltorch.linear(t1460, t_transformer_h_8_mlp_fc_1_weight, None) # t1465: "cuda:0 bf16[1, 512, 11008]" - # t1465 = prims.linear(t1460, t_transformer_h_8_mlp_fc_1_weight, None) # t1465: "cuda:0 bf16[1, 512, 11008]" - t1469 = ltorch.linear(t1460, t_transformer_h_8_mlp_fc_2_weight, None) # t1469: "cuda:0 bf16[1, 512, 11008]" - # t1469 = prims.linear(t1460, t_transformer_h_8_mlp_fc_2_weight, None) # t1469: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t1479 = ltorch.silu(t1465, False) # t1479: "cuda:0 bf16[1, 512, 11008]" - # t1470 = prims.convert_element_type(t1465, dtypes.float32) # t1470: "cuda:0 f32[1, 512, 11008]" - # t1471 = prims.neg(t1470) # t1471: "cuda:0 f32[1, 512, 11008]" - # t1472 = prims.exp(t1471) # t1472: "cuda:0 f32[1, 512, 11008]" - # t1473 = prims.add(1.0, t1472) # t1473: "cuda:0 f32[1, 512, 11008]" - # t1474 = prims.reciprocal(t1473) # t1474: "cuda:0 f32[1, 512, 11008]" - # t1475 = prims.convert_element_type(t1474, dtypes.bfloat16) # t1475: "cuda:0 bf16[1, 512, 11008]" - # t1476 = prims.convert_element_type(t1465, dtypes.float32) # t1476: "cuda:0 f32[1, 512, 11008]" - # t1477 = prims.convert_element_type(t1475, dtypes.float32) # t1477: "cuda:0 f32[1, 512, 11008]" - # t1478 = prims.mul(t1476, t1477) # t1478: "cuda:0 f32[1, 512, 11008]" - # t1479 = prims.convert_element_type(t1478, dtypes.bfloat16) # t1479: "cuda:0 bf16[1, 512, 11008]" - t1483 = ltorch.mul(t1479, t1469) # t1483: "cuda:0 bf16[1, 512, 11008]" - # t1480 = prims.convert_element_type(t1479, dtypes.float32) # t1480: "cuda:0 f32[1, 512, 11008]" - # t1481 = prims.convert_element_type(t1469, dtypes.float32) # t1481: "cuda:0 f32[1, 512, 11008]" - # t1482 = prims.mul(t1480, t1481) # t1482: "cuda:0 f32[1, 512, 11008]" - # t1483 = prims.convert_element_type(t1482, dtypes.bfloat16) # t1483: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1487 = ltorch.linear(t1483, t_transformer_h_8_mlp_proj_weight, None) # t1487: "cuda:0 bf16[1, 512, 4096]" - # t1487 = prims.linear(t1483, t_transformer_h_8_mlp_proj_weight, None) # t1487: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t1491 = ltorch.add(t1487, t1438, alpha=None) # t1491: "cuda:0 bf16[1, 512, 4096]" - # t1488 = prims.convert_element_type(t1487, dtypes.float32) # t1488: "cuda:0 f32[1, 512, 4096]" - # t1489 = prims.convert_element_type(t1438, dtypes.float32) # t1489: "cuda:0 f32[1, 512, 4096]" - # t1490 = prims.add(t1488, t1489) # t1490: "cuda:0 f32[1, 512, 4096]" - # t1491 = prims.convert_element_type(t1490, dtypes.bfloat16) # t1491: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1493 = prims.convert_element_type(t1491, dtypes.float32) # t1493: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1494 = ltorch.mul(t1493, t1493) # t1494: "cuda:0 f32[1, 512, 4096]" - # t1494 = prims.mul(t1493, t1493) # t1494: "cuda:0 f32[1, 512, 4096]" - t1498 = ltorch.mean(t1494, -1, True, dtype=None) # t1498: "cuda:0 f32[1, 512, 1]" - # t1496 = prims.sum(t1494, (2,)) # t1496: "cuda:0 f32[1, 512]" - # t1497 = prims.broadcast_in_dim(t1496, [1, 512, 1], [0, 1]) # t1497: "cuda:0 f32[1, 512, 1]" - # t1498 = ltorch.true_divide(t1497, 4096) # t1498: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1498 = prims.div(t1497, 4096.0) # t1498: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1500 = ltorch.add(t1498, 1e-05, alpha=None) # t1500: "cuda:0 f32[1, 512, 1]" - # t1500 = prims.add(t1498, 1e-05) # t1500: "cuda:0 f32[1, 512, 1]" - t1501 = ltorch.rsqrt(t1500) # t1501: "cuda:0 f32[1, 512, 1]" - # t1501 = prims.rsqrt(t1500) # t1501: "cuda:0 f32[1, 512, 1]" - t1503 = ltorch.mul(t1493, t1501) # t1503: "cuda:0 f32[1, 512, 4096]" - # t1502 = prims.broadcast_in_dim(t1501, (1, 512, 4096), (0, 1, 2)) # t1502: "cuda:0 f32[1, 512, 4096]" - # t1503 = prims.mul(t1493, t1502) # t1503: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1504 = ltorch.to(t1503, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1504: "cuda:0 bf16[1, 512, 4096]" - # t1504 = prims.convert_element_type(t1503, dtypes.bfloat16) # t1504: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1514 = ltorch.mul(t1504, t_transformer_h_9_norm_1_weight) # t1514: "cuda:0 bf16[1, 512, 4096]" - # t1510 = prims.broadcast_in_dim(t_transformer_h_9_norm_1_weight, (1, 512, 4096), (2,)) # t1510: "cuda:0 bf16[1, 512, 4096]" - # t1511 = prims.convert_element_type(t1504, dtypes.float32) # t1511: "cuda:0 f32[1, 512, 4096]" - # t1512 = prims.convert_element_type(t1510, dtypes.float32) # t1512: "cuda:0 f32[1, 512, 4096]" - # t1513 = prims.mul(t1511, t1512) # t1513: "cuda:0 f32[1, 512, 4096]" - # t1514 = prims.convert_element_type(t1513, dtypes.bfloat16) # t1514: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1519 = ltorch.linear(t1514, t_transformer_h_9_attn_attn_weight, None) # t1519: "cuda:0 bf16[1, 512, 12288]" - # t1519 = prims.linear(t1514, t_transformer_h_9_attn_attn_weight, None) # t1519: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t1520 = ltorch.view(t1519, 1, 512, 32, 3, 128) # t1520: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1520 = ltorch.reshape(t1519, (1, 512, 32, 3, 128)) # t1520: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1520 = prims.reshape(t1519, (1, 512, 32, 3, 128)) # t1520: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t1521 = ltorch.permute(t1520, 0, 2, 3, 1, 4) # t1521: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t1521 = prims.transpose(t1520, (0, 2, 3, 1, 4)) # t1521: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t1522, t1523, t1524) = ltorch.split(t1521, (1, 1, 1), 2) - # t1522 = prims.slice_prim(t1521, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1522: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1523 = prims.slice_prim(t1521, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1523: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1524 = prims.slice_prim(t1521, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1524: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t1525 = ltorch.reshape(t1522, 1, -1, 512, 128) # t1525: "cuda:0 bf16[1, 32, 512, 128]" - # t1525 = prims.reshape(t1522, (1, 32, 512, 128)) # t1525: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t1526 = ltorch.reshape(t1523, 1, -1, 512, 128) # t1526: "cuda:0 bf16[1, 32, 512, 128]" - # t1526 = prims.reshape(t1523, (1, 32, 512, 128)) # t1526: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t1527 = ltorch.reshape(t1524, 1, -1, 512, 128) # t1527: "cuda:0 bf16[1, 32, 512, 128]" - # t1527 = prims.reshape(t1524, (1, 32, 512, 128)) # t1527: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t1528 = ltorch.getitem(t1525, (..., slice(None, 128, None))) # t1528: "cuda:0 bf16[1, 32, 512, 128]" - # t1528 = prims.slice_prim(t1525, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1528: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1529 = ltorch.getitem(t1528, (..., slice(None, 64, None))) # t1529: "cuda:0 bf16[1, 32, 512, 64]" - # t1529 = prims.slice_prim(t1528, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1529: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1530 = ltorch.getitem(t1528, (..., slice(64, None, None))) # t1530: "cuda:0 bf16[1, 32, 512, 64]" - # t1530 = prims.slice_prim(t1528, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1530: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1533 = ltorch.neg(t1530) # t1533: "cuda:0 bf16[1, 32, 512, 64]" - # t1531 = prims.convert_element_type(t1530, dtypes.float32) # t1531: "cuda:0 f32[1, 32, 512, 64]" - # t1532 = prims.neg(t1531) # t1532: "cuda:0 f32[1, 32, 512, 64]" - # t1533 = prims.convert_element_type(t1532, dtypes.bfloat16) # t1533: "cuda:0 bf16[1, 32, 512, 64]" - t1534 = ltorch.cat((t1533, t1529), -1) # t1534: "cuda:0 bf16[1, 32, 512, 128]" - # t1534 = prims.cat((t1533, t1529), -1) # t1534: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1537 = ltorch.mul(t1528, cos) # t1537: "cuda:0 f32[1, 32, 512, 128]" - # t1535 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1535: "cuda:0 f32[1, 32, 512, 128]" - # t1536 = prims.convert_element_type(t1528, dtypes.float32) # t1536: "cuda:0 f32[1, 32, 512, 128]" - # t1537 = prims.mul(t1536, t1535) # t1537: "cuda:0 f32[1, 32, 512, 128]" - t1540 = ltorch.mul(t1534, sin) # t1540: "cuda:0 f32[1, 32, 512, 128]" - # t1538 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1538: "cuda:0 f32[1, 32, 512, 128]" - # t1539 = prims.convert_element_type(t1534, dtypes.float32) # t1539: "cuda:0 f32[1, 32, 512, 128]" - # t1540 = prims.mul(t1539, t1538) # t1540: "cuda:0 f32[1, 32, 512, 128]" - t1541 = ltorch.add(t1537, t1540, alpha=None) # t1541: "cuda:0 f32[1, 32, 512, 128]" - # t1541 = prims.add(t1537, t1540) # t1541: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1542 = ltorch.to(t1541, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1542: "cuda:0 bf16[1, 32, 512, 128]" - # t1542 = prims.convert_element_type(t1541, dtypes.bfloat16) # t1542: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t1543 = ltorch.getitem(t1526, (..., slice(None, 128, None))) # t1543: "cuda:0 bf16[1, 32, 512, 128]" - # t1543 = prims.slice_prim(t1526, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1543: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1544 = ltorch.getitem(t1543, (..., slice(None, 64, None))) # t1544: "cuda:0 bf16[1, 32, 512, 64]" - # t1544 = prims.slice_prim(t1543, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1544: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1545 = ltorch.getitem(t1543, (..., slice(64, None, None))) # t1545: "cuda:0 bf16[1, 32, 512, 64]" - # t1545 = prims.slice_prim(t1543, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1545: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1548 = ltorch.neg(t1545) # t1548: "cuda:0 bf16[1, 32, 512, 64]" - # t1546 = prims.convert_element_type(t1545, dtypes.float32) # t1546: "cuda:0 f32[1, 32, 512, 64]" - # t1547 = prims.neg(t1546) # t1547: "cuda:0 f32[1, 32, 512, 64]" - # t1548 = prims.convert_element_type(t1547, dtypes.bfloat16) # t1548: "cuda:0 bf16[1, 32, 512, 64]" - t1549 = ltorch.cat((t1548, t1544), -1) # t1549: "cuda:0 bf16[1, 32, 512, 128]" - # t1549 = prims.cat((t1548, t1544), -1) # t1549: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1552 = ltorch.mul(t1543, cos) # t1552: "cuda:0 f32[1, 32, 512, 128]" - # t1550 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1550: "cuda:0 f32[1, 32, 512, 128]" - # t1551 = prims.convert_element_type(t1543, dtypes.float32) # t1551: "cuda:0 f32[1, 32, 512, 128]" - # t1552 = prims.mul(t1551, t1550) # t1552: "cuda:0 f32[1, 32, 512, 128]" - t1555 = ltorch.mul(t1549, sin) # t1555: "cuda:0 f32[1, 32, 512, 128]" - # t1553 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1553: "cuda:0 f32[1, 32, 512, 128]" - # t1554 = prims.convert_element_type(t1549, dtypes.float32) # t1554: "cuda:0 f32[1, 32, 512, 128]" - # t1555 = prims.mul(t1554, t1553) # t1555: "cuda:0 f32[1, 32, 512, 128]" - t1556 = ltorch.add(t1552, t1555, alpha=None) # t1556: "cuda:0 f32[1, 32, 512, 128]" - # t1556 = prims.add(t1552, t1555) # t1556: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1557 = ltorch.to(t1556, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1557: "cuda:0 bf16[1, 32, 512, 128]" - # t1557 = prims.convert_element_type(t1556, dtypes.bfloat16) # t1557: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t1558 = ltorch.getitem(t1525, (..., slice(128, None, None))) # t1558: "cuda:0 bf16[1, 32, 512, 0]" - # t1558 = prims.slice_prim(t1525, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1558: "cuda:0 bf16[1, 32, 512, 0]" - t1559 = ltorch.cat((t1542, t1558), -1) # t1559: "cuda:0 bf16[1, 32, 512, 128]" - # t1559 = prims.cat((t1542, t1558), -1) # t1559: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t1560 = ltorch.getitem(t1526, (..., slice(128, None, None))) # t1560: "cuda:0 bf16[1, 32, 512, 0]" - # t1560 = prims.slice_prim(t1526, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1560: "cuda:0 bf16[1, 32, 512, 0]" - t1561 = ltorch.cat((t1557, t1560), -1) # t1561: "cuda:0 bf16[1, 32, 512, 128]" - # t1561 = prims.cat((t1557, t1560), -1) # t1561: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t1591 = ltorch.scaled_dot_product_attention(t1559, t1561, t1527, None, 0.0, True, scale=0.08838834764831843) # t1591: "cuda:0 bf16[1, 32, 512, 128]" - # t1564 = ltorch.mul(t1559, 0.29730177875068026) # t1564: "cuda:0 bf16[1, 32, 512, 128]" - # t1562 = prims.convert_element_type(t1559, dtypes.float32) # t1562: "cuda:0 f32[1, 32, 512, 128]" - # t1563 = prims.mul(t1562, 0.29730177875068026) # t1563: "cuda:0 f32[1, 32, 512, 128]" - # t1564 = prims.convert_element_type(t1563, dtypes.bfloat16) # t1564: "cuda:0 bf16[1, 32, 512, 128]" - # t1565 = ltorch.transpose(t1561, -2, -1) # t1565: "cuda:0 bf16[1, 32, 128, 512]" - # t1565 = prims.transpose(t1561, (0, 1, 3, 2)) # t1565: "cuda:0 bf16[1, 32, 128, 512]" - # t1568 = ltorch.mul(t1565, 0.29730177875068026) # t1568: "cuda:0 bf16[1, 32, 128, 512]" - # t1566 = prims.convert_element_type(t1565, dtypes.float32) # t1566: "cuda:0 f32[1, 32, 128, 512]" - # t1567 = prims.mul(t1566, 0.29730177875068026) # t1567: "cuda:0 f32[1, 32, 128, 512]" - # t1568 = prims.convert_element_type(t1567, dtypes.bfloat16) # t1568: "cuda:0 bf16[1, 32, 128, 512]" - # t1569 = ltorch.matmul(t1564, t1568) # t1569: "cuda:0 bf16[1, 32, 512, 512]" - # t1569 = prims.matmul(t1564, t1568) # t1569: "cuda:0 bf16[1, 32, 512, 512]" - # t1579 = ltorch.tril(t1569, 0, fill_value=-float('inf')) # t1579: "cuda:0 bf16[1, 32, 512, 512]" - # t1570 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1570: "cuda:0 i64[512]" - # t1570 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1570: "cuda:0 i64[512]" - # t1571 = ltorch.unsqueeze(t1570, -1) # t1571: "cuda:0 i64[512, 1]" - # t1571 = prims.broadcast_in_dim(t1570, [512, 1], [0]) # t1571: "cuda:0 i64[512, 1]" - # t1572 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1572: "cuda:0 i64[512]" - # t1572 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1572: "cuda:0 i64[512]" - # t1573 = ltorch.unsqueeze(t1572, -2) # t1573: "cuda:0 i64[1, 512]" - # t1573 = prims.broadcast_in_dim(t1572, [1, 512], [1]) # t1573: "cuda:0 i64[1, 512]" - # t1574 = ltorch.add(t1571, 0, alpha=None) # t1574: "cuda:0 i64[512, 1]" - # t1574 = prims.add(t1571, 0) # t1574: "cuda:0 i64[512, 1]" - # t1577 = ltorch.ge(t1574, t1573) # t1577: "cuda:0 b8[512, 512]" - # t1575 = prims.broadcast_in_dim(t1574, (512, 512), (0, 1)) # t1575: "cuda:0 i64[512, 512]" - # t1576 = prims.broadcast_in_dim(t1573, (512, 512), (0, 1)) # t1576: "cuda:0 i64[512, 512]" - # t1577 = prims.ge(t1575, t1576) # t1577: "cuda:0 b8[512, 512]" - # t1579 = ltorch.where(t1577, t1569, -float('inf')) # t1579: "cuda:0 bf16[1, 32, 512, 512]" - # t1578 = prims.broadcast_in_dim(t1577, (1, 32, 512, 512), (2, 3)) # t1578: "cuda:0 b8[1, 32, 512, 512]" - # t1579 = prims.where(t1578, t1569, -float('inf')) # t1579: "cuda:0 bf16[1, 32, 512, 512]" - # t1590 = ltorch._softmax(t1579, -1, dtype=None) # t1590: "cuda:0 bf16[1, 32, 512, 512]" - # t1580 = ltorch.to(t1579, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t1580: "cuda:0 f32[1, 32, 512, 512]" - # t1580 = prims.convert_element_type(t1579, dtypes.float32) # t1580: "cuda:0 f32[1, 32, 512, 512]" - # t1582 = ltorch.amax(t1580, -1, True) # t1582: "cuda:0 f32[1, 32, 512, 1]" - # t1581 = prims.amax(t1580, (3,)) # t1581: "cuda:0 f32[1, 32, 512]" - # t1582 = prims.broadcast_in_dim(t1581, [1, 32, 512, 1], [0, 1, 2]) # t1582: "cuda:0 f32[1, 32, 512, 1]" - # t1584 = ltorch.sub(t1580, t1582, alpha=None) # t1584: "cuda:0 f32[1, 32, 512, 512]" - # t1583 = prims.broadcast_in_dim(t1582, (1, 32, 512, 512), (0, 1, 2, 3)) # t1583: "cuda:0 f32[1, 32, 512, 512]" - # t1584 = prims.sub(t1580, t1583) # t1584: "cuda:0 f32[1, 32, 512, 512]" - # t1585 = ltorch.exp(t1584) # t1585: "cuda:0 f32[1, 32, 512, 512]" - # t1585 = prims.exp(t1584) # t1585: "cuda:0 f32[1, 32, 512, 512]" - # t1587 = ltorch.sum(t1585, -1, True, dtype=None) # t1587: "cuda:0 f32[1, 32, 512, 1]" - # t1586 = prims.sum(t1585, (3,)) # t1586: "cuda:0 f32[1, 32, 512]" - # t1587 = prims.broadcast_in_dim(t1586, [1, 32, 512, 1], [0, 1, 2]) # t1587: "cuda:0 f32[1, 32, 512, 1]" - # t1589 = ltorch.true_divide(t1585, t1587) # t1589: "cuda:0 f32[1, 32, 512, 512]" - # t1588 = prims.broadcast_in_dim(t1587, (1, 32, 512, 512), (0, 1, 2, 3)) # t1588: "cuda:0 f32[1, 32, 512, 512]" - # t1589 = prims.div(t1585, t1588) # t1589: "cuda:0 f32[1, 32, 512, 512]" - # t1590 = ltorch.to(t1589, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t1590: "cuda:0 bf16[1, 32, 512, 512]" - # t1590 = prims.convert_element_type(t1589, dtypes.bfloat16) # t1590: "cuda:0 bf16[1, 32, 512, 512]" - # t1591 = ltorch.matmul(t1590, t1527) # t1591: "cuda:0 bf16[1, 32, 512, 128]" - # t1591 = prims.matmul(t1590, t1527) # t1591: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t1592 = ltorch.transpose(t1591, 1, 2) # t1592: "cuda:0 bf16[1, 512, 32, 128]" - # t1592 = prims.transpose(t1591, (0, 2, 1, 3)) # t1592: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t1593 = ltorch.reshape(t1592, 1, 512, 4096) # t1593: "cuda:0 bf16[1, 512, 4096]" - # t1593 = prims.reshape(t1592, (1, 512, 4096)) # t1593: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1597 = ltorch.linear(t1593, t_transformer_h_9_attn_proj_weight, None) # t1597: "cuda:0 bf16[1, 512, 4096]" - # t1597 = prims.linear(t1593, t_transformer_h_9_attn_proj_weight, None) # t1597: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t1601 = ltorch.add(t1597, t1491, alpha=None) # t1601: "cuda:0 bf16[1, 512, 4096]" - # t1598 = prims.convert_element_type(t1597, dtypes.float32) # t1598: "cuda:0 f32[1, 512, 4096]" - # t1599 = prims.convert_element_type(t1491, dtypes.float32) # t1599: "cuda:0 f32[1, 512, 4096]" - # t1600 = prims.add(t1598, t1599) # t1600: "cuda:0 f32[1, 512, 4096]" - # t1601 = prims.convert_element_type(t1600, dtypes.bfloat16) # t1601: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1602 = prims.convert_element_type(t1601, dtypes.float32) # t1602: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1603 = ltorch.mul(t1602, t1602) # t1603: "cuda:0 f32[1, 512, 4096]" - # t1603 = prims.mul(t1602, t1602) # t1603: "cuda:0 f32[1, 512, 4096]" - t1607 = ltorch.mean(t1603, -1, True, dtype=None) # t1607: "cuda:0 f32[1, 512, 1]" - # t1605 = prims.sum(t1603, (2,)) # t1605: "cuda:0 f32[1, 512]" - # t1606 = prims.broadcast_in_dim(t1605, [1, 512, 1], [0, 1]) # t1606: "cuda:0 f32[1, 512, 1]" - # t1607 = ltorch.true_divide(t1606, 4096) # t1607: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1607 = prims.div(t1606, 4096.0) # t1607: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1609 = ltorch.add(t1607, 1e-05, alpha=None) # t1609: "cuda:0 f32[1, 512, 1]" - # t1609 = prims.add(t1607, 1e-05) # t1609: "cuda:0 f32[1, 512, 1]" - t1610 = ltorch.rsqrt(t1609) # t1610: "cuda:0 f32[1, 512, 1]" - # t1610 = prims.rsqrt(t1609) # t1610: "cuda:0 f32[1, 512, 1]" - t1612 = ltorch.mul(t1602, t1610) # t1612: "cuda:0 f32[1, 512, 4096]" - # t1611 = prims.broadcast_in_dim(t1610, (1, 512, 4096), (0, 1, 2)) # t1611: "cuda:0 f32[1, 512, 4096]" - # t1612 = prims.mul(t1602, t1611) # t1612: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1613 = ltorch.to(t1612, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1613: "cuda:0 bf16[1, 512, 4096]" - # t1613 = prims.convert_element_type(t1612, dtypes.bfloat16) # t1613: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1623 = ltorch.mul(t1613, t_transformer_h_9_norm_2_weight) # t1623: "cuda:0 bf16[1, 512, 4096]" - # t1619 = prims.broadcast_in_dim(t_transformer_h_9_norm_2_weight, (1, 512, 4096), (2,)) # t1619: "cuda:0 bf16[1, 512, 4096]" - # t1620 = prims.convert_element_type(t1613, dtypes.float32) # t1620: "cuda:0 f32[1, 512, 4096]" - # t1621 = prims.convert_element_type(t1619, dtypes.float32) # t1621: "cuda:0 f32[1, 512, 4096]" - # t1622 = prims.mul(t1620, t1621) # t1622: "cuda:0 f32[1, 512, 4096]" - # t1623 = prims.convert_element_type(t1622, dtypes.bfloat16) # t1623: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1628 = ltorch.linear(t1623, t_transformer_h_9_mlp_fc_1_weight, None) # t1628: "cuda:0 bf16[1, 512, 11008]" - # t1628 = prims.linear(t1623, t_transformer_h_9_mlp_fc_1_weight, None) # t1628: "cuda:0 bf16[1, 512, 11008]" - t1632 = ltorch.linear(t1623, t_transformer_h_9_mlp_fc_2_weight, None) # t1632: "cuda:0 bf16[1, 512, 11008]" - # t1632 = prims.linear(t1623, t_transformer_h_9_mlp_fc_2_weight, None) # t1632: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t1642 = ltorch.silu(t1628, False) # t1642: "cuda:0 bf16[1, 512, 11008]" - # t1633 = prims.convert_element_type(t1628, dtypes.float32) # t1633: "cuda:0 f32[1, 512, 11008]" - # t1634 = prims.neg(t1633) # t1634: "cuda:0 f32[1, 512, 11008]" - # t1635 = prims.exp(t1634) # t1635: "cuda:0 f32[1, 512, 11008]" - # t1636 = prims.add(1.0, t1635) # t1636: "cuda:0 f32[1, 512, 11008]" - # t1637 = prims.reciprocal(t1636) # t1637: "cuda:0 f32[1, 512, 11008]" - # t1638 = prims.convert_element_type(t1637, dtypes.bfloat16) # t1638: "cuda:0 bf16[1, 512, 11008]" - # t1639 = prims.convert_element_type(t1628, dtypes.float32) # t1639: "cuda:0 f32[1, 512, 11008]" - # t1640 = prims.convert_element_type(t1638, dtypes.float32) # t1640: "cuda:0 f32[1, 512, 11008]" - # t1641 = prims.mul(t1639, t1640) # t1641: "cuda:0 f32[1, 512, 11008]" - # t1642 = prims.convert_element_type(t1641, dtypes.bfloat16) # t1642: "cuda:0 bf16[1, 512, 11008]" - t1646 = ltorch.mul(t1642, t1632) # t1646: "cuda:0 bf16[1, 512, 11008]" - # t1643 = prims.convert_element_type(t1642, dtypes.float32) # t1643: "cuda:0 f32[1, 512, 11008]" - # t1644 = prims.convert_element_type(t1632, dtypes.float32) # t1644: "cuda:0 f32[1, 512, 11008]" - # t1645 = prims.mul(t1643, t1644) # t1645: "cuda:0 f32[1, 512, 11008]" - # t1646 = prims.convert_element_type(t1645, dtypes.bfloat16) # t1646: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1650 = ltorch.linear(t1646, t_transformer_h_9_mlp_proj_weight, None) # t1650: "cuda:0 bf16[1, 512, 4096]" - # t1650 = prims.linear(t1646, t_transformer_h_9_mlp_proj_weight, None) # t1650: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t1654 = ltorch.add(t1650, t1601, alpha=None) # t1654: "cuda:0 bf16[1, 512, 4096]" - # t1651 = prims.convert_element_type(t1650, dtypes.float32) # t1651: "cuda:0 f32[1, 512, 4096]" - # t1652 = prims.convert_element_type(t1601, dtypes.float32) # t1652: "cuda:0 f32[1, 512, 4096]" - # t1653 = prims.add(t1651, t1652) # t1653: "cuda:0 f32[1, 512, 4096]" - # t1654 = prims.convert_element_type(t1653, dtypes.bfloat16) # t1654: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1656 = prims.convert_element_type(t1654, dtypes.float32) # t1656: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1657 = ltorch.mul(t1656, t1656) # t1657: "cuda:0 f32[1, 512, 4096]" - # t1657 = prims.mul(t1656, t1656) # t1657: "cuda:0 f32[1, 512, 4096]" - t1661 = ltorch.mean(t1657, -1, True, dtype=None) # t1661: "cuda:0 f32[1, 512, 1]" - # t1659 = prims.sum(t1657, (2,)) # t1659: "cuda:0 f32[1, 512]" - # t1660 = prims.broadcast_in_dim(t1659, [1, 512, 1], [0, 1]) # t1660: "cuda:0 f32[1, 512, 1]" - # t1661 = ltorch.true_divide(t1660, 4096) # t1661: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1661 = prims.div(t1660, 4096.0) # t1661: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1663 = ltorch.add(t1661, 1e-05, alpha=None) # t1663: "cuda:0 f32[1, 512, 1]" - # t1663 = prims.add(t1661, 1e-05) # t1663: "cuda:0 f32[1, 512, 1]" - t1664 = ltorch.rsqrt(t1663) # t1664: "cuda:0 f32[1, 512, 1]" - # t1664 = prims.rsqrt(t1663) # t1664: "cuda:0 f32[1, 512, 1]" - t1666 = ltorch.mul(t1656, t1664) # t1666: "cuda:0 f32[1, 512, 4096]" - # t1665 = prims.broadcast_in_dim(t1664, (1, 512, 4096), (0, 1, 2)) # t1665: "cuda:0 f32[1, 512, 4096]" - # t1666 = prims.mul(t1656, t1665) # t1666: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1667 = ltorch.to(t1666, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1667: "cuda:0 bf16[1, 512, 4096]" - # t1667 = prims.convert_element_type(t1666, dtypes.bfloat16) # t1667: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1677 = ltorch.mul(t1667, t_transformer_h_10_norm_1_weight) # t1677: "cuda:0 bf16[1, 512, 4096]" - # t1673 = prims.broadcast_in_dim(t_transformer_h_10_norm_1_weight, (1, 512, 4096), (2,)) # t1673: "cuda:0 bf16[1, 512, 4096]" - # t1674 = prims.convert_element_type(t1667, dtypes.float32) # t1674: "cuda:0 f32[1, 512, 4096]" - # t1675 = prims.convert_element_type(t1673, dtypes.float32) # t1675: "cuda:0 f32[1, 512, 4096]" - # t1676 = prims.mul(t1674, t1675) # t1676: "cuda:0 f32[1, 512, 4096]" - # t1677 = prims.convert_element_type(t1676, dtypes.bfloat16) # t1677: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1682 = ltorch.linear(t1677, t_transformer_h_10_attn_attn_weight, None) # t1682: "cuda:0 bf16[1, 512, 12288]" - # t1682 = prims.linear(t1677, t_transformer_h_10_attn_attn_weight, None) # t1682: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t1683 = ltorch.view(t1682, 1, 512, 32, 3, 128) # t1683: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1683 = ltorch.reshape(t1682, (1, 512, 32, 3, 128)) # t1683: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1683 = prims.reshape(t1682, (1, 512, 32, 3, 128)) # t1683: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t1684 = ltorch.permute(t1683, 0, 2, 3, 1, 4) # t1684: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t1684 = prims.transpose(t1683, (0, 2, 3, 1, 4)) # t1684: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t1685, t1686, t1687) = ltorch.split(t1684, (1, 1, 1), 2) - # t1685 = prims.slice_prim(t1684, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1685: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1686 = prims.slice_prim(t1684, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1686: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1687 = prims.slice_prim(t1684, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1687: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t1688 = ltorch.reshape(t1685, 1, -1, 512, 128) # t1688: "cuda:0 bf16[1, 32, 512, 128]" - # t1688 = prims.reshape(t1685, (1, 32, 512, 128)) # t1688: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t1689 = ltorch.reshape(t1686, 1, -1, 512, 128) # t1689: "cuda:0 bf16[1, 32, 512, 128]" - # t1689 = prims.reshape(t1686, (1, 32, 512, 128)) # t1689: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t1690 = ltorch.reshape(t1687, 1, -1, 512, 128) # t1690: "cuda:0 bf16[1, 32, 512, 128]" - # t1690 = prims.reshape(t1687, (1, 32, 512, 128)) # t1690: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t1691 = ltorch.getitem(t1688, (..., slice(None, 128, None))) # t1691: "cuda:0 bf16[1, 32, 512, 128]" - # t1691 = prims.slice_prim(t1688, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1691: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1692 = ltorch.getitem(t1691, (..., slice(None, 64, None))) # t1692: "cuda:0 bf16[1, 32, 512, 64]" - # t1692 = prims.slice_prim(t1691, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1692: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1693 = ltorch.getitem(t1691, (..., slice(64, None, None))) # t1693: "cuda:0 bf16[1, 32, 512, 64]" - # t1693 = prims.slice_prim(t1691, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1693: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1696 = ltorch.neg(t1693) # t1696: "cuda:0 bf16[1, 32, 512, 64]" - # t1694 = prims.convert_element_type(t1693, dtypes.float32) # t1694: "cuda:0 f32[1, 32, 512, 64]" - # t1695 = prims.neg(t1694) # t1695: "cuda:0 f32[1, 32, 512, 64]" - # t1696 = prims.convert_element_type(t1695, dtypes.bfloat16) # t1696: "cuda:0 bf16[1, 32, 512, 64]" - t1697 = ltorch.cat((t1696, t1692), -1) # t1697: "cuda:0 bf16[1, 32, 512, 128]" - # t1697 = prims.cat((t1696, t1692), -1) # t1697: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1700 = ltorch.mul(t1691, cos) # t1700: "cuda:0 f32[1, 32, 512, 128]" - # t1698 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1698: "cuda:0 f32[1, 32, 512, 128]" - # t1699 = prims.convert_element_type(t1691, dtypes.float32) # t1699: "cuda:0 f32[1, 32, 512, 128]" - # t1700 = prims.mul(t1699, t1698) # t1700: "cuda:0 f32[1, 32, 512, 128]" - t1703 = ltorch.mul(t1697, sin) # t1703: "cuda:0 f32[1, 32, 512, 128]" - # t1701 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1701: "cuda:0 f32[1, 32, 512, 128]" - # t1702 = prims.convert_element_type(t1697, dtypes.float32) # t1702: "cuda:0 f32[1, 32, 512, 128]" - # t1703 = prims.mul(t1702, t1701) # t1703: "cuda:0 f32[1, 32, 512, 128]" - t1704 = ltorch.add(t1700, t1703, alpha=None) # t1704: "cuda:0 f32[1, 32, 512, 128]" - # t1704 = prims.add(t1700, t1703) # t1704: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1705 = ltorch.to(t1704, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1705: "cuda:0 bf16[1, 32, 512, 128]" - # t1705 = prims.convert_element_type(t1704, dtypes.bfloat16) # t1705: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t1706 = ltorch.getitem(t1689, (..., slice(None, 128, None))) # t1706: "cuda:0 bf16[1, 32, 512, 128]" - # t1706 = prims.slice_prim(t1689, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1706: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1707 = ltorch.getitem(t1706, (..., slice(None, 64, None))) # t1707: "cuda:0 bf16[1, 32, 512, 64]" - # t1707 = prims.slice_prim(t1706, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1707: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1708 = ltorch.getitem(t1706, (..., slice(64, None, None))) # t1708: "cuda:0 bf16[1, 32, 512, 64]" - # t1708 = prims.slice_prim(t1706, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1708: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1711 = ltorch.neg(t1708) # t1711: "cuda:0 bf16[1, 32, 512, 64]" - # t1709 = prims.convert_element_type(t1708, dtypes.float32) # t1709: "cuda:0 f32[1, 32, 512, 64]" - # t1710 = prims.neg(t1709) # t1710: "cuda:0 f32[1, 32, 512, 64]" - # t1711 = prims.convert_element_type(t1710, dtypes.bfloat16) # t1711: "cuda:0 bf16[1, 32, 512, 64]" - t1712 = ltorch.cat((t1711, t1707), -1) # t1712: "cuda:0 bf16[1, 32, 512, 128]" - # t1712 = prims.cat((t1711, t1707), -1) # t1712: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1715 = ltorch.mul(t1706, cos) # t1715: "cuda:0 f32[1, 32, 512, 128]" - # t1713 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1713: "cuda:0 f32[1, 32, 512, 128]" - # t1714 = prims.convert_element_type(t1706, dtypes.float32) # t1714: "cuda:0 f32[1, 32, 512, 128]" - # t1715 = prims.mul(t1714, t1713) # t1715: "cuda:0 f32[1, 32, 512, 128]" - t1718 = ltorch.mul(t1712, sin) # t1718: "cuda:0 f32[1, 32, 512, 128]" - # t1716 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1716: "cuda:0 f32[1, 32, 512, 128]" - # t1717 = prims.convert_element_type(t1712, dtypes.float32) # t1717: "cuda:0 f32[1, 32, 512, 128]" - # t1718 = prims.mul(t1717, t1716) # t1718: "cuda:0 f32[1, 32, 512, 128]" - t1719 = ltorch.add(t1715, t1718, alpha=None) # t1719: "cuda:0 f32[1, 32, 512, 128]" - # t1719 = prims.add(t1715, t1718) # t1719: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1720 = ltorch.to(t1719, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1720: "cuda:0 bf16[1, 32, 512, 128]" - # t1720 = prims.convert_element_type(t1719, dtypes.bfloat16) # t1720: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t1721 = ltorch.getitem(t1688, (..., slice(128, None, None))) # t1721: "cuda:0 bf16[1, 32, 512, 0]" - # t1721 = prims.slice_prim(t1688, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1721: "cuda:0 bf16[1, 32, 512, 0]" - t1722 = ltorch.cat((t1705, t1721), -1) # t1722: "cuda:0 bf16[1, 32, 512, 128]" - # t1722 = prims.cat((t1705, t1721), -1) # t1722: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t1723 = ltorch.getitem(t1689, (..., slice(128, None, None))) # t1723: "cuda:0 bf16[1, 32, 512, 0]" - # t1723 = prims.slice_prim(t1689, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1723: "cuda:0 bf16[1, 32, 512, 0]" - t1724 = ltorch.cat((t1720, t1723), -1) # t1724: "cuda:0 bf16[1, 32, 512, 128]" - # t1724 = prims.cat((t1720, t1723), -1) # t1724: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t1754 = ltorch.scaled_dot_product_attention(t1722, t1724, t1690, None, 0.0, True, scale=0.08838834764831843) # t1754: "cuda:0 bf16[1, 32, 512, 128]" - # t1727 = ltorch.mul(t1722, 0.29730177875068026) # t1727: "cuda:0 bf16[1, 32, 512, 128]" - # t1725 = prims.convert_element_type(t1722, dtypes.float32) # t1725: "cuda:0 f32[1, 32, 512, 128]" - # t1726 = prims.mul(t1725, 0.29730177875068026) # t1726: "cuda:0 f32[1, 32, 512, 128]" - # t1727 = prims.convert_element_type(t1726, dtypes.bfloat16) # t1727: "cuda:0 bf16[1, 32, 512, 128]" - # t1728 = ltorch.transpose(t1724, -2, -1) # t1728: "cuda:0 bf16[1, 32, 128, 512]" - # t1728 = prims.transpose(t1724, (0, 1, 3, 2)) # t1728: "cuda:0 bf16[1, 32, 128, 512]" - # t1731 = ltorch.mul(t1728, 0.29730177875068026) # t1731: "cuda:0 bf16[1, 32, 128, 512]" - # t1729 = prims.convert_element_type(t1728, dtypes.float32) # t1729: "cuda:0 f32[1, 32, 128, 512]" - # t1730 = prims.mul(t1729, 0.29730177875068026) # t1730: "cuda:0 f32[1, 32, 128, 512]" - # t1731 = prims.convert_element_type(t1730, dtypes.bfloat16) # t1731: "cuda:0 bf16[1, 32, 128, 512]" - # t1732 = ltorch.matmul(t1727, t1731) # t1732: "cuda:0 bf16[1, 32, 512, 512]" - # t1732 = prims.matmul(t1727, t1731) # t1732: "cuda:0 bf16[1, 32, 512, 512]" - # t1742 = ltorch.tril(t1732, 0, fill_value=-float('inf')) # t1742: "cuda:0 bf16[1, 32, 512, 512]" - # t1733 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1733: "cuda:0 i64[512]" - # t1733 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1733: "cuda:0 i64[512]" - # t1734 = ltorch.unsqueeze(t1733, -1) # t1734: "cuda:0 i64[512, 1]" - # t1734 = prims.broadcast_in_dim(t1733, [512, 1], [0]) # t1734: "cuda:0 i64[512, 1]" - # t1735 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1735: "cuda:0 i64[512]" - # t1735 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1735: "cuda:0 i64[512]" - # t1736 = ltorch.unsqueeze(t1735, -2) # t1736: "cuda:0 i64[1, 512]" - # t1736 = prims.broadcast_in_dim(t1735, [1, 512], [1]) # t1736: "cuda:0 i64[1, 512]" - # t1737 = ltorch.add(t1734, 0, alpha=None) # t1737: "cuda:0 i64[512, 1]" - # t1737 = prims.add(t1734, 0) # t1737: "cuda:0 i64[512, 1]" - # t1740 = ltorch.ge(t1737, t1736) # t1740: "cuda:0 b8[512, 512]" - # t1738 = prims.broadcast_in_dim(t1737, (512, 512), (0, 1)) # t1738: "cuda:0 i64[512, 512]" - # t1739 = prims.broadcast_in_dim(t1736, (512, 512), (0, 1)) # t1739: "cuda:0 i64[512, 512]" - # t1740 = prims.ge(t1738, t1739) # t1740: "cuda:0 b8[512, 512]" - # t1742 = ltorch.where(t1740, t1732, -float('inf')) # t1742: "cuda:0 bf16[1, 32, 512, 512]" - # t1741 = prims.broadcast_in_dim(t1740, (1, 32, 512, 512), (2, 3)) # t1741: "cuda:0 b8[1, 32, 512, 512]" - # t1742 = prims.where(t1741, t1732, -float('inf')) # t1742: "cuda:0 bf16[1, 32, 512, 512]" - # t1753 = ltorch._softmax(t1742, -1, dtype=None) # t1753: "cuda:0 bf16[1, 32, 512, 512]" - # t1743 = ltorch.to(t1742, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t1743: "cuda:0 f32[1, 32, 512, 512]" - # t1743 = prims.convert_element_type(t1742, dtypes.float32) # t1743: "cuda:0 f32[1, 32, 512, 512]" - # t1745 = ltorch.amax(t1743, -1, True) # t1745: "cuda:0 f32[1, 32, 512, 1]" - # t1744 = prims.amax(t1743, (3,)) # t1744: "cuda:0 f32[1, 32, 512]" - # t1745 = prims.broadcast_in_dim(t1744, [1, 32, 512, 1], [0, 1, 2]) # t1745: "cuda:0 f32[1, 32, 512, 1]" - # t1747 = ltorch.sub(t1743, t1745, alpha=None) # t1747: "cuda:0 f32[1, 32, 512, 512]" - # t1746 = prims.broadcast_in_dim(t1745, (1, 32, 512, 512), (0, 1, 2, 3)) # t1746: "cuda:0 f32[1, 32, 512, 512]" - # t1747 = prims.sub(t1743, t1746) # t1747: "cuda:0 f32[1, 32, 512, 512]" - # t1748 = ltorch.exp(t1747) # t1748: "cuda:0 f32[1, 32, 512, 512]" - # t1748 = prims.exp(t1747) # t1748: "cuda:0 f32[1, 32, 512, 512]" - # t1750 = ltorch.sum(t1748, -1, True, dtype=None) # t1750: "cuda:0 f32[1, 32, 512, 1]" - # t1749 = prims.sum(t1748, (3,)) # t1749: "cuda:0 f32[1, 32, 512]" - # t1750 = prims.broadcast_in_dim(t1749, [1, 32, 512, 1], [0, 1, 2]) # t1750: "cuda:0 f32[1, 32, 512, 1]" - # t1752 = ltorch.true_divide(t1748, t1750) # t1752: "cuda:0 f32[1, 32, 512, 512]" - # t1751 = prims.broadcast_in_dim(t1750, (1, 32, 512, 512), (0, 1, 2, 3)) # t1751: "cuda:0 f32[1, 32, 512, 512]" - # t1752 = prims.div(t1748, t1751) # t1752: "cuda:0 f32[1, 32, 512, 512]" - # t1753 = ltorch.to(t1752, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t1753: "cuda:0 bf16[1, 32, 512, 512]" - # t1753 = prims.convert_element_type(t1752, dtypes.bfloat16) # t1753: "cuda:0 bf16[1, 32, 512, 512]" - # t1754 = ltorch.matmul(t1753, t1690) # t1754: "cuda:0 bf16[1, 32, 512, 128]" - # t1754 = prims.matmul(t1753, t1690) # t1754: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t1755 = ltorch.transpose(t1754, 1, 2) # t1755: "cuda:0 bf16[1, 512, 32, 128]" - # t1755 = prims.transpose(t1754, (0, 2, 1, 3)) # t1755: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t1756 = ltorch.reshape(t1755, 1, 512, 4096) # t1756: "cuda:0 bf16[1, 512, 4096]" - # t1756 = prims.reshape(t1755, (1, 512, 4096)) # t1756: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1760 = ltorch.linear(t1756, t_transformer_h_10_attn_proj_weight, None) # t1760: "cuda:0 bf16[1, 512, 4096]" - # t1760 = prims.linear(t1756, t_transformer_h_10_attn_proj_weight, None) # t1760: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t1764 = ltorch.add(t1760, t1654, alpha=None) # t1764: "cuda:0 bf16[1, 512, 4096]" - # t1761 = prims.convert_element_type(t1760, dtypes.float32) # t1761: "cuda:0 f32[1, 512, 4096]" - # t1762 = prims.convert_element_type(t1654, dtypes.float32) # t1762: "cuda:0 f32[1, 512, 4096]" - # t1763 = prims.add(t1761, t1762) # t1763: "cuda:0 f32[1, 512, 4096]" - # t1764 = prims.convert_element_type(t1763, dtypes.bfloat16) # t1764: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1765 = prims.convert_element_type(t1764, dtypes.float32) # t1765: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1766 = ltorch.mul(t1765, t1765) # t1766: "cuda:0 f32[1, 512, 4096]" - # t1766 = prims.mul(t1765, t1765) # t1766: "cuda:0 f32[1, 512, 4096]" - t1770 = ltorch.mean(t1766, -1, True, dtype=None) # t1770: "cuda:0 f32[1, 512, 1]" - # t1768 = prims.sum(t1766, (2,)) # t1768: "cuda:0 f32[1, 512]" - # t1769 = prims.broadcast_in_dim(t1768, [1, 512, 1], [0, 1]) # t1769: "cuda:0 f32[1, 512, 1]" - # t1770 = ltorch.true_divide(t1769, 4096) # t1770: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1770 = prims.div(t1769, 4096.0) # t1770: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1772 = ltorch.add(t1770, 1e-05, alpha=None) # t1772: "cuda:0 f32[1, 512, 1]" - # t1772 = prims.add(t1770, 1e-05) # t1772: "cuda:0 f32[1, 512, 1]" - t1773 = ltorch.rsqrt(t1772) # t1773: "cuda:0 f32[1, 512, 1]" - # t1773 = prims.rsqrt(t1772) # t1773: "cuda:0 f32[1, 512, 1]" - t1775 = ltorch.mul(t1765, t1773) # t1775: "cuda:0 f32[1, 512, 4096]" - # t1774 = prims.broadcast_in_dim(t1773, (1, 512, 4096), (0, 1, 2)) # t1774: "cuda:0 f32[1, 512, 4096]" - # t1775 = prims.mul(t1765, t1774) # t1775: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1776 = ltorch.to(t1775, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1776: "cuda:0 bf16[1, 512, 4096]" - # t1776 = prims.convert_element_type(t1775, dtypes.bfloat16) # t1776: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1786 = ltorch.mul(t1776, t_transformer_h_10_norm_2_weight) # t1786: "cuda:0 bf16[1, 512, 4096]" - # t1782 = prims.broadcast_in_dim(t_transformer_h_10_norm_2_weight, (1, 512, 4096), (2,)) # t1782: "cuda:0 bf16[1, 512, 4096]" - # t1783 = prims.convert_element_type(t1776, dtypes.float32) # t1783: "cuda:0 f32[1, 512, 4096]" - # t1784 = prims.convert_element_type(t1782, dtypes.float32) # t1784: "cuda:0 f32[1, 512, 4096]" - # t1785 = prims.mul(t1783, t1784) # t1785: "cuda:0 f32[1, 512, 4096]" - # t1786 = prims.convert_element_type(t1785, dtypes.bfloat16) # t1786: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1791 = ltorch.linear(t1786, t_transformer_h_10_mlp_fc_1_weight, None) # t1791: "cuda:0 bf16[1, 512, 11008]" - # t1791 = prims.linear(t1786, t_transformer_h_10_mlp_fc_1_weight, None) # t1791: "cuda:0 bf16[1, 512, 11008]" - t1795 = ltorch.linear(t1786, t_transformer_h_10_mlp_fc_2_weight, None) # t1795: "cuda:0 bf16[1, 512, 11008]" - # t1795 = prims.linear(t1786, t_transformer_h_10_mlp_fc_2_weight, None) # t1795: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t1805 = ltorch.silu(t1791, False) # t1805: "cuda:0 bf16[1, 512, 11008]" - # t1796 = prims.convert_element_type(t1791, dtypes.float32) # t1796: "cuda:0 f32[1, 512, 11008]" - # t1797 = prims.neg(t1796) # t1797: "cuda:0 f32[1, 512, 11008]" - # t1798 = prims.exp(t1797) # t1798: "cuda:0 f32[1, 512, 11008]" - # t1799 = prims.add(1.0, t1798) # t1799: "cuda:0 f32[1, 512, 11008]" - # t1800 = prims.reciprocal(t1799) # t1800: "cuda:0 f32[1, 512, 11008]" - # t1801 = prims.convert_element_type(t1800, dtypes.bfloat16) # t1801: "cuda:0 bf16[1, 512, 11008]" - # t1802 = prims.convert_element_type(t1791, dtypes.float32) # t1802: "cuda:0 f32[1, 512, 11008]" - # t1803 = prims.convert_element_type(t1801, dtypes.float32) # t1803: "cuda:0 f32[1, 512, 11008]" - # t1804 = prims.mul(t1802, t1803) # t1804: "cuda:0 f32[1, 512, 11008]" - # t1805 = prims.convert_element_type(t1804, dtypes.bfloat16) # t1805: "cuda:0 bf16[1, 512, 11008]" - t1809 = ltorch.mul(t1805, t1795) # t1809: "cuda:0 bf16[1, 512, 11008]" - # t1806 = prims.convert_element_type(t1805, dtypes.float32) # t1806: "cuda:0 f32[1, 512, 11008]" - # t1807 = prims.convert_element_type(t1795, dtypes.float32) # t1807: "cuda:0 f32[1, 512, 11008]" - # t1808 = prims.mul(t1806, t1807) # t1808: "cuda:0 f32[1, 512, 11008]" - # t1809 = prims.convert_element_type(t1808, dtypes.bfloat16) # t1809: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1813 = ltorch.linear(t1809, t_transformer_h_10_mlp_proj_weight, None) # t1813: "cuda:0 bf16[1, 512, 4096]" - # t1813 = prims.linear(t1809, t_transformer_h_10_mlp_proj_weight, None) # t1813: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t1817 = ltorch.add(t1813, t1764, alpha=None) # t1817: "cuda:0 bf16[1, 512, 4096]" - # t1814 = prims.convert_element_type(t1813, dtypes.float32) # t1814: "cuda:0 f32[1, 512, 4096]" - # t1815 = prims.convert_element_type(t1764, dtypes.float32) # t1815: "cuda:0 f32[1, 512, 4096]" - # t1816 = prims.add(t1814, t1815) # t1816: "cuda:0 f32[1, 512, 4096]" - # t1817 = prims.convert_element_type(t1816, dtypes.bfloat16) # t1817: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1819 = prims.convert_element_type(t1817, dtypes.float32) # t1819: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1820 = ltorch.mul(t1819, t1819) # t1820: "cuda:0 f32[1, 512, 4096]" - # t1820 = prims.mul(t1819, t1819) # t1820: "cuda:0 f32[1, 512, 4096]" - t1824 = ltorch.mean(t1820, -1, True, dtype=None) # t1824: "cuda:0 f32[1, 512, 1]" - # t1822 = prims.sum(t1820, (2,)) # t1822: "cuda:0 f32[1, 512]" - # t1823 = prims.broadcast_in_dim(t1822, [1, 512, 1], [0, 1]) # t1823: "cuda:0 f32[1, 512, 1]" - # t1824 = ltorch.true_divide(t1823, 4096) # t1824: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1824 = prims.div(t1823, 4096.0) # t1824: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1826 = ltorch.add(t1824, 1e-05, alpha=None) # t1826: "cuda:0 f32[1, 512, 1]" - # t1826 = prims.add(t1824, 1e-05) # t1826: "cuda:0 f32[1, 512, 1]" - t1827 = ltorch.rsqrt(t1826) # t1827: "cuda:0 f32[1, 512, 1]" - # t1827 = prims.rsqrt(t1826) # t1827: "cuda:0 f32[1, 512, 1]" - t1829 = ltorch.mul(t1819, t1827) # t1829: "cuda:0 f32[1, 512, 4096]" - # t1828 = prims.broadcast_in_dim(t1827, (1, 512, 4096), (0, 1, 2)) # t1828: "cuda:0 f32[1, 512, 4096]" - # t1829 = prims.mul(t1819, t1828) # t1829: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1830 = ltorch.to(t1829, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1830: "cuda:0 bf16[1, 512, 4096]" - # t1830 = prims.convert_element_type(t1829, dtypes.bfloat16) # t1830: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1840 = ltorch.mul(t1830, t_transformer_h_11_norm_1_weight) # t1840: "cuda:0 bf16[1, 512, 4096]" - # t1836 = prims.broadcast_in_dim(t_transformer_h_11_norm_1_weight, (1, 512, 4096), (2,)) # t1836: "cuda:0 bf16[1, 512, 4096]" - # t1837 = prims.convert_element_type(t1830, dtypes.float32) # t1837: "cuda:0 f32[1, 512, 4096]" - # t1838 = prims.convert_element_type(t1836, dtypes.float32) # t1838: "cuda:0 f32[1, 512, 4096]" - # t1839 = prims.mul(t1837, t1838) # t1839: "cuda:0 f32[1, 512, 4096]" - # t1840 = prims.convert_element_type(t1839, dtypes.bfloat16) # t1840: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1845 = ltorch.linear(t1840, t_transformer_h_11_attn_attn_weight, None) # t1845: "cuda:0 bf16[1, 512, 12288]" - # t1845 = prims.linear(t1840, t_transformer_h_11_attn_attn_weight, None) # t1845: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t1846 = ltorch.view(t1845, 1, 512, 32, 3, 128) # t1846: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1846 = ltorch.reshape(t1845, (1, 512, 32, 3, 128)) # t1846: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t1846 = prims.reshape(t1845, (1, 512, 32, 3, 128)) # t1846: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t1847 = ltorch.permute(t1846, 0, 2, 3, 1, 4) # t1847: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t1847 = prims.transpose(t1846, (0, 2, 3, 1, 4)) # t1847: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t1848, t1849, t1850) = ltorch.split(t1847, (1, 1, 1), 2) - # t1848 = prims.slice_prim(t1847, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1848: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1849 = prims.slice_prim(t1847, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1849: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1850 = prims.slice_prim(t1847, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1850: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t1851 = ltorch.reshape(t1848, 1, -1, 512, 128) # t1851: "cuda:0 bf16[1, 32, 512, 128]" - # t1851 = prims.reshape(t1848, (1, 32, 512, 128)) # t1851: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t1852 = ltorch.reshape(t1849, 1, -1, 512, 128) # t1852: "cuda:0 bf16[1, 32, 512, 128]" - # t1852 = prims.reshape(t1849, (1, 32, 512, 128)) # t1852: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t1853 = ltorch.reshape(t1850, 1, -1, 512, 128) # t1853: "cuda:0 bf16[1, 32, 512, 128]" - # t1853 = prims.reshape(t1850, (1, 32, 512, 128)) # t1853: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t1854 = ltorch.getitem(t1851, (..., slice(None, 128, None))) # t1854: "cuda:0 bf16[1, 32, 512, 128]" - # t1854 = prims.slice_prim(t1851, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1854: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1855 = ltorch.getitem(t1854, (..., slice(None, 64, None))) # t1855: "cuda:0 bf16[1, 32, 512, 64]" - # t1855 = prims.slice_prim(t1854, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1855: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1856 = ltorch.getitem(t1854, (..., slice(64, None, None))) # t1856: "cuda:0 bf16[1, 32, 512, 64]" - # t1856 = prims.slice_prim(t1854, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1856: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1859 = ltorch.neg(t1856) # t1859: "cuda:0 bf16[1, 32, 512, 64]" - # t1857 = prims.convert_element_type(t1856, dtypes.float32) # t1857: "cuda:0 f32[1, 32, 512, 64]" - # t1858 = prims.neg(t1857) # t1858: "cuda:0 f32[1, 32, 512, 64]" - # t1859 = prims.convert_element_type(t1858, dtypes.bfloat16) # t1859: "cuda:0 bf16[1, 32, 512, 64]" - t1860 = ltorch.cat((t1859, t1855), -1) # t1860: "cuda:0 bf16[1, 32, 512, 128]" - # t1860 = prims.cat((t1859, t1855), -1) # t1860: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1863 = ltorch.mul(t1854, cos) # t1863: "cuda:0 f32[1, 32, 512, 128]" - # t1861 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1861: "cuda:0 f32[1, 32, 512, 128]" - # t1862 = prims.convert_element_type(t1854, dtypes.float32) # t1862: "cuda:0 f32[1, 32, 512, 128]" - # t1863 = prims.mul(t1862, t1861) # t1863: "cuda:0 f32[1, 32, 512, 128]" - t1866 = ltorch.mul(t1860, sin) # t1866: "cuda:0 f32[1, 32, 512, 128]" - # t1864 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1864: "cuda:0 f32[1, 32, 512, 128]" - # t1865 = prims.convert_element_type(t1860, dtypes.float32) # t1865: "cuda:0 f32[1, 32, 512, 128]" - # t1866 = prims.mul(t1865, t1864) # t1866: "cuda:0 f32[1, 32, 512, 128]" - t1867 = ltorch.add(t1863, t1866, alpha=None) # t1867: "cuda:0 f32[1, 32, 512, 128]" - # t1867 = prims.add(t1863, t1866) # t1867: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1868 = ltorch.to(t1867, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1868: "cuda:0 bf16[1, 32, 512, 128]" - # t1868 = prims.convert_element_type(t1867, dtypes.bfloat16) # t1868: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t1869 = ltorch.getitem(t1852, (..., slice(None, 128, None))) # t1869: "cuda:0 bf16[1, 32, 512, 128]" - # t1869 = prims.slice_prim(t1852, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1869: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t1870 = ltorch.getitem(t1869, (..., slice(None, 64, None))) # t1870: "cuda:0 bf16[1, 32, 512, 64]" - # t1870 = prims.slice_prim(t1869, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1870: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t1871 = ltorch.getitem(t1869, (..., slice(64, None, None))) # t1871: "cuda:0 bf16[1, 32, 512, 64]" - # t1871 = prims.slice_prim(t1869, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1871: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t1874 = ltorch.neg(t1871) # t1874: "cuda:0 bf16[1, 32, 512, 64]" - # t1872 = prims.convert_element_type(t1871, dtypes.float32) # t1872: "cuda:0 f32[1, 32, 512, 64]" - # t1873 = prims.neg(t1872) # t1873: "cuda:0 f32[1, 32, 512, 64]" - # t1874 = prims.convert_element_type(t1873, dtypes.bfloat16) # t1874: "cuda:0 bf16[1, 32, 512, 64]" - t1875 = ltorch.cat((t1874, t1870), -1) # t1875: "cuda:0 bf16[1, 32, 512, 128]" - # t1875 = prims.cat((t1874, t1870), -1) # t1875: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t1878 = ltorch.mul(t1869, cos) # t1878: "cuda:0 f32[1, 32, 512, 128]" - # t1876 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t1876: "cuda:0 f32[1, 32, 512, 128]" - # t1877 = prims.convert_element_type(t1869, dtypes.float32) # t1877: "cuda:0 f32[1, 32, 512, 128]" - # t1878 = prims.mul(t1877, t1876) # t1878: "cuda:0 f32[1, 32, 512, 128]" - t1881 = ltorch.mul(t1875, sin) # t1881: "cuda:0 f32[1, 32, 512, 128]" - # t1879 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t1879: "cuda:0 f32[1, 32, 512, 128]" - # t1880 = prims.convert_element_type(t1875, dtypes.float32) # t1880: "cuda:0 f32[1, 32, 512, 128]" - # t1881 = prims.mul(t1880, t1879) # t1881: "cuda:0 f32[1, 32, 512, 128]" - t1882 = ltorch.add(t1878, t1881, alpha=None) # t1882: "cuda:0 f32[1, 32, 512, 128]" - # t1882 = prims.add(t1878, t1881) # t1882: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t1883 = ltorch.to(t1882, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1883: "cuda:0 bf16[1, 32, 512, 128]" - # t1883 = prims.convert_element_type(t1882, dtypes.bfloat16) # t1883: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t1884 = ltorch.getitem(t1851, (..., slice(128, None, None))) # t1884: "cuda:0 bf16[1, 32, 512, 0]" - # t1884 = prims.slice_prim(t1851, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1884: "cuda:0 bf16[1, 32, 512, 0]" - t1885 = ltorch.cat((t1868, t1884), -1) # t1885: "cuda:0 bf16[1, 32, 512, 128]" - # t1885 = prims.cat((t1868, t1884), -1) # t1885: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t1886 = ltorch.getitem(t1852, (..., slice(128, None, None))) # t1886: "cuda:0 bf16[1, 32, 512, 0]" - # t1886 = prims.slice_prim(t1852, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1886: "cuda:0 bf16[1, 32, 512, 0]" - t1887 = ltorch.cat((t1883, t1886), -1) # t1887: "cuda:0 bf16[1, 32, 512, 128]" - # t1887 = prims.cat((t1883, t1886), -1) # t1887: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t1917 = ltorch.scaled_dot_product_attention(t1885, t1887, t1853, None, 0.0, True, scale=0.08838834764831843) # t1917: "cuda:0 bf16[1, 32, 512, 128]" - # t1890 = ltorch.mul(t1885, 0.29730177875068026) # t1890: "cuda:0 bf16[1, 32, 512, 128]" - # t1888 = prims.convert_element_type(t1885, dtypes.float32) # t1888: "cuda:0 f32[1, 32, 512, 128]" - # t1889 = prims.mul(t1888, 0.29730177875068026) # t1889: "cuda:0 f32[1, 32, 512, 128]" - # t1890 = prims.convert_element_type(t1889, dtypes.bfloat16) # t1890: "cuda:0 bf16[1, 32, 512, 128]" - # t1891 = ltorch.transpose(t1887, -2, -1) # t1891: "cuda:0 bf16[1, 32, 128, 512]" - # t1891 = prims.transpose(t1887, (0, 1, 3, 2)) # t1891: "cuda:0 bf16[1, 32, 128, 512]" - # t1894 = ltorch.mul(t1891, 0.29730177875068026) # t1894: "cuda:0 bf16[1, 32, 128, 512]" - # t1892 = prims.convert_element_type(t1891, dtypes.float32) # t1892: "cuda:0 f32[1, 32, 128, 512]" - # t1893 = prims.mul(t1892, 0.29730177875068026) # t1893: "cuda:0 f32[1, 32, 128, 512]" - # t1894 = prims.convert_element_type(t1893, dtypes.bfloat16) # t1894: "cuda:0 bf16[1, 32, 128, 512]" - # t1895 = ltorch.matmul(t1890, t1894) # t1895: "cuda:0 bf16[1, 32, 512, 512]" - # t1895 = prims.matmul(t1890, t1894) # t1895: "cuda:0 bf16[1, 32, 512, 512]" - # t1905 = ltorch.tril(t1895, 0, fill_value=-float('inf')) # t1905: "cuda:0 bf16[1, 32, 512, 512]" - # t1896 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1896: "cuda:0 i64[512]" - # t1896 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1896: "cuda:0 i64[512]" - # t1897 = ltorch.unsqueeze(t1896, -1) # t1897: "cuda:0 i64[512, 1]" - # t1897 = prims.broadcast_in_dim(t1896, [512, 1], [0]) # t1897: "cuda:0 i64[512, 1]" - # t1898 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t1898: "cuda:0 i64[512]" - # t1898 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t1898: "cuda:0 i64[512]" - # t1899 = ltorch.unsqueeze(t1898, -2) # t1899: "cuda:0 i64[1, 512]" - # t1899 = prims.broadcast_in_dim(t1898, [1, 512], [1]) # t1899: "cuda:0 i64[1, 512]" - # t1900 = ltorch.add(t1897, 0, alpha=None) # t1900: "cuda:0 i64[512, 1]" - # t1900 = prims.add(t1897, 0) # t1900: "cuda:0 i64[512, 1]" - # t1903 = ltorch.ge(t1900, t1899) # t1903: "cuda:0 b8[512, 512]" - # t1901 = prims.broadcast_in_dim(t1900, (512, 512), (0, 1)) # t1901: "cuda:0 i64[512, 512]" - # t1902 = prims.broadcast_in_dim(t1899, (512, 512), (0, 1)) # t1902: "cuda:0 i64[512, 512]" - # t1903 = prims.ge(t1901, t1902) # t1903: "cuda:0 b8[512, 512]" - # t1905 = ltorch.where(t1903, t1895, -float('inf')) # t1905: "cuda:0 bf16[1, 32, 512, 512]" - # t1904 = prims.broadcast_in_dim(t1903, (1, 32, 512, 512), (2, 3)) # t1904: "cuda:0 b8[1, 32, 512, 512]" - # t1905 = prims.where(t1904, t1895, -float('inf')) # t1905: "cuda:0 bf16[1, 32, 512, 512]" - # t1916 = ltorch._softmax(t1905, -1, dtype=None) # t1916: "cuda:0 bf16[1, 32, 512, 512]" - # t1906 = ltorch.to(t1905, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t1906: "cuda:0 f32[1, 32, 512, 512]" - # t1906 = prims.convert_element_type(t1905, dtypes.float32) # t1906: "cuda:0 f32[1, 32, 512, 512]" - # t1908 = ltorch.amax(t1906, -1, True) # t1908: "cuda:0 f32[1, 32, 512, 1]" - # t1907 = prims.amax(t1906, (3,)) # t1907: "cuda:0 f32[1, 32, 512]" - # t1908 = prims.broadcast_in_dim(t1907, [1, 32, 512, 1], [0, 1, 2]) # t1908: "cuda:0 f32[1, 32, 512, 1]" - # t1910 = ltorch.sub(t1906, t1908, alpha=None) # t1910: "cuda:0 f32[1, 32, 512, 512]" - # t1909 = prims.broadcast_in_dim(t1908, (1, 32, 512, 512), (0, 1, 2, 3)) # t1909: "cuda:0 f32[1, 32, 512, 512]" - # t1910 = prims.sub(t1906, t1909) # t1910: "cuda:0 f32[1, 32, 512, 512]" - # t1911 = ltorch.exp(t1910) # t1911: "cuda:0 f32[1, 32, 512, 512]" - # t1911 = prims.exp(t1910) # t1911: "cuda:0 f32[1, 32, 512, 512]" - # t1913 = ltorch.sum(t1911, -1, True, dtype=None) # t1913: "cuda:0 f32[1, 32, 512, 1]" - # t1912 = prims.sum(t1911, (3,)) # t1912: "cuda:0 f32[1, 32, 512]" - # t1913 = prims.broadcast_in_dim(t1912, [1, 32, 512, 1], [0, 1, 2]) # t1913: "cuda:0 f32[1, 32, 512, 1]" - # t1915 = ltorch.true_divide(t1911, t1913) # t1915: "cuda:0 f32[1, 32, 512, 512]" - # t1914 = prims.broadcast_in_dim(t1913, (1, 32, 512, 512), (0, 1, 2, 3)) # t1914: "cuda:0 f32[1, 32, 512, 512]" - # t1915 = prims.div(t1911, t1914) # t1915: "cuda:0 f32[1, 32, 512, 512]" - # t1916 = ltorch.to(t1915, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t1916: "cuda:0 bf16[1, 32, 512, 512]" - # t1916 = prims.convert_element_type(t1915, dtypes.bfloat16) # t1916: "cuda:0 bf16[1, 32, 512, 512]" - # t1917 = ltorch.matmul(t1916, t1853) # t1917: "cuda:0 bf16[1, 32, 512, 128]" - # t1917 = prims.matmul(t1916, t1853) # t1917: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t1918 = ltorch.transpose(t1917, 1, 2) # t1918: "cuda:0 bf16[1, 512, 32, 128]" - # t1918 = prims.transpose(t1917, (0, 2, 1, 3)) # t1918: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t1919 = ltorch.reshape(t1918, 1, 512, 4096) # t1919: "cuda:0 bf16[1, 512, 4096]" - # t1919 = prims.reshape(t1918, (1, 512, 4096)) # t1919: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1923 = ltorch.linear(t1919, t_transformer_h_11_attn_proj_weight, None) # t1923: "cuda:0 bf16[1, 512, 4096]" - # t1923 = prims.linear(t1919, t_transformer_h_11_attn_proj_weight, None) # t1923: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t1927 = ltorch.add(t1923, t1817, alpha=None) # t1927: "cuda:0 bf16[1, 512, 4096]" - # t1924 = prims.convert_element_type(t1923, dtypes.float32) # t1924: "cuda:0 f32[1, 512, 4096]" - # t1925 = prims.convert_element_type(t1817, dtypes.float32) # t1925: "cuda:0 f32[1, 512, 4096]" - # t1926 = prims.add(t1924, t1925) # t1926: "cuda:0 f32[1, 512, 4096]" - # t1927 = prims.convert_element_type(t1926, dtypes.bfloat16) # t1927: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1928 = prims.convert_element_type(t1927, dtypes.float32) # t1928: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1929 = ltorch.mul(t1928, t1928) # t1929: "cuda:0 f32[1, 512, 4096]" - # t1929 = prims.mul(t1928, t1928) # t1929: "cuda:0 f32[1, 512, 4096]" - t1933 = ltorch.mean(t1929, -1, True, dtype=None) # t1933: "cuda:0 f32[1, 512, 1]" - # t1931 = prims.sum(t1929, (2,)) # t1931: "cuda:0 f32[1, 512]" - # t1932 = prims.broadcast_in_dim(t1931, [1, 512, 1], [0, 1]) # t1932: "cuda:0 f32[1, 512, 1]" - # t1933 = ltorch.true_divide(t1932, 4096) # t1933: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1933 = prims.div(t1932, 4096.0) # t1933: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1935 = ltorch.add(t1933, 1e-05, alpha=None) # t1935: "cuda:0 f32[1, 512, 1]" - # t1935 = prims.add(t1933, 1e-05) # t1935: "cuda:0 f32[1, 512, 1]" - t1936 = ltorch.rsqrt(t1935) # t1936: "cuda:0 f32[1, 512, 1]" - # t1936 = prims.rsqrt(t1935) # t1936: "cuda:0 f32[1, 512, 1]" - t1938 = ltorch.mul(t1928, t1936) # t1938: "cuda:0 f32[1, 512, 4096]" - # t1937 = prims.broadcast_in_dim(t1936, (1, 512, 4096), (0, 1, 2)) # t1937: "cuda:0 f32[1, 512, 4096]" - # t1938 = prims.mul(t1928, t1937) # t1938: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1939 = ltorch.to(t1938, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1939: "cuda:0 bf16[1, 512, 4096]" - # t1939 = prims.convert_element_type(t1938, dtypes.bfloat16) # t1939: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t1949 = ltorch.mul(t1939, t_transformer_h_11_norm_2_weight) # t1949: "cuda:0 bf16[1, 512, 4096]" - # t1945 = prims.broadcast_in_dim(t_transformer_h_11_norm_2_weight, (1, 512, 4096), (2,)) # t1945: "cuda:0 bf16[1, 512, 4096]" - # t1946 = prims.convert_element_type(t1939, dtypes.float32) # t1946: "cuda:0 f32[1, 512, 4096]" - # t1947 = prims.convert_element_type(t1945, dtypes.float32) # t1947: "cuda:0 f32[1, 512, 4096]" - # t1948 = prims.mul(t1946, t1947) # t1948: "cuda:0 f32[1, 512, 4096]" - # t1949 = prims.convert_element_type(t1948, dtypes.bfloat16) # t1949: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1954 = ltorch.linear(t1949, t_transformer_h_11_mlp_fc_1_weight, None) # t1954: "cuda:0 bf16[1, 512, 11008]" - # t1954 = prims.linear(t1949, t_transformer_h_11_mlp_fc_1_weight, None) # t1954: "cuda:0 bf16[1, 512, 11008]" - t1958 = ltorch.linear(t1949, t_transformer_h_11_mlp_fc_2_weight, None) # t1958: "cuda:0 bf16[1, 512, 11008]" - # t1958 = prims.linear(t1949, t_transformer_h_11_mlp_fc_2_weight, None) # t1958: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t1968 = ltorch.silu(t1954, False) # t1968: "cuda:0 bf16[1, 512, 11008]" - # t1959 = prims.convert_element_type(t1954, dtypes.float32) # t1959: "cuda:0 f32[1, 512, 11008]" - # t1960 = prims.neg(t1959) # t1960: "cuda:0 f32[1, 512, 11008]" - # t1961 = prims.exp(t1960) # t1961: "cuda:0 f32[1, 512, 11008]" - # t1962 = prims.add(1.0, t1961) # t1962: "cuda:0 f32[1, 512, 11008]" - # t1963 = prims.reciprocal(t1962) # t1963: "cuda:0 f32[1, 512, 11008]" - # t1964 = prims.convert_element_type(t1963, dtypes.bfloat16) # t1964: "cuda:0 bf16[1, 512, 11008]" - # t1965 = prims.convert_element_type(t1954, dtypes.float32) # t1965: "cuda:0 f32[1, 512, 11008]" - # t1966 = prims.convert_element_type(t1964, dtypes.float32) # t1966: "cuda:0 f32[1, 512, 11008]" - # t1967 = prims.mul(t1965, t1966) # t1967: "cuda:0 f32[1, 512, 11008]" - # t1968 = prims.convert_element_type(t1967, dtypes.bfloat16) # t1968: "cuda:0 bf16[1, 512, 11008]" - t1972 = ltorch.mul(t1968, t1958) # t1972: "cuda:0 bf16[1, 512, 11008]" - # t1969 = prims.convert_element_type(t1968, dtypes.float32) # t1969: "cuda:0 f32[1, 512, 11008]" - # t1970 = prims.convert_element_type(t1958, dtypes.float32) # t1970: "cuda:0 f32[1, 512, 11008]" - # t1971 = prims.mul(t1969, t1970) # t1971: "cuda:0 f32[1, 512, 11008]" - # t1972 = prims.convert_element_type(t1971, dtypes.bfloat16) # t1972: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t1976 = ltorch.linear(t1972, t_transformer_h_11_mlp_proj_weight, None) # t1976: "cuda:0 bf16[1, 512, 4096]" - # t1976 = prims.linear(t1972, t_transformer_h_11_mlp_proj_weight, None) # t1976: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t1980 = ltorch.add(t1976, t1927, alpha=None) # t1980: "cuda:0 bf16[1, 512, 4096]" - # t1977 = prims.convert_element_type(t1976, dtypes.float32) # t1977: "cuda:0 f32[1, 512, 4096]" - # t1978 = prims.convert_element_type(t1927, dtypes.float32) # t1978: "cuda:0 f32[1, 512, 4096]" - # t1979 = prims.add(t1977, t1978) # t1979: "cuda:0 f32[1, 512, 4096]" - # t1980 = prims.convert_element_type(t1979, dtypes.bfloat16) # t1980: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t1982 = prims.convert_element_type(t1980, dtypes.float32) # t1982: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t1983 = ltorch.mul(t1982, t1982) # t1983: "cuda:0 f32[1, 512, 4096]" - # t1983 = prims.mul(t1982, t1982) # t1983: "cuda:0 f32[1, 512, 4096]" - t1987 = ltorch.mean(t1983, -1, True, dtype=None) # t1987: "cuda:0 f32[1, 512, 1]" - # t1985 = prims.sum(t1983, (2,)) # t1985: "cuda:0 f32[1, 512]" - # t1986 = prims.broadcast_in_dim(t1985, [1, 512, 1], [0, 1]) # t1986: "cuda:0 f32[1, 512, 1]" - # t1987 = ltorch.true_divide(t1986, 4096) # t1987: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t1987 = prims.div(t1986, 4096.0) # t1987: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t1989 = ltorch.add(t1987, 1e-05, alpha=None) # t1989: "cuda:0 f32[1, 512, 1]" - # t1989 = prims.add(t1987, 1e-05) # t1989: "cuda:0 f32[1, 512, 1]" - t1990 = ltorch.rsqrt(t1989) # t1990: "cuda:0 f32[1, 512, 1]" - # t1990 = prims.rsqrt(t1989) # t1990: "cuda:0 f32[1, 512, 1]" - t1992 = ltorch.mul(t1982, t1990) # t1992: "cuda:0 f32[1, 512, 4096]" - # t1991 = prims.broadcast_in_dim(t1990, (1, 512, 4096), (0, 1, 2)) # t1991: "cuda:0 f32[1, 512, 4096]" - # t1992 = prims.mul(t1982, t1991) # t1992: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t1993 = ltorch.to(t1992, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t1993: "cuda:0 bf16[1, 512, 4096]" - # t1993 = prims.convert_element_type(t1992, dtypes.bfloat16) # t1993: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t2003 = ltorch.mul(t1993, t_transformer_h_12_norm_1_weight) # t2003: "cuda:0 bf16[1, 512, 4096]" - # t1999 = prims.broadcast_in_dim(t_transformer_h_12_norm_1_weight, (1, 512, 4096), (2,)) # t1999: "cuda:0 bf16[1, 512, 4096]" - # t2000 = prims.convert_element_type(t1993, dtypes.float32) # t2000: "cuda:0 f32[1, 512, 4096]" - # t2001 = prims.convert_element_type(t1999, dtypes.float32) # t2001: "cuda:0 f32[1, 512, 4096]" - # t2002 = prims.mul(t2000, t2001) # t2002: "cuda:0 f32[1, 512, 4096]" - # t2003 = prims.convert_element_type(t2002, dtypes.bfloat16) # t2003: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2008 = ltorch.linear(t2003, t_transformer_h_12_attn_attn_weight, None) # t2008: "cuda:0 bf16[1, 512, 12288]" - # t2008 = prims.linear(t2003, t_transformer_h_12_attn_attn_weight, None) # t2008: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t2009 = ltorch.view(t2008, 1, 512, 32, 3, 128) # t2009: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t2009 = ltorch.reshape(t2008, (1, 512, 32, 3, 128)) # t2009: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t2009 = prims.reshape(t2008, (1, 512, 32, 3, 128)) # t2009: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t2010 = ltorch.permute(t2009, 0, 2, 3, 1, 4) # t2010: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t2010 = prims.transpose(t2009, (0, 2, 3, 1, 4)) # t2010: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t2011, t2012, t2013) = ltorch.split(t2010, (1, 1, 1), 2) - # t2011 = prims.slice_prim(t2010, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t2011: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2012 = prims.slice_prim(t2010, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t2012: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2013 = prims.slice_prim(t2010, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t2013: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t2014 = ltorch.reshape(t2011, 1, -1, 512, 128) # t2014: "cuda:0 bf16[1, 32, 512, 128]" - # t2014 = prims.reshape(t2011, (1, 32, 512, 128)) # t2014: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t2015 = ltorch.reshape(t2012, 1, -1, 512, 128) # t2015: "cuda:0 bf16[1, 32, 512, 128]" - # t2015 = prims.reshape(t2012, (1, 32, 512, 128)) # t2015: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t2016 = ltorch.reshape(t2013, 1, -1, 512, 128) # t2016: "cuda:0 bf16[1, 32, 512, 128]" - # t2016 = prims.reshape(t2013, (1, 32, 512, 128)) # t2016: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t2017 = ltorch.getitem(t2014, (..., slice(None, 128, None))) # t2017: "cuda:0 bf16[1, 32, 512, 128]" - # t2017 = prims.slice_prim(t2014, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2017: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t2018 = ltorch.getitem(t2017, (..., slice(None, 64, None))) # t2018: "cuda:0 bf16[1, 32, 512, 64]" - # t2018 = prims.slice_prim(t2017, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2018: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t2019 = ltorch.getitem(t2017, (..., slice(64, None, None))) # t2019: "cuda:0 bf16[1, 32, 512, 64]" - # t2019 = prims.slice_prim(t2017, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2019: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t2022 = ltorch.neg(t2019) # t2022: "cuda:0 bf16[1, 32, 512, 64]" - # t2020 = prims.convert_element_type(t2019, dtypes.float32) # t2020: "cuda:0 f32[1, 32, 512, 64]" - # t2021 = prims.neg(t2020) # t2021: "cuda:0 f32[1, 32, 512, 64]" - # t2022 = prims.convert_element_type(t2021, dtypes.bfloat16) # t2022: "cuda:0 bf16[1, 32, 512, 64]" - t2023 = ltorch.cat((t2022, t2018), -1) # t2023: "cuda:0 bf16[1, 32, 512, 128]" - # t2023 = prims.cat((t2022, t2018), -1) # t2023: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t2026 = ltorch.mul(t2017, cos) # t2026: "cuda:0 f32[1, 32, 512, 128]" - # t2024 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t2024: "cuda:0 f32[1, 32, 512, 128]" - # t2025 = prims.convert_element_type(t2017, dtypes.float32) # t2025: "cuda:0 f32[1, 32, 512, 128]" - # t2026 = prims.mul(t2025, t2024) # t2026: "cuda:0 f32[1, 32, 512, 128]" - t2029 = ltorch.mul(t2023, sin) # t2029: "cuda:0 f32[1, 32, 512, 128]" - # t2027 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t2027: "cuda:0 f32[1, 32, 512, 128]" - # t2028 = prims.convert_element_type(t2023, dtypes.float32) # t2028: "cuda:0 f32[1, 32, 512, 128]" - # t2029 = prims.mul(t2028, t2027) # t2029: "cuda:0 f32[1, 32, 512, 128]" - t2030 = ltorch.add(t2026, t2029, alpha=None) # t2030: "cuda:0 f32[1, 32, 512, 128]" - # t2030 = prims.add(t2026, t2029) # t2030: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t2031 = ltorch.to(t2030, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2031: "cuda:0 bf16[1, 32, 512, 128]" - # t2031 = prims.convert_element_type(t2030, dtypes.bfloat16) # t2031: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t2032 = ltorch.getitem(t2015, (..., slice(None, 128, None))) # t2032: "cuda:0 bf16[1, 32, 512, 128]" - # t2032 = prims.slice_prim(t2015, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2032: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t2033 = ltorch.getitem(t2032, (..., slice(None, 64, None))) # t2033: "cuda:0 bf16[1, 32, 512, 64]" - # t2033 = prims.slice_prim(t2032, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2033: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t2034 = ltorch.getitem(t2032, (..., slice(64, None, None))) # t2034: "cuda:0 bf16[1, 32, 512, 64]" - # t2034 = prims.slice_prim(t2032, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2034: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t2037 = ltorch.neg(t2034) # t2037: "cuda:0 bf16[1, 32, 512, 64]" - # t2035 = prims.convert_element_type(t2034, dtypes.float32) # t2035: "cuda:0 f32[1, 32, 512, 64]" - # t2036 = prims.neg(t2035) # t2036: "cuda:0 f32[1, 32, 512, 64]" - # t2037 = prims.convert_element_type(t2036, dtypes.bfloat16) # t2037: "cuda:0 bf16[1, 32, 512, 64]" - t2038 = ltorch.cat((t2037, t2033), -1) # t2038: "cuda:0 bf16[1, 32, 512, 128]" - # t2038 = prims.cat((t2037, t2033), -1) # t2038: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t2041 = ltorch.mul(t2032, cos) # t2041: "cuda:0 f32[1, 32, 512, 128]" - # t2039 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t2039: "cuda:0 f32[1, 32, 512, 128]" - # t2040 = prims.convert_element_type(t2032, dtypes.float32) # t2040: "cuda:0 f32[1, 32, 512, 128]" - # t2041 = prims.mul(t2040, t2039) # t2041: "cuda:0 f32[1, 32, 512, 128]" - t2044 = ltorch.mul(t2038, sin) # t2044: "cuda:0 f32[1, 32, 512, 128]" - # t2042 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t2042: "cuda:0 f32[1, 32, 512, 128]" - # t2043 = prims.convert_element_type(t2038, dtypes.float32) # t2043: "cuda:0 f32[1, 32, 512, 128]" - # t2044 = prims.mul(t2043, t2042) # t2044: "cuda:0 f32[1, 32, 512, 128]" - t2045 = ltorch.add(t2041, t2044, alpha=None) # t2045: "cuda:0 f32[1, 32, 512, 128]" - # t2045 = prims.add(t2041, t2044) # t2045: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t2046 = ltorch.to(t2045, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2046: "cuda:0 bf16[1, 32, 512, 128]" - # t2046 = prims.convert_element_type(t2045, dtypes.bfloat16) # t2046: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t2047 = ltorch.getitem(t2014, (..., slice(128, None, None))) # t2047: "cuda:0 bf16[1, 32, 512, 0]" - # t2047 = prims.slice_prim(t2014, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2047: "cuda:0 bf16[1, 32, 512, 0]" - t2048 = ltorch.cat((t2031, t2047), -1) # t2048: "cuda:0 bf16[1, 32, 512, 128]" - # t2048 = prims.cat((t2031, t2047), -1) # t2048: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t2049 = ltorch.getitem(t2015, (..., slice(128, None, None))) # t2049: "cuda:0 bf16[1, 32, 512, 0]" - # t2049 = prims.slice_prim(t2015, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2049: "cuda:0 bf16[1, 32, 512, 0]" - t2050 = ltorch.cat((t2046, t2049), -1) # t2050: "cuda:0 bf16[1, 32, 512, 128]" - # t2050 = prims.cat((t2046, t2049), -1) # t2050: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t2080 = ltorch.scaled_dot_product_attention(t2048, t2050, t2016, None, 0.0, True, scale=0.08838834764831843) # t2080: "cuda:0 bf16[1, 32, 512, 128]" - # t2053 = ltorch.mul(t2048, 0.29730177875068026) # t2053: "cuda:0 bf16[1, 32, 512, 128]" - # t2051 = prims.convert_element_type(t2048, dtypes.float32) # t2051: "cuda:0 f32[1, 32, 512, 128]" - # t2052 = prims.mul(t2051, 0.29730177875068026) # t2052: "cuda:0 f32[1, 32, 512, 128]" - # t2053 = prims.convert_element_type(t2052, dtypes.bfloat16) # t2053: "cuda:0 bf16[1, 32, 512, 128]" - # t2054 = ltorch.transpose(t2050, -2, -1) # t2054: "cuda:0 bf16[1, 32, 128, 512]" - # t2054 = prims.transpose(t2050, (0, 1, 3, 2)) # t2054: "cuda:0 bf16[1, 32, 128, 512]" - # t2057 = ltorch.mul(t2054, 0.29730177875068026) # t2057: "cuda:0 bf16[1, 32, 128, 512]" - # t2055 = prims.convert_element_type(t2054, dtypes.float32) # t2055: "cuda:0 f32[1, 32, 128, 512]" - # t2056 = prims.mul(t2055, 0.29730177875068026) # t2056: "cuda:0 f32[1, 32, 128, 512]" - # t2057 = prims.convert_element_type(t2056, dtypes.bfloat16) # t2057: "cuda:0 bf16[1, 32, 128, 512]" - # t2058 = ltorch.matmul(t2053, t2057) # t2058: "cuda:0 bf16[1, 32, 512, 512]" - # t2058 = prims.matmul(t2053, t2057) # t2058: "cuda:0 bf16[1, 32, 512, 512]" - # t2068 = ltorch.tril(t2058, 0, fill_value=-float('inf')) # t2068: "cuda:0 bf16[1, 32, 512, 512]" - # t2059 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t2059: "cuda:0 i64[512]" - # t2059 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t2059: "cuda:0 i64[512]" - # t2060 = ltorch.unsqueeze(t2059, -1) # t2060: "cuda:0 i64[512, 1]" - # t2060 = prims.broadcast_in_dim(t2059, [512, 1], [0]) # t2060: "cuda:0 i64[512, 1]" - # t2061 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t2061: "cuda:0 i64[512]" - # t2061 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t2061: "cuda:0 i64[512]" - # t2062 = ltorch.unsqueeze(t2061, -2) # t2062: "cuda:0 i64[1, 512]" - # t2062 = prims.broadcast_in_dim(t2061, [1, 512], [1]) # t2062: "cuda:0 i64[1, 512]" - # t2063 = ltorch.add(t2060, 0, alpha=None) # t2063: "cuda:0 i64[512, 1]" - # t2063 = prims.add(t2060, 0) # t2063: "cuda:0 i64[512, 1]" - # t2066 = ltorch.ge(t2063, t2062) # t2066: "cuda:0 b8[512, 512]" - # t2064 = prims.broadcast_in_dim(t2063, (512, 512), (0, 1)) # t2064: "cuda:0 i64[512, 512]" - # t2065 = prims.broadcast_in_dim(t2062, (512, 512), (0, 1)) # t2065: "cuda:0 i64[512, 512]" - # t2066 = prims.ge(t2064, t2065) # t2066: "cuda:0 b8[512, 512]" - # t2068 = ltorch.where(t2066, t2058, -float('inf')) # t2068: "cuda:0 bf16[1, 32, 512, 512]" - # t2067 = prims.broadcast_in_dim(t2066, (1, 32, 512, 512), (2, 3)) # t2067: "cuda:0 b8[1, 32, 512, 512]" - # t2068 = prims.where(t2067, t2058, -float('inf')) # t2068: "cuda:0 bf16[1, 32, 512, 512]" - # t2079 = ltorch._softmax(t2068, -1, dtype=None) # t2079: "cuda:0 bf16[1, 32, 512, 512]" - # t2069 = ltorch.to(t2068, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t2069: "cuda:0 f32[1, 32, 512, 512]" - # t2069 = prims.convert_element_type(t2068, dtypes.float32) # t2069: "cuda:0 f32[1, 32, 512, 512]" - # t2071 = ltorch.amax(t2069, -1, True) # t2071: "cuda:0 f32[1, 32, 512, 1]" - # t2070 = prims.amax(t2069, (3,)) # t2070: "cuda:0 f32[1, 32, 512]" - # t2071 = prims.broadcast_in_dim(t2070, [1, 32, 512, 1], [0, 1, 2]) # t2071: "cuda:0 f32[1, 32, 512, 1]" - # t2073 = ltorch.sub(t2069, t2071, alpha=None) # t2073: "cuda:0 f32[1, 32, 512, 512]" - # t2072 = prims.broadcast_in_dim(t2071, (1, 32, 512, 512), (0, 1, 2, 3)) # t2072: "cuda:0 f32[1, 32, 512, 512]" - # t2073 = prims.sub(t2069, t2072) # t2073: "cuda:0 f32[1, 32, 512, 512]" - # t2074 = ltorch.exp(t2073) # t2074: "cuda:0 f32[1, 32, 512, 512]" - # t2074 = prims.exp(t2073) # t2074: "cuda:0 f32[1, 32, 512, 512]" - # t2076 = ltorch.sum(t2074, -1, True, dtype=None) # t2076: "cuda:0 f32[1, 32, 512, 1]" - # t2075 = prims.sum(t2074, (3,)) # t2075: "cuda:0 f32[1, 32, 512]" - # t2076 = prims.broadcast_in_dim(t2075, [1, 32, 512, 1], [0, 1, 2]) # t2076: "cuda:0 f32[1, 32, 512, 1]" - # t2078 = ltorch.true_divide(t2074, t2076) # t2078: "cuda:0 f32[1, 32, 512, 512]" - # t2077 = prims.broadcast_in_dim(t2076, (1, 32, 512, 512), (0, 1, 2, 3)) # t2077: "cuda:0 f32[1, 32, 512, 512]" - # t2078 = prims.div(t2074, t2077) # t2078: "cuda:0 f32[1, 32, 512, 512]" - # t2079 = ltorch.to(t2078, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t2079: "cuda:0 bf16[1, 32, 512, 512]" - # t2079 = prims.convert_element_type(t2078, dtypes.bfloat16) # t2079: "cuda:0 bf16[1, 32, 512, 512]" - # t2080 = ltorch.matmul(t2079, t2016) # t2080: "cuda:0 bf16[1, 32, 512, 128]" - # t2080 = prims.matmul(t2079, t2016) # t2080: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t2081 = ltorch.transpose(t2080, 1, 2) # t2081: "cuda:0 bf16[1, 512, 32, 128]" - # t2081 = prims.transpose(t2080, (0, 2, 1, 3)) # t2081: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t2082 = ltorch.reshape(t2081, 1, 512, 4096) # t2082: "cuda:0 bf16[1, 512, 4096]" - # t2082 = prims.reshape(t2081, (1, 512, 4096)) # t2082: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2086 = ltorch.linear(t2082, t_transformer_h_12_attn_proj_weight, None) # t2086: "cuda:0 bf16[1, 512, 4096]" - # t2086 = prims.linear(t2082, t_transformer_h_12_attn_proj_weight, None) # t2086: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t2090 = ltorch.add(t2086, t1980, alpha=None) # t2090: "cuda:0 bf16[1, 512, 4096]" - # t2087 = prims.convert_element_type(t2086, dtypes.float32) # t2087: "cuda:0 f32[1, 512, 4096]" - # t2088 = prims.convert_element_type(t1980, dtypes.float32) # t2088: "cuda:0 f32[1, 512, 4096]" - # t2089 = prims.add(t2087, t2088) # t2089: "cuda:0 f32[1, 512, 4096]" - # t2090 = prims.convert_element_type(t2089, dtypes.bfloat16) # t2090: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t2091 = prims.convert_element_type(t2090, dtypes.float32) # t2091: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t2092 = ltorch.mul(t2091, t2091) # t2092: "cuda:0 f32[1, 512, 4096]" - # t2092 = prims.mul(t2091, t2091) # t2092: "cuda:0 f32[1, 512, 4096]" - t2096 = ltorch.mean(t2092, -1, True, dtype=None) # t2096: "cuda:0 f32[1, 512, 1]" - # t2094 = prims.sum(t2092, (2,)) # t2094: "cuda:0 f32[1, 512]" - # t2095 = prims.broadcast_in_dim(t2094, [1, 512, 1], [0, 1]) # t2095: "cuda:0 f32[1, 512, 1]" - # t2096 = ltorch.true_divide(t2095, 4096) # t2096: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t2096 = prims.div(t2095, 4096.0) # t2096: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t2098 = ltorch.add(t2096, 1e-05, alpha=None) # t2098: "cuda:0 f32[1, 512, 1]" - # t2098 = prims.add(t2096, 1e-05) # t2098: "cuda:0 f32[1, 512, 1]" - t2099 = ltorch.rsqrt(t2098) # t2099: "cuda:0 f32[1, 512, 1]" - # t2099 = prims.rsqrt(t2098) # t2099: "cuda:0 f32[1, 512, 1]" - t2101 = ltorch.mul(t2091, t2099) # t2101: "cuda:0 f32[1, 512, 4096]" - # t2100 = prims.broadcast_in_dim(t2099, (1, 512, 4096), (0, 1, 2)) # t2100: "cuda:0 f32[1, 512, 4096]" - # t2101 = prims.mul(t2091, t2100) # t2101: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t2102 = ltorch.to(t2101, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2102: "cuda:0 bf16[1, 512, 4096]" - # t2102 = prims.convert_element_type(t2101, dtypes.bfloat16) # t2102: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t2112 = ltorch.mul(t2102, t_transformer_h_12_norm_2_weight) # t2112: "cuda:0 bf16[1, 512, 4096]" - # t2108 = prims.broadcast_in_dim(t_transformer_h_12_norm_2_weight, (1, 512, 4096), (2,)) # t2108: "cuda:0 bf16[1, 512, 4096]" - # t2109 = prims.convert_element_type(t2102, dtypes.float32) # t2109: "cuda:0 f32[1, 512, 4096]" - # t2110 = prims.convert_element_type(t2108, dtypes.float32) # t2110: "cuda:0 f32[1, 512, 4096]" - # t2111 = prims.mul(t2109, t2110) # t2111: "cuda:0 f32[1, 512, 4096]" - # t2112 = prims.convert_element_type(t2111, dtypes.bfloat16) # t2112: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2117 = ltorch.linear(t2112, t_transformer_h_12_mlp_fc_1_weight, None) # t2117: "cuda:0 bf16[1, 512, 11008]" - # t2117 = prims.linear(t2112, t_transformer_h_12_mlp_fc_1_weight, None) # t2117: "cuda:0 bf16[1, 512, 11008]" - t2121 = ltorch.linear(t2112, t_transformer_h_12_mlp_fc_2_weight, None) # t2121: "cuda:0 bf16[1, 512, 11008]" - # t2121 = prims.linear(t2112, t_transformer_h_12_mlp_fc_2_weight, None) # t2121: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t2131 = ltorch.silu(t2117, False) # t2131: "cuda:0 bf16[1, 512, 11008]" - # t2122 = prims.convert_element_type(t2117, dtypes.float32) # t2122: "cuda:0 f32[1, 512, 11008]" - # t2123 = prims.neg(t2122) # t2123: "cuda:0 f32[1, 512, 11008]" - # t2124 = prims.exp(t2123) # t2124: "cuda:0 f32[1, 512, 11008]" - # t2125 = prims.add(1.0, t2124) # t2125: "cuda:0 f32[1, 512, 11008]" - # t2126 = prims.reciprocal(t2125) # t2126: "cuda:0 f32[1, 512, 11008]" - # t2127 = prims.convert_element_type(t2126, dtypes.bfloat16) # t2127: "cuda:0 bf16[1, 512, 11008]" - # t2128 = prims.convert_element_type(t2117, dtypes.float32) # t2128: "cuda:0 f32[1, 512, 11008]" - # t2129 = prims.convert_element_type(t2127, dtypes.float32) # t2129: "cuda:0 f32[1, 512, 11008]" - # t2130 = prims.mul(t2128, t2129) # t2130: "cuda:0 f32[1, 512, 11008]" - # t2131 = prims.convert_element_type(t2130, dtypes.bfloat16) # t2131: "cuda:0 bf16[1, 512, 11008]" - t2135 = ltorch.mul(t2131, t2121) # t2135: "cuda:0 bf16[1, 512, 11008]" - # t2132 = prims.convert_element_type(t2131, dtypes.float32) # t2132: "cuda:0 f32[1, 512, 11008]" - # t2133 = prims.convert_element_type(t2121, dtypes.float32) # t2133: "cuda:0 f32[1, 512, 11008]" - # t2134 = prims.mul(t2132, t2133) # t2134: "cuda:0 f32[1, 512, 11008]" - # t2135 = prims.convert_element_type(t2134, dtypes.bfloat16) # t2135: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2139 = ltorch.linear(t2135, t_transformer_h_12_mlp_proj_weight, None) # t2139: "cuda:0 bf16[1, 512, 4096]" - # t2139 = prims.linear(t2135, t_transformer_h_12_mlp_proj_weight, None) # t2139: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t2143 = ltorch.add(t2139, t2090, alpha=None) # t2143: "cuda:0 bf16[1, 512, 4096]" - # t2140 = prims.convert_element_type(t2139, dtypes.float32) # t2140: "cuda:0 f32[1, 512, 4096]" - # t2141 = prims.convert_element_type(t2090, dtypes.float32) # t2141: "cuda:0 f32[1, 512, 4096]" - # t2142 = prims.add(t2140, t2141) # t2142: "cuda:0 f32[1, 512, 4096]" - # t2143 = prims.convert_element_type(t2142, dtypes.bfloat16) # t2143: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t2145 = prims.convert_element_type(t2143, dtypes.float32) # t2145: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t2146 = ltorch.mul(t2145, t2145) # t2146: "cuda:0 f32[1, 512, 4096]" - # t2146 = prims.mul(t2145, t2145) # t2146: "cuda:0 f32[1, 512, 4096]" - t2150 = ltorch.mean(t2146, -1, True, dtype=None) # t2150: "cuda:0 f32[1, 512, 1]" - # t2148 = prims.sum(t2146, (2,)) # t2148: "cuda:0 f32[1, 512]" - # t2149 = prims.broadcast_in_dim(t2148, [1, 512, 1], [0, 1]) # t2149: "cuda:0 f32[1, 512, 1]" - # t2150 = ltorch.true_divide(t2149, 4096) # t2150: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t2150 = prims.div(t2149, 4096.0) # t2150: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t2152 = ltorch.add(t2150, 1e-05, alpha=None) # t2152: "cuda:0 f32[1, 512, 1]" - # t2152 = prims.add(t2150, 1e-05) # t2152: "cuda:0 f32[1, 512, 1]" - t2153 = ltorch.rsqrt(t2152) # t2153: "cuda:0 f32[1, 512, 1]" - # t2153 = prims.rsqrt(t2152) # t2153: "cuda:0 f32[1, 512, 1]" - t2155 = ltorch.mul(t2145, t2153) # t2155: "cuda:0 f32[1, 512, 4096]" - # t2154 = prims.broadcast_in_dim(t2153, (1, 512, 4096), (0, 1, 2)) # t2154: "cuda:0 f32[1, 512, 4096]" - # t2155 = prims.mul(t2145, t2154) # t2155: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t2156 = ltorch.to(t2155, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2156: "cuda:0 bf16[1, 512, 4096]" - # t2156 = prims.convert_element_type(t2155, dtypes.bfloat16) # t2156: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t2166 = ltorch.mul(t2156, t_transformer_h_13_norm_1_weight) # t2166: "cuda:0 bf16[1, 512, 4096]" - # t2162 = prims.broadcast_in_dim(t_transformer_h_13_norm_1_weight, (1, 512, 4096), (2,)) # t2162: "cuda:0 bf16[1, 512, 4096]" - # t2163 = prims.convert_element_type(t2156, dtypes.float32) # t2163: "cuda:0 f32[1, 512, 4096]" - # t2164 = prims.convert_element_type(t2162, dtypes.float32) # t2164: "cuda:0 f32[1, 512, 4096]" - # t2165 = prims.mul(t2163, t2164) # t2165: "cuda:0 f32[1, 512, 4096]" - # t2166 = prims.convert_element_type(t2165, dtypes.bfloat16) # t2166: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2171 = ltorch.linear(t2166, t_transformer_h_13_attn_attn_weight, None) # t2171: "cuda:0 bf16[1, 512, 12288]" - # t2171 = prims.linear(t2166, t_transformer_h_13_attn_attn_weight, None) # t2171: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t2172 = ltorch.view(t2171, 1, 512, 32, 3, 128) # t2172: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t2172 = ltorch.reshape(t2171, (1, 512, 32, 3, 128)) # t2172: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t2172 = prims.reshape(t2171, (1, 512, 32, 3, 128)) # t2172: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t2173 = ltorch.permute(t2172, 0, 2, 3, 1, 4) # t2173: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t2173 = prims.transpose(t2172, (0, 2, 3, 1, 4)) # t2173: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t2174, t2175, t2176) = ltorch.split(t2173, (1, 1, 1), 2) - # t2174 = prims.slice_prim(t2173, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t2174: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2175 = prims.slice_prim(t2173, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t2175: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2176 = prims.slice_prim(t2173, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t2176: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t2177 = ltorch.reshape(t2174, 1, -1, 512, 128) # t2177: "cuda:0 bf16[1, 32, 512, 128]" - # t2177 = prims.reshape(t2174, (1, 32, 512, 128)) # t2177: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t2178 = ltorch.reshape(t2175, 1, -1, 512, 128) # t2178: "cuda:0 bf16[1, 32, 512, 128]" - # t2178 = prims.reshape(t2175, (1, 32, 512, 128)) # t2178: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t2179 = ltorch.reshape(t2176, 1, -1, 512, 128) # t2179: "cuda:0 bf16[1, 32, 512, 128]" - # t2179 = prims.reshape(t2176, (1, 32, 512, 128)) # t2179: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t2180 = ltorch.getitem(t2177, (..., slice(None, 128, None))) # t2180: "cuda:0 bf16[1, 32, 512, 128]" - # t2180 = prims.slice_prim(t2177, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2180: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t2181 = ltorch.getitem(t2180, (..., slice(None, 64, None))) # t2181: "cuda:0 bf16[1, 32, 512, 64]" - # t2181 = prims.slice_prim(t2180, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2181: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t2182 = ltorch.getitem(t2180, (..., slice(64, None, None))) # t2182: "cuda:0 bf16[1, 32, 512, 64]" - # t2182 = prims.slice_prim(t2180, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2182: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t2185 = ltorch.neg(t2182) # t2185: "cuda:0 bf16[1, 32, 512, 64]" - # t2183 = prims.convert_element_type(t2182, dtypes.float32) # t2183: "cuda:0 f32[1, 32, 512, 64]" - # t2184 = prims.neg(t2183) # t2184: "cuda:0 f32[1, 32, 512, 64]" - # t2185 = prims.convert_element_type(t2184, dtypes.bfloat16) # t2185: "cuda:0 bf16[1, 32, 512, 64]" - t2186 = ltorch.cat((t2185, t2181), -1) # t2186: "cuda:0 bf16[1, 32, 512, 128]" - # t2186 = prims.cat((t2185, t2181), -1) # t2186: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t2189 = ltorch.mul(t2180, cos) # t2189: "cuda:0 f32[1, 32, 512, 128]" - # t2187 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t2187: "cuda:0 f32[1, 32, 512, 128]" - # t2188 = prims.convert_element_type(t2180, dtypes.float32) # t2188: "cuda:0 f32[1, 32, 512, 128]" - # t2189 = prims.mul(t2188, t2187) # t2189: "cuda:0 f32[1, 32, 512, 128]" - t2192 = ltorch.mul(t2186, sin) # t2192: "cuda:0 f32[1, 32, 512, 128]" - # t2190 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t2190: "cuda:0 f32[1, 32, 512, 128]" - # t2191 = prims.convert_element_type(t2186, dtypes.float32) # t2191: "cuda:0 f32[1, 32, 512, 128]" - # t2192 = prims.mul(t2191, t2190) # t2192: "cuda:0 f32[1, 32, 512, 128]" - t2193 = ltorch.add(t2189, t2192, alpha=None) # t2193: "cuda:0 f32[1, 32, 512, 128]" - # t2193 = prims.add(t2189, t2192) # t2193: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t2194 = ltorch.to(t2193, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2194: "cuda:0 bf16[1, 32, 512, 128]" - # t2194 = prims.convert_element_type(t2193, dtypes.bfloat16) # t2194: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t2195 = ltorch.getitem(t2178, (..., slice(None, 128, None))) # t2195: "cuda:0 bf16[1, 32, 512, 128]" - # t2195 = prims.slice_prim(t2178, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2195: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t2196 = ltorch.getitem(t2195, (..., slice(None, 64, None))) # t2196: "cuda:0 bf16[1, 32, 512, 64]" - # t2196 = prims.slice_prim(t2195, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2196: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t2197 = ltorch.getitem(t2195, (..., slice(64, None, None))) # t2197: "cuda:0 bf16[1, 32, 512, 64]" - # t2197 = prims.slice_prim(t2195, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2197: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t2200 = ltorch.neg(t2197) # t2200: "cuda:0 bf16[1, 32, 512, 64]" - # t2198 = prims.convert_element_type(t2197, dtypes.float32) # t2198: "cuda:0 f32[1, 32, 512, 64]" - # t2199 = prims.neg(t2198) # t2199: "cuda:0 f32[1, 32, 512, 64]" - # t2200 = prims.convert_element_type(t2199, dtypes.bfloat16) # t2200: "cuda:0 bf16[1, 32, 512, 64]" - t2201 = ltorch.cat((t2200, t2196), -1) # t2201: "cuda:0 bf16[1, 32, 512, 128]" - # t2201 = prims.cat((t2200, t2196), -1) # t2201: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t2204 = ltorch.mul(t2195, cos) # t2204: "cuda:0 f32[1, 32, 512, 128]" - # t2202 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t2202: "cuda:0 f32[1, 32, 512, 128]" - # t2203 = prims.convert_element_type(t2195, dtypes.float32) # t2203: "cuda:0 f32[1, 32, 512, 128]" - # t2204 = prims.mul(t2203, t2202) # t2204: "cuda:0 f32[1, 32, 512, 128]" - t2207 = ltorch.mul(t2201, sin) # t2207: "cuda:0 f32[1, 32, 512, 128]" - # t2205 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t2205: "cuda:0 f32[1, 32, 512, 128]" - # t2206 = prims.convert_element_type(t2201, dtypes.float32) # t2206: "cuda:0 f32[1, 32, 512, 128]" - # t2207 = prims.mul(t2206, t2205) # t2207: "cuda:0 f32[1, 32, 512, 128]" - t2208 = ltorch.add(t2204, t2207, alpha=None) # t2208: "cuda:0 f32[1, 32, 512, 128]" - # t2208 = prims.add(t2204, t2207) # t2208: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t2209 = ltorch.to(t2208, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2209: "cuda:0 bf16[1, 32, 512, 128]" - # t2209 = prims.convert_element_type(t2208, dtypes.bfloat16) # t2209: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t2210 = ltorch.getitem(t2177, (..., slice(128, None, None))) # t2210: "cuda:0 bf16[1, 32, 512, 0]" - # t2210 = prims.slice_prim(t2177, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2210: "cuda:0 bf16[1, 32, 512, 0]" - t2211 = ltorch.cat((t2194, t2210), -1) # t2211: "cuda:0 bf16[1, 32, 512, 128]" - # t2211 = prims.cat((t2194, t2210), -1) # t2211: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t2212 = ltorch.getitem(t2178, (..., slice(128, None, None))) # t2212: "cuda:0 bf16[1, 32, 512, 0]" - # t2212 = prims.slice_prim(t2178, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2212: "cuda:0 bf16[1, 32, 512, 0]" - t2213 = ltorch.cat((t2209, t2212), -1) # t2213: "cuda:0 bf16[1, 32, 512, 128]" - # t2213 = prims.cat((t2209, t2212), -1) # t2213: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t2243 = ltorch.scaled_dot_product_attention(t2211, t2213, t2179, None, 0.0, True, scale=0.08838834764831843) # t2243: "cuda:0 bf16[1, 32, 512, 128]" - # t2216 = ltorch.mul(t2211, 0.29730177875068026) # t2216: "cuda:0 bf16[1, 32, 512, 128]" - # t2214 = prims.convert_element_type(t2211, dtypes.float32) # t2214: "cuda:0 f32[1, 32, 512, 128]" - # t2215 = prims.mul(t2214, 0.29730177875068026) # t2215: "cuda:0 f32[1, 32, 512, 128]" - # t2216 = prims.convert_element_type(t2215, dtypes.bfloat16) # t2216: "cuda:0 bf16[1, 32, 512, 128]" - # t2217 = ltorch.transpose(t2213, -2, -1) # t2217: "cuda:0 bf16[1, 32, 128, 512]" - # t2217 = prims.transpose(t2213, (0, 1, 3, 2)) # t2217: "cuda:0 bf16[1, 32, 128, 512]" - # t2220 = ltorch.mul(t2217, 0.29730177875068026) # t2220: "cuda:0 bf16[1, 32, 128, 512]" - # t2218 = prims.convert_element_type(t2217, dtypes.float32) # t2218: "cuda:0 f32[1, 32, 128, 512]" - # t2219 = prims.mul(t2218, 0.29730177875068026) # t2219: "cuda:0 f32[1, 32, 128, 512]" - # t2220 = prims.convert_element_type(t2219, dtypes.bfloat16) # t2220: "cuda:0 bf16[1, 32, 128, 512]" - # t2221 = ltorch.matmul(t2216, t2220) # t2221: "cuda:0 bf16[1, 32, 512, 512]" - # t2221 = prims.matmul(t2216, t2220) # t2221: "cuda:0 bf16[1, 32, 512, 512]" - # t2231 = ltorch.tril(t2221, 0, fill_value=-float('inf')) # t2231: "cuda:0 bf16[1, 32, 512, 512]" - # t2222 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t2222: "cuda:0 i64[512]" - # t2222 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t2222: "cuda:0 i64[512]" - # t2223 = ltorch.unsqueeze(t2222, -1) # t2223: "cuda:0 i64[512, 1]" - # t2223 = prims.broadcast_in_dim(t2222, [512, 1], [0]) # t2223: "cuda:0 i64[512, 1]" - # t2224 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t2224: "cuda:0 i64[512]" - # t2224 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t2224: "cuda:0 i64[512]" - # t2225 = ltorch.unsqueeze(t2224, -2) # t2225: "cuda:0 i64[1, 512]" - # t2225 = prims.broadcast_in_dim(t2224, [1, 512], [1]) # t2225: "cuda:0 i64[1, 512]" - # t2226 = ltorch.add(t2223, 0, alpha=None) # t2226: "cuda:0 i64[512, 1]" - # t2226 = prims.add(t2223, 0) # t2226: "cuda:0 i64[512, 1]" - # t2229 = ltorch.ge(t2226, t2225) # t2229: "cuda:0 b8[512, 512]" - # t2227 = prims.broadcast_in_dim(t2226, (512, 512), (0, 1)) # t2227: "cuda:0 i64[512, 512]" - # t2228 = prims.broadcast_in_dim(t2225, (512, 512), (0, 1)) # t2228: "cuda:0 i64[512, 512]" - # t2229 = prims.ge(t2227, t2228) # t2229: "cuda:0 b8[512, 512]" - # t2231 = ltorch.where(t2229, t2221, -float('inf')) # t2231: "cuda:0 bf16[1, 32, 512, 512]" - # t2230 = prims.broadcast_in_dim(t2229, (1, 32, 512, 512), (2, 3)) # t2230: "cuda:0 b8[1, 32, 512, 512]" - # t2231 = prims.where(t2230, t2221, -float('inf')) # t2231: "cuda:0 bf16[1, 32, 512, 512]" - # t2242 = ltorch._softmax(t2231, -1, dtype=None) # t2242: "cuda:0 bf16[1, 32, 512, 512]" - # t2232 = ltorch.to(t2231, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t2232: "cuda:0 f32[1, 32, 512, 512]" - # t2232 = prims.convert_element_type(t2231, dtypes.float32) # t2232: "cuda:0 f32[1, 32, 512, 512]" - # t2234 = ltorch.amax(t2232, -1, True) # t2234: "cuda:0 f32[1, 32, 512, 1]" - # t2233 = prims.amax(t2232, (3,)) # t2233: "cuda:0 f32[1, 32, 512]" - # t2234 = prims.broadcast_in_dim(t2233, [1, 32, 512, 1], [0, 1, 2]) # t2234: "cuda:0 f32[1, 32, 512, 1]" - # t2236 = ltorch.sub(t2232, t2234, alpha=None) # t2236: "cuda:0 f32[1, 32, 512, 512]" - # t2235 = prims.broadcast_in_dim(t2234, (1, 32, 512, 512), (0, 1, 2, 3)) # t2235: "cuda:0 f32[1, 32, 512, 512]" - # t2236 = prims.sub(t2232, t2235) # t2236: "cuda:0 f32[1, 32, 512, 512]" - # t2237 = ltorch.exp(t2236) # t2237: "cuda:0 f32[1, 32, 512, 512]" - # t2237 = prims.exp(t2236) # t2237: "cuda:0 f32[1, 32, 512, 512]" - # t2239 = ltorch.sum(t2237, -1, True, dtype=None) # t2239: "cuda:0 f32[1, 32, 512, 1]" - # t2238 = prims.sum(t2237, (3,)) # t2238: "cuda:0 f32[1, 32, 512]" - # t2239 = prims.broadcast_in_dim(t2238, [1, 32, 512, 1], [0, 1, 2]) # t2239: "cuda:0 f32[1, 32, 512, 1]" - # t2241 = ltorch.true_divide(t2237, t2239) # t2241: "cuda:0 f32[1, 32, 512, 512]" - # t2240 = prims.broadcast_in_dim(t2239, (1, 32, 512, 512), (0, 1, 2, 3)) # t2240: "cuda:0 f32[1, 32, 512, 512]" - # t2241 = prims.div(t2237, t2240) # t2241: "cuda:0 f32[1, 32, 512, 512]" - # t2242 = ltorch.to(t2241, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t2242: "cuda:0 bf16[1, 32, 512, 512]" - # t2242 = prims.convert_element_type(t2241, dtypes.bfloat16) # t2242: "cuda:0 bf16[1, 32, 512, 512]" - # t2243 = ltorch.matmul(t2242, t2179) # t2243: "cuda:0 bf16[1, 32, 512, 128]" - # t2243 = prims.matmul(t2242, t2179) # t2243: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t2244 = ltorch.transpose(t2243, 1, 2) # t2244: "cuda:0 bf16[1, 512, 32, 128]" - # t2244 = prims.transpose(t2243, (0, 2, 1, 3)) # t2244: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t2245 = ltorch.reshape(t2244, 1, 512, 4096) # t2245: "cuda:0 bf16[1, 512, 4096]" - # t2245 = prims.reshape(t2244, (1, 512, 4096)) # t2245: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2249 = ltorch.linear(t2245, t_transformer_h_13_attn_proj_weight, None) # t2249: "cuda:0 bf16[1, 512, 4096]" - # t2249 = prims.linear(t2245, t_transformer_h_13_attn_proj_weight, None) # t2249: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t2253 = ltorch.add(t2249, t2143, alpha=None) # t2253: "cuda:0 bf16[1, 512, 4096]" - # t2250 = prims.convert_element_type(t2249, dtypes.float32) # t2250: "cuda:0 f32[1, 512, 4096]" - # t2251 = prims.convert_element_type(t2143, dtypes.float32) # t2251: "cuda:0 f32[1, 512, 4096]" - # t2252 = prims.add(t2250, t2251) # t2252: "cuda:0 f32[1, 512, 4096]" - # t2253 = prims.convert_element_type(t2252, dtypes.bfloat16) # t2253: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t2254 = prims.convert_element_type(t2253, dtypes.float32) # t2254: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t2255 = ltorch.mul(t2254, t2254) # t2255: "cuda:0 f32[1, 512, 4096]" - # t2255 = prims.mul(t2254, t2254) # t2255: "cuda:0 f32[1, 512, 4096]" - t2259 = ltorch.mean(t2255, -1, True, dtype=None) # t2259: "cuda:0 f32[1, 512, 1]" - # t2257 = prims.sum(t2255, (2,)) # t2257: "cuda:0 f32[1, 512]" - # t2258 = prims.broadcast_in_dim(t2257, [1, 512, 1], [0, 1]) # t2258: "cuda:0 f32[1, 512, 1]" - # t2259 = ltorch.true_divide(t2258, 4096) # t2259: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t2259 = prims.div(t2258, 4096.0) # t2259: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t2261 = ltorch.add(t2259, 1e-05, alpha=None) # t2261: "cuda:0 f32[1, 512, 1]" - # t2261 = prims.add(t2259, 1e-05) # t2261: "cuda:0 f32[1, 512, 1]" - t2262 = ltorch.rsqrt(t2261) # t2262: "cuda:0 f32[1, 512, 1]" - # t2262 = prims.rsqrt(t2261) # t2262: "cuda:0 f32[1, 512, 1]" - t2264 = ltorch.mul(t2254, t2262) # t2264: "cuda:0 f32[1, 512, 4096]" - # t2263 = prims.broadcast_in_dim(t2262, (1, 512, 4096), (0, 1, 2)) # t2263: "cuda:0 f32[1, 512, 4096]" - # t2264 = prims.mul(t2254, t2263) # t2264: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t2265 = ltorch.to(t2264, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2265: "cuda:0 bf16[1, 512, 4096]" - # t2265 = prims.convert_element_type(t2264, dtypes.bfloat16) # t2265: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t2275 = ltorch.mul(t2265, t_transformer_h_13_norm_2_weight) # t2275: "cuda:0 bf16[1, 512, 4096]" - # t2271 = prims.broadcast_in_dim(t_transformer_h_13_norm_2_weight, (1, 512, 4096), (2,)) # t2271: "cuda:0 bf16[1, 512, 4096]" - # t2272 = prims.convert_element_type(t2265, dtypes.float32) # t2272: "cuda:0 f32[1, 512, 4096]" - # t2273 = prims.convert_element_type(t2271, dtypes.float32) # t2273: "cuda:0 f32[1, 512, 4096]" - # t2274 = prims.mul(t2272, t2273) # t2274: "cuda:0 f32[1, 512, 4096]" - # t2275 = prims.convert_element_type(t2274, dtypes.bfloat16) # t2275: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2280 = ltorch.linear(t2275, t_transformer_h_13_mlp_fc_1_weight, None) # t2280: "cuda:0 bf16[1, 512, 11008]" - # t2280 = prims.linear(t2275, t_transformer_h_13_mlp_fc_1_weight, None) # t2280: "cuda:0 bf16[1, 512, 11008]" - t2284 = ltorch.linear(t2275, t_transformer_h_13_mlp_fc_2_weight, None) # t2284: "cuda:0 bf16[1, 512, 11008]" - # t2284 = prims.linear(t2275, t_transformer_h_13_mlp_fc_2_weight, None) # t2284: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t2294 = ltorch.silu(t2280, False) # t2294: "cuda:0 bf16[1, 512, 11008]" - # t2285 = prims.convert_element_type(t2280, dtypes.float32) # t2285: "cuda:0 f32[1, 512, 11008]" - # t2286 = prims.neg(t2285) # t2286: "cuda:0 f32[1, 512, 11008]" - # t2287 = prims.exp(t2286) # t2287: "cuda:0 f32[1, 512, 11008]" - # t2288 = prims.add(1.0, t2287) # t2288: "cuda:0 f32[1, 512, 11008]" - # t2289 = prims.reciprocal(t2288) # t2289: "cuda:0 f32[1, 512, 11008]" - # t2290 = prims.convert_element_type(t2289, dtypes.bfloat16) # t2290: "cuda:0 bf16[1, 512, 11008]" - # t2291 = prims.convert_element_type(t2280, dtypes.float32) # t2291: "cuda:0 f32[1, 512, 11008]" - # t2292 = prims.convert_element_type(t2290, dtypes.float32) # t2292: "cuda:0 f32[1, 512, 11008]" - # t2293 = prims.mul(t2291, t2292) # t2293: "cuda:0 f32[1, 512, 11008]" - # t2294 = prims.convert_element_type(t2293, dtypes.bfloat16) # t2294: "cuda:0 bf16[1, 512, 11008]" - t2298 = ltorch.mul(t2294, t2284) # t2298: "cuda:0 bf16[1, 512, 11008]" - # t2295 = prims.convert_element_type(t2294, dtypes.float32) # t2295: "cuda:0 f32[1, 512, 11008]" - # t2296 = prims.convert_element_type(t2284, dtypes.float32) # t2296: "cuda:0 f32[1, 512, 11008]" - # t2297 = prims.mul(t2295, t2296) # t2297: "cuda:0 f32[1, 512, 11008]" - # t2298 = prims.convert_element_type(t2297, dtypes.bfloat16) # t2298: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2302 = ltorch.linear(t2298, t_transformer_h_13_mlp_proj_weight, None) # t2302: "cuda:0 bf16[1, 512, 4096]" - # t2302 = prims.linear(t2298, t_transformer_h_13_mlp_proj_weight, None) # t2302: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t2306 = ltorch.add(t2302, t2253, alpha=None) # t2306: "cuda:0 bf16[1, 512, 4096]" - # t2303 = prims.convert_element_type(t2302, dtypes.float32) # t2303: "cuda:0 f32[1, 512, 4096]" - # t2304 = prims.convert_element_type(t2253, dtypes.float32) # t2304: "cuda:0 f32[1, 512, 4096]" - # t2305 = prims.add(t2303, t2304) # t2305: "cuda:0 f32[1, 512, 4096]" - # t2306 = prims.convert_element_type(t2305, dtypes.bfloat16) # t2306: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t2308 = prims.convert_element_type(t2306, dtypes.float32) # t2308: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t2309 = ltorch.mul(t2308, t2308) # t2309: "cuda:0 f32[1, 512, 4096]" - # t2309 = prims.mul(t2308, t2308) # t2309: "cuda:0 f32[1, 512, 4096]" - t2313 = ltorch.mean(t2309, -1, True, dtype=None) # t2313: "cuda:0 f32[1, 512, 1]" - # t2311 = prims.sum(t2309, (2,)) # t2311: "cuda:0 f32[1, 512]" - # t2312 = prims.broadcast_in_dim(t2311, [1, 512, 1], [0, 1]) # t2312: "cuda:0 f32[1, 512, 1]" - # t2313 = ltorch.true_divide(t2312, 4096) # t2313: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t2313 = prims.div(t2312, 4096.0) # t2313: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t2315 = ltorch.add(t2313, 1e-05, alpha=None) # t2315: "cuda:0 f32[1, 512, 1]" - # t2315 = prims.add(t2313, 1e-05) # t2315: "cuda:0 f32[1, 512, 1]" - t2316 = ltorch.rsqrt(t2315) # t2316: "cuda:0 f32[1, 512, 1]" - # t2316 = prims.rsqrt(t2315) # t2316: "cuda:0 f32[1, 512, 1]" - t2318 = ltorch.mul(t2308, t2316) # t2318: "cuda:0 f32[1, 512, 4096]" - # t2317 = prims.broadcast_in_dim(t2316, (1, 512, 4096), (0, 1, 2)) # t2317: "cuda:0 f32[1, 512, 4096]" - # t2318 = prims.mul(t2308, t2317) # t2318: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t2319 = ltorch.to(t2318, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2319: "cuda:0 bf16[1, 512, 4096]" - # t2319 = prims.convert_element_type(t2318, dtypes.bfloat16) # t2319: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t2329 = ltorch.mul(t2319, t_transformer_h_14_norm_1_weight) # t2329: "cuda:0 bf16[1, 512, 4096]" - # t2325 = prims.broadcast_in_dim(t_transformer_h_14_norm_1_weight, (1, 512, 4096), (2,)) # t2325: "cuda:0 bf16[1, 512, 4096]" - # t2326 = prims.convert_element_type(t2319, dtypes.float32) # t2326: "cuda:0 f32[1, 512, 4096]" - # t2327 = prims.convert_element_type(t2325, dtypes.float32) # t2327: "cuda:0 f32[1, 512, 4096]" - # t2328 = prims.mul(t2326, t2327) # t2328: "cuda:0 f32[1, 512, 4096]" - # t2329 = prims.convert_element_type(t2328, dtypes.bfloat16) # t2329: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2334 = ltorch.linear(t2329, t_transformer_h_14_attn_attn_weight, None) # t2334: "cuda:0 bf16[1, 512, 12288]" - # t2334 = prims.linear(t2329, t_transformer_h_14_attn_attn_weight, None) # t2334: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t2335 = ltorch.view(t2334, 1, 512, 32, 3, 128) # t2335: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t2335 = ltorch.reshape(t2334, (1, 512, 32, 3, 128)) # t2335: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t2335 = prims.reshape(t2334, (1, 512, 32, 3, 128)) # t2335: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t2336 = ltorch.permute(t2335, 0, 2, 3, 1, 4) # t2336: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t2336 = prims.transpose(t2335, (0, 2, 3, 1, 4)) # t2336: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t2337, t2338, t2339) = ltorch.split(t2336, (1, 1, 1), 2) - # t2337 = prims.slice_prim(t2336, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t2337: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2338 = prims.slice_prim(t2336, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t2338: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2339 = prims.slice_prim(t2336, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t2339: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t2340 = ltorch.reshape(t2337, 1, -1, 512, 128) # t2340: "cuda:0 bf16[1, 32, 512, 128]" - # t2340 = prims.reshape(t2337, (1, 32, 512, 128)) # t2340: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t2341 = ltorch.reshape(t2338, 1, -1, 512, 128) # t2341: "cuda:0 bf16[1, 32, 512, 128]" - # t2341 = prims.reshape(t2338, (1, 32, 512, 128)) # t2341: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t2342 = ltorch.reshape(t2339, 1, -1, 512, 128) # t2342: "cuda:0 bf16[1, 32, 512, 128]" - # t2342 = prims.reshape(t2339, (1, 32, 512, 128)) # t2342: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t2343 = ltorch.getitem(t2340, (..., slice(None, 128, None))) # t2343: "cuda:0 bf16[1, 32, 512, 128]" - # t2343 = prims.slice_prim(t2340, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2343: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t2344 = ltorch.getitem(t2343, (..., slice(None, 64, None))) # t2344: "cuda:0 bf16[1, 32, 512, 64]" - # t2344 = prims.slice_prim(t2343, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2344: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t2345 = ltorch.getitem(t2343, (..., slice(64, None, None))) # t2345: "cuda:0 bf16[1, 32, 512, 64]" - # t2345 = prims.slice_prim(t2343, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2345: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t2348 = ltorch.neg(t2345) # t2348: "cuda:0 bf16[1, 32, 512, 64]" - # t2346 = prims.convert_element_type(t2345, dtypes.float32) # t2346: "cuda:0 f32[1, 32, 512, 64]" - # t2347 = prims.neg(t2346) # t2347: "cuda:0 f32[1, 32, 512, 64]" - # t2348 = prims.convert_element_type(t2347, dtypes.bfloat16) # t2348: "cuda:0 bf16[1, 32, 512, 64]" - t2349 = ltorch.cat((t2348, t2344), -1) # t2349: "cuda:0 bf16[1, 32, 512, 128]" - # t2349 = prims.cat((t2348, t2344), -1) # t2349: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t2352 = ltorch.mul(t2343, cos) # t2352: "cuda:0 f32[1, 32, 512, 128]" - # t2350 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t2350: "cuda:0 f32[1, 32, 512, 128]" - # t2351 = prims.convert_element_type(t2343, dtypes.float32) # t2351: "cuda:0 f32[1, 32, 512, 128]" - # t2352 = prims.mul(t2351, t2350) # t2352: "cuda:0 f32[1, 32, 512, 128]" - t2355 = ltorch.mul(t2349, sin) # t2355: "cuda:0 f32[1, 32, 512, 128]" - # t2353 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t2353: "cuda:0 f32[1, 32, 512, 128]" - # t2354 = prims.convert_element_type(t2349, dtypes.float32) # t2354: "cuda:0 f32[1, 32, 512, 128]" - # t2355 = prims.mul(t2354, t2353) # t2355: "cuda:0 f32[1, 32, 512, 128]" - t2356 = ltorch.add(t2352, t2355, alpha=None) # t2356: "cuda:0 f32[1, 32, 512, 128]" - # t2356 = prims.add(t2352, t2355) # t2356: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t2357 = ltorch.to(t2356, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2357: "cuda:0 bf16[1, 32, 512, 128]" - # t2357 = prims.convert_element_type(t2356, dtypes.bfloat16) # t2357: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t2358 = ltorch.getitem(t2341, (..., slice(None, 128, None))) # t2358: "cuda:0 bf16[1, 32, 512, 128]" - # t2358 = prims.slice_prim(t2341, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2358: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t2359 = ltorch.getitem(t2358, (..., slice(None, 64, None))) # t2359: "cuda:0 bf16[1, 32, 512, 64]" - # t2359 = prims.slice_prim(t2358, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2359: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t2360 = ltorch.getitem(t2358, (..., slice(64, None, None))) # t2360: "cuda:0 bf16[1, 32, 512, 64]" - # t2360 = prims.slice_prim(t2358, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2360: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t2363 = ltorch.neg(t2360) # t2363: "cuda:0 bf16[1, 32, 512, 64]" - # t2361 = prims.convert_element_type(t2360, dtypes.float32) # t2361: "cuda:0 f32[1, 32, 512, 64]" - # t2362 = prims.neg(t2361) # t2362: "cuda:0 f32[1, 32, 512, 64]" - # t2363 = prims.convert_element_type(t2362, dtypes.bfloat16) # t2363: "cuda:0 bf16[1, 32, 512, 64]" - t2364 = ltorch.cat((t2363, t2359), -1) # t2364: "cuda:0 bf16[1, 32, 512, 128]" - # t2364 = prims.cat((t2363, t2359), -1) # t2364: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t2367 = ltorch.mul(t2358, cos) # t2367: "cuda:0 f32[1, 32, 512, 128]" - # t2365 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t2365: "cuda:0 f32[1, 32, 512, 128]" - # t2366 = prims.convert_element_type(t2358, dtypes.float32) # t2366: "cuda:0 f32[1, 32, 512, 128]" - # t2367 = prims.mul(t2366, t2365) # t2367: "cuda:0 f32[1, 32, 512, 128]" - t2370 = ltorch.mul(t2364, sin) # t2370: "cuda:0 f32[1, 32, 512, 128]" - # t2368 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t2368: "cuda:0 f32[1, 32, 512, 128]" - # t2369 = prims.convert_element_type(t2364, dtypes.float32) # t2369: "cuda:0 f32[1, 32, 512, 128]" - # t2370 = prims.mul(t2369, t2368) # t2370: "cuda:0 f32[1, 32, 512, 128]" - t2371 = ltorch.add(t2367, t2370, alpha=None) # t2371: "cuda:0 f32[1, 32, 512, 128]" - # t2371 = prims.add(t2367, t2370) # t2371: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t2372 = ltorch.to(t2371, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2372: "cuda:0 bf16[1, 32, 512, 128]" - # t2372 = prims.convert_element_type(t2371, dtypes.bfloat16) # t2372: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t2373 = ltorch.getitem(t2340, (..., slice(128, None, None))) # t2373: "cuda:0 bf16[1, 32, 512, 0]" - # t2373 = prims.slice_prim(t2340, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2373: "cuda:0 bf16[1, 32, 512, 0]" - t2374 = ltorch.cat((t2357, t2373), -1) # t2374: "cuda:0 bf16[1, 32, 512, 128]" - # t2374 = prims.cat((t2357, t2373), -1) # t2374: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t2375 = ltorch.getitem(t2341, (..., slice(128, None, None))) # t2375: "cuda:0 bf16[1, 32, 512, 0]" - # t2375 = prims.slice_prim(t2341, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2375: "cuda:0 bf16[1, 32, 512, 0]" - t2376 = ltorch.cat((t2372, t2375), -1) # t2376: "cuda:0 bf16[1, 32, 512, 128]" - # t2376 = prims.cat((t2372, t2375), -1) # t2376: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t2406 = ltorch.scaled_dot_product_attention(t2374, t2376, t2342, None, 0.0, True, scale=0.08838834764831843) # t2406: "cuda:0 bf16[1, 32, 512, 128]" - # t2379 = ltorch.mul(t2374, 0.29730177875068026) # t2379: "cuda:0 bf16[1, 32, 512, 128]" - # t2377 = prims.convert_element_type(t2374, dtypes.float32) # t2377: "cuda:0 f32[1, 32, 512, 128]" - # t2378 = prims.mul(t2377, 0.29730177875068026) # t2378: "cuda:0 f32[1, 32, 512, 128]" - # t2379 = prims.convert_element_type(t2378, dtypes.bfloat16) # t2379: "cuda:0 bf16[1, 32, 512, 128]" - # t2380 = ltorch.transpose(t2376, -2, -1) # t2380: "cuda:0 bf16[1, 32, 128, 512]" - # t2380 = prims.transpose(t2376, (0, 1, 3, 2)) # t2380: "cuda:0 bf16[1, 32, 128, 512]" - # t2383 = ltorch.mul(t2380, 0.29730177875068026) # t2383: "cuda:0 bf16[1, 32, 128, 512]" - # t2381 = prims.convert_element_type(t2380, dtypes.float32) # t2381: "cuda:0 f32[1, 32, 128, 512]" - # t2382 = prims.mul(t2381, 0.29730177875068026) # t2382: "cuda:0 f32[1, 32, 128, 512]" - # t2383 = prims.convert_element_type(t2382, dtypes.bfloat16) # t2383: "cuda:0 bf16[1, 32, 128, 512]" - # t2384 = ltorch.matmul(t2379, t2383) # t2384: "cuda:0 bf16[1, 32, 512, 512]" - # t2384 = prims.matmul(t2379, t2383) # t2384: "cuda:0 bf16[1, 32, 512, 512]" - # t2394 = ltorch.tril(t2384, 0, fill_value=-float('inf')) # t2394: "cuda:0 bf16[1, 32, 512, 512]" - # t2385 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t2385: "cuda:0 i64[512]" - # t2385 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t2385: "cuda:0 i64[512]" - # t2386 = ltorch.unsqueeze(t2385, -1) # t2386: "cuda:0 i64[512, 1]" - # t2386 = prims.broadcast_in_dim(t2385, [512, 1], [0]) # t2386: "cuda:0 i64[512, 1]" - # t2387 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t2387: "cuda:0 i64[512]" - # t2387 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t2387: "cuda:0 i64[512]" - # t2388 = ltorch.unsqueeze(t2387, -2) # t2388: "cuda:0 i64[1, 512]" - # t2388 = prims.broadcast_in_dim(t2387, [1, 512], [1]) # t2388: "cuda:0 i64[1, 512]" - # t2389 = ltorch.add(t2386, 0, alpha=None) # t2389: "cuda:0 i64[512, 1]" - # t2389 = prims.add(t2386, 0) # t2389: "cuda:0 i64[512, 1]" - # t2392 = ltorch.ge(t2389, t2388) # t2392: "cuda:0 b8[512, 512]" - # t2390 = prims.broadcast_in_dim(t2389, (512, 512), (0, 1)) # t2390: "cuda:0 i64[512, 512]" - # t2391 = prims.broadcast_in_dim(t2388, (512, 512), (0, 1)) # t2391: "cuda:0 i64[512, 512]" - # t2392 = prims.ge(t2390, t2391) # t2392: "cuda:0 b8[512, 512]" - # t2394 = ltorch.where(t2392, t2384, -float('inf')) # t2394: "cuda:0 bf16[1, 32, 512, 512]" - # t2393 = prims.broadcast_in_dim(t2392, (1, 32, 512, 512), (2, 3)) # t2393: "cuda:0 b8[1, 32, 512, 512]" - # t2394 = prims.where(t2393, t2384, -float('inf')) # t2394: "cuda:0 bf16[1, 32, 512, 512]" - # t2405 = ltorch._softmax(t2394, -1, dtype=None) # t2405: "cuda:0 bf16[1, 32, 512, 512]" - # t2395 = ltorch.to(t2394, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t2395: "cuda:0 f32[1, 32, 512, 512]" - # t2395 = prims.convert_element_type(t2394, dtypes.float32) # t2395: "cuda:0 f32[1, 32, 512, 512]" - # t2397 = ltorch.amax(t2395, -1, True) # t2397: "cuda:0 f32[1, 32, 512, 1]" - # t2396 = prims.amax(t2395, (3,)) # t2396: "cuda:0 f32[1, 32, 512]" - # t2397 = prims.broadcast_in_dim(t2396, [1, 32, 512, 1], [0, 1, 2]) # t2397: "cuda:0 f32[1, 32, 512, 1]" - # t2399 = ltorch.sub(t2395, t2397, alpha=None) # t2399: "cuda:0 f32[1, 32, 512, 512]" - # t2398 = prims.broadcast_in_dim(t2397, (1, 32, 512, 512), (0, 1, 2, 3)) # t2398: "cuda:0 f32[1, 32, 512, 512]" - # t2399 = prims.sub(t2395, t2398) # t2399: "cuda:0 f32[1, 32, 512, 512]" - # t2400 = ltorch.exp(t2399) # t2400: "cuda:0 f32[1, 32, 512, 512]" - # t2400 = prims.exp(t2399) # t2400: "cuda:0 f32[1, 32, 512, 512]" - # t2402 = ltorch.sum(t2400, -1, True, dtype=None) # t2402: "cuda:0 f32[1, 32, 512, 1]" - # t2401 = prims.sum(t2400, (3,)) # t2401: "cuda:0 f32[1, 32, 512]" - # t2402 = prims.broadcast_in_dim(t2401, [1, 32, 512, 1], [0, 1, 2]) # t2402: "cuda:0 f32[1, 32, 512, 1]" - # t2404 = ltorch.true_divide(t2400, t2402) # t2404: "cuda:0 f32[1, 32, 512, 512]" - # t2403 = prims.broadcast_in_dim(t2402, (1, 32, 512, 512), (0, 1, 2, 3)) # t2403: "cuda:0 f32[1, 32, 512, 512]" - # t2404 = prims.div(t2400, t2403) # t2404: "cuda:0 f32[1, 32, 512, 512]" - # t2405 = ltorch.to(t2404, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t2405: "cuda:0 bf16[1, 32, 512, 512]" - # t2405 = prims.convert_element_type(t2404, dtypes.bfloat16) # t2405: "cuda:0 bf16[1, 32, 512, 512]" - # t2406 = ltorch.matmul(t2405, t2342) # t2406: "cuda:0 bf16[1, 32, 512, 128]" - # t2406 = prims.matmul(t2405, t2342) # t2406: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t2407 = ltorch.transpose(t2406, 1, 2) # t2407: "cuda:0 bf16[1, 512, 32, 128]" - # t2407 = prims.transpose(t2406, (0, 2, 1, 3)) # t2407: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t2408 = ltorch.reshape(t2407, 1, 512, 4096) # t2408: "cuda:0 bf16[1, 512, 4096]" - # t2408 = prims.reshape(t2407, (1, 512, 4096)) # t2408: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2412 = ltorch.linear(t2408, t_transformer_h_14_attn_proj_weight, None) # t2412: "cuda:0 bf16[1, 512, 4096]" - # t2412 = prims.linear(t2408, t_transformer_h_14_attn_proj_weight, None) # t2412: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t2416 = ltorch.add(t2412, t2306, alpha=None) # t2416: "cuda:0 bf16[1, 512, 4096]" - # t2413 = prims.convert_element_type(t2412, dtypes.float32) # t2413: "cuda:0 f32[1, 512, 4096]" - # t2414 = prims.convert_element_type(t2306, dtypes.float32) # t2414: "cuda:0 f32[1, 512, 4096]" - # t2415 = prims.add(t2413, t2414) # t2415: "cuda:0 f32[1, 512, 4096]" - # t2416 = prims.convert_element_type(t2415, dtypes.bfloat16) # t2416: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t2417 = prims.convert_element_type(t2416, dtypes.float32) # t2417: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t2418 = ltorch.mul(t2417, t2417) # t2418: "cuda:0 f32[1, 512, 4096]" - # t2418 = prims.mul(t2417, t2417) # t2418: "cuda:0 f32[1, 512, 4096]" - t2422 = ltorch.mean(t2418, -1, True, dtype=None) # t2422: "cuda:0 f32[1, 512, 1]" - # t2420 = prims.sum(t2418, (2,)) # t2420: "cuda:0 f32[1, 512]" - # t2421 = prims.broadcast_in_dim(t2420, [1, 512, 1], [0, 1]) # t2421: "cuda:0 f32[1, 512, 1]" - # t2422 = ltorch.true_divide(t2421, 4096) # t2422: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t2422 = prims.div(t2421, 4096.0) # t2422: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t2424 = ltorch.add(t2422, 1e-05, alpha=None) # t2424: "cuda:0 f32[1, 512, 1]" - # t2424 = prims.add(t2422, 1e-05) # t2424: "cuda:0 f32[1, 512, 1]" - t2425 = ltorch.rsqrt(t2424) # t2425: "cuda:0 f32[1, 512, 1]" - # t2425 = prims.rsqrt(t2424) # t2425: "cuda:0 f32[1, 512, 1]" - t2427 = ltorch.mul(t2417, t2425) # t2427: "cuda:0 f32[1, 512, 4096]" - # t2426 = prims.broadcast_in_dim(t2425, (1, 512, 4096), (0, 1, 2)) # t2426: "cuda:0 f32[1, 512, 4096]" - # t2427 = prims.mul(t2417, t2426) # t2427: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t2428 = ltorch.to(t2427, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2428: "cuda:0 bf16[1, 512, 4096]" - # t2428 = prims.convert_element_type(t2427, dtypes.bfloat16) # t2428: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t2438 = ltorch.mul(t2428, t_transformer_h_14_norm_2_weight) # t2438: "cuda:0 bf16[1, 512, 4096]" - # t2434 = prims.broadcast_in_dim(t_transformer_h_14_norm_2_weight, (1, 512, 4096), (2,)) # t2434: "cuda:0 bf16[1, 512, 4096]" - # t2435 = prims.convert_element_type(t2428, dtypes.float32) # t2435: "cuda:0 f32[1, 512, 4096]" - # t2436 = prims.convert_element_type(t2434, dtypes.float32) # t2436: "cuda:0 f32[1, 512, 4096]" - # t2437 = prims.mul(t2435, t2436) # t2437: "cuda:0 f32[1, 512, 4096]" - # t2438 = prims.convert_element_type(t2437, dtypes.bfloat16) # t2438: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2443 = ltorch.linear(t2438, t_transformer_h_14_mlp_fc_1_weight, None) # t2443: "cuda:0 bf16[1, 512, 11008]" - # t2443 = prims.linear(t2438, t_transformer_h_14_mlp_fc_1_weight, None) # t2443: "cuda:0 bf16[1, 512, 11008]" - t2447 = ltorch.linear(t2438, t_transformer_h_14_mlp_fc_2_weight, None) # t2447: "cuda:0 bf16[1, 512, 11008]" - # t2447 = prims.linear(t2438, t_transformer_h_14_mlp_fc_2_weight, None) # t2447: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t2457 = ltorch.silu(t2443, False) # t2457: "cuda:0 bf16[1, 512, 11008]" - # t2448 = prims.convert_element_type(t2443, dtypes.float32) # t2448: "cuda:0 f32[1, 512, 11008]" - # t2449 = prims.neg(t2448) # t2449: "cuda:0 f32[1, 512, 11008]" - # t2450 = prims.exp(t2449) # t2450: "cuda:0 f32[1, 512, 11008]" - # t2451 = prims.add(1.0, t2450) # t2451: "cuda:0 f32[1, 512, 11008]" - # t2452 = prims.reciprocal(t2451) # t2452: "cuda:0 f32[1, 512, 11008]" - # t2453 = prims.convert_element_type(t2452, dtypes.bfloat16) # t2453: "cuda:0 bf16[1, 512, 11008]" - # t2454 = prims.convert_element_type(t2443, dtypes.float32) # t2454: "cuda:0 f32[1, 512, 11008]" - # t2455 = prims.convert_element_type(t2453, dtypes.float32) # t2455: "cuda:0 f32[1, 512, 11008]" - # t2456 = prims.mul(t2454, t2455) # t2456: "cuda:0 f32[1, 512, 11008]" - # t2457 = prims.convert_element_type(t2456, dtypes.bfloat16) # t2457: "cuda:0 bf16[1, 512, 11008]" - t2461 = ltorch.mul(t2457, t2447) # t2461: "cuda:0 bf16[1, 512, 11008]" - # t2458 = prims.convert_element_type(t2457, dtypes.float32) # t2458: "cuda:0 f32[1, 512, 11008]" - # t2459 = prims.convert_element_type(t2447, dtypes.float32) # t2459: "cuda:0 f32[1, 512, 11008]" - # t2460 = prims.mul(t2458, t2459) # t2460: "cuda:0 f32[1, 512, 11008]" - # t2461 = prims.convert_element_type(t2460, dtypes.bfloat16) # t2461: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2465 = ltorch.linear(t2461, t_transformer_h_14_mlp_proj_weight, None) # t2465: "cuda:0 bf16[1, 512, 4096]" - # t2465 = prims.linear(t2461, t_transformer_h_14_mlp_proj_weight, None) # t2465: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t2469 = ltorch.add(t2465, t2416, alpha=None) # t2469: "cuda:0 bf16[1, 512, 4096]" - # t2466 = prims.convert_element_type(t2465, dtypes.float32) # t2466: "cuda:0 f32[1, 512, 4096]" - # t2467 = prims.convert_element_type(t2416, dtypes.float32) # t2467: "cuda:0 f32[1, 512, 4096]" - # t2468 = prims.add(t2466, t2467) # t2468: "cuda:0 f32[1, 512, 4096]" - # t2469 = prims.convert_element_type(t2468, dtypes.bfloat16) # t2469: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t2471 = prims.convert_element_type(t2469, dtypes.float32) # t2471: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t2472 = ltorch.mul(t2471, t2471) # t2472: "cuda:0 f32[1, 512, 4096]" - # t2472 = prims.mul(t2471, t2471) # t2472: "cuda:0 f32[1, 512, 4096]" - t2476 = ltorch.mean(t2472, -1, True, dtype=None) # t2476: "cuda:0 f32[1, 512, 1]" - # t2474 = prims.sum(t2472, (2,)) # t2474: "cuda:0 f32[1, 512]" - # t2475 = prims.broadcast_in_dim(t2474, [1, 512, 1], [0, 1]) # t2475: "cuda:0 f32[1, 512, 1]" - # t2476 = ltorch.true_divide(t2475, 4096) # t2476: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t2476 = prims.div(t2475, 4096.0) # t2476: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t2478 = ltorch.add(t2476, 1e-05, alpha=None) # t2478: "cuda:0 f32[1, 512, 1]" - # t2478 = prims.add(t2476, 1e-05) # t2478: "cuda:0 f32[1, 512, 1]" - t2479 = ltorch.rsqrt(t2478) # t2479: "cuda:0 f32[1, 512, 1]" - # t2479 = prims.rsqrt(t2478) # t2479: "cuda:0 f32[1, 512, 1]" - t2481 = ltorch.mul(t2471, t2479) # t2481: "cuda:0 f32[1, 512, 4096]" - # t2480 = prims.broadcast_in_dim(t2479, (1, 512, 4096), (0, 1, 2)) # t2480: "cuda:0 f32[1, 512, 4096]" - # t2481 = prims.mul(t2471, t2480) # t2481: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t2482 = ltorch.to(t2481, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2482: "cuda:0 bf16[1, 512, 4096]" - # t2482 = prims.convert_element_type(t2481, dtypes.bfloat16) # t2482: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t2492 = ltorch.mul(t2482, t_transformer_h_15_norm_1_weight) # t2492: "cuda:0 bf16[1, 512, 4096]" - # t2488 = prims.broadcast_in_dim(t_transformer_h_15_norm_1_weight, (1, 512, 4096), (2,)) # t2488: "cuda:0 bf16[1, 512, 4096]" - # t2489 = prims.convert_element_type(t2482, dtypes.float32) # t2489: "cuda:0 f32[1, 512, 4096]" - # t2490 = prims.convert_element_type(t2488, dtypes.float32) # t2490: "cuda:0 f32[1, 512, 4096]" - # t2491 = prims.mul(t2489, t2490) # t2491: "cuda:0 f32[1, 512, 4096]" - # t2492 = prims.convert_element_type(t2491, dtypes.bfloat16) # t2492: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2497 = ltorch.linear(t2492, t_transformer_h_15_attn_attn_weight, None) # t2497: "cuda:0 bf16[1, 512, 12288]" - # t2497 = prims.linear(t2492, t_transformer_h_15_attn_attn_weight, None) # t2497: "cuda:0 bf16[1, 512, 12288]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:220: qkv = qkv.view(B, T, self.config.n_query_groups, total_qkv, self.config.head_size) - t2498 = ltorch.view(t2497, 1, 512, 32, 3, 128) # t2498: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t2498 = ltorch.reshape(t2497, (1, 512, 32, 3, 128)) # t2498: "cuda:0 bf16[1, 512, 32, 3, 128]" - # t2498 = prims.reshape(t2497, (1, 512, 32, 3, 128)) # t2498: "cuda:0 bf16[1, 512, 32, 3, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:221: qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs) - t2499 = ltorch.permute(t2498, 0, 2, 3, 1, 4) # t2499: "cuda:0 bf16[1, 32, 3, 512, 128]" - # t2499 = prims.transpose(t2498, (0, 2, 3, 1, 4)) # t2499: "cuda:0 bf16[1, 32, 3, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:224: q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) - (t2500, t2501, t2502) = ltorch.split(t2499, (1, 1, 1), 2) - # t2500 = prims.slice_prim(t2499, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t2500: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2501 = prims.slice_prim(t2499, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t2501: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2502 = prims.slice_prim(t2499, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t2502: "cuda:0 bf16[1, 32, 1, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:233: q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) - t2503 = ltorch.reshape(t2500, 1, -1, 512, 128) # t2503: "cuda:0 bf16[1, 32, 512, 128]" - # t2503 = prims.reshape(t2500, (1, 32, 512, 128)) # t2503: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:234: k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) - t2504 = ltorch.reshape(t2501, 1, -1, 512, 128) # t2504: "cuda:0 bf16[1, 32, 512, 128]" - # t2504 = prims.reshape(t2501, (1, 32, 512, 128)) # t2504: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:235: v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) - t2505 = ltorch.reshape(t2502, 1, -1, 512, 128) # t2505: "cuda:0 bf16[1, 32, 512, 128]" - # t2505 = prims.reshape(t2502, (1, 32, 512, 128)) # t2505: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:237: q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - t2506 = ltorch.getitem(t2503, (..., slice(None, 128, None))) # t2506: "cuda:0 bf16[1, 32, 512, 128]" - # t2506 = prims.slice_prim(t2503, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2506: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t2507 = ltorch.getitem(t2506, (..., slice(None, 64, None))) # t2507: "cuda:0 bf16[1, 32, 512, 64]" - # t2507 = prims.slice_prim(t2506, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2507: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t2508 = ltorch.getitem(t2506, (..., slice(64, None, None))) # t2508: "cuda:0 bf16[1, 32, 512, 64]" - # t2508 = prims.slice_prim(t2506, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2508: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t2511 = ltorch.neg(t2508) # t2511: "cuda:0 bf16[1, 32, 512, 64]" - # t2509 = prims.convert_element_type(t2508, dtypes.float32) # t2509: "cuda:0 f32[1, 32, 512, 64]" - # t2510 = prims.neg(t2509) # t2510: "cuda:0 f32[1, 32, 512, 64]" - # t2511 = prims.convert_element_type(t2510, dtypes.bfloat16) # t2511: "cuda:0 bf16[1, 32, 512, 64]" - t2512 = ltorch.cat((t2511, t2507), -1) # t2512: "cuda:0 bf16[1, 32, 512, 128]" - # t2512 = prims.cat((t2511, t2507), -1) # t2512: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t2515 = ltorch.mul(t2506, cos) # t2515: "cuda:0 f32[1, 32, 512, 128]" - # t2513 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t2513: "cuda:0 f32[1, 32, 512, 128]" - # t2514 = prims.convert_element_type(t2506, dtypes.float32) # t2514: "cuda:0 f32[1, 32, 512, 128]" - # t2515 = prims.mul(t2514, t2513) # t2515: "cuda:0 f32[1, 32, 512, 128]" - t2518 = ltorch.mul(t2512, sin) # t2518: "cuda:0 f32[1, 32, 512, 128]" - # t2516 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t2516: "cuda:0 f32[1, 32, 512, 128]" - # t2517 = prims.convert_element_type(t2512, dtypes.float32) # t2517: "cuda:0 f32[1, 32, 512, 128]" - # t2518 = prims.mul(t2517, t2516) # t2518: "cuda:0 f32[1, 32, 512, 128]" - t2519 = ltorch.add(t2515, t2518, alpha=None) # t2519: "cuda:0 f32[1, 32, 512, 128]" - # t2519 = prims.add(t2515, t2518) # t2519: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t2520 = ltorch.to(t2519, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2520: "cuda:0 bf16[1, 32, 512, 128]" - # t2520 = prims.convert_element_type(t2519, dtypes.bfloat16) # t2520: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:238: k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - t2521 = ltorch.getitem(t2504, (..., slice(None, 128, None))) # t2521: "cuda:0 bf16[1, 32, 512, 128]" - # t2521 = prims.slice_prim(t2504, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2521: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:375: x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - t2522 = ltorch.getitem(t2521, (..., slice(None, 64, None))) # t2522: "cuda:0 bf16[1, 32, 512, 64]" - # t2522 = prims.slice_prim(t2521, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2522: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:376: x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - t2523 = ltorch.getitem(t2521, (..., slice(64, None, None))) # t2523: "cuda:0 bf16[1, 32, 512, 64]" - # t2523 = prims.slice_prim(t2521, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2523: "cuda:0 bf16[1, 32, 512, 64]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:377: rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - t2526 = ltorch.neg(t2523) # t2526: "cuda:0 bf16[1, 32, 512, 64]" - # t2524 = prims.convert_element_type(t2523, dtypes.float32) # t2524: "cuda:0 f32[1, 32, 512, 64]" - # t2525 = prims.neg(t2524) # t2525: "cuda:0 f32[1, 32, 512, 64]" - # t2526 = prims.convert_element_type(t2525, dtypes.bfloat16) # t2526: "cuda:0 bf16[1, 32, 512, 64]" - t2527 = ltorch.cat((t2526, t2522), -1) # t2527: "cuda:0 bf16[1, 32, 512, 128]" - # t2527 = prims.cat((t2526, t2522), -1) # t2527: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:378: roped = (x * cos) + (rotated * sin) - t2530 = ltorch.mul(t2521, cos) # t2530: "cuda:0 f32[1, 32, 512, 128]" - # t2528 = prims.broadcast_in_dim(cos, (1, 32, 512, 128), (2, 3)) # t2528: "cuda:0 f32[1, 32, 512, 128]" - # t2529 = prims.convert_element_type(t2521, dtypes.float32) # t2529: "cuda:0 f32[1, 32, 512, 128]" - # t2530 = prims.mul(t2529, t2528) # t2530: "cuda:0 f32[1, 32, 512, 128]" - t2533 = ltorch.mul(t2527, sin) # t2533: "cuda:0 f32[1, 32, 512, 128]" - # t2531 = prims.broadcast_in_dim(sin, (1, 32, 512, 128), (2, 3)) # t2531: "cuda:0 f32[1, 32, 512, 128]" - # t2532 = prims.convert_element_type(t2527, dtypes.float32) # t2532: "cuda:0 f32[1, 32, 512, 128]" - # t2533 = prims.mul(t2532, t2531) # t2533: "cuda:0 f32[1, 32, 512, 128]" - t2534 = ltorch.add(t2530, t2533, alpha=None) # t2534: "cuda:0 f32[1, 32, 512, 128]" - # t2534 = prims.add(t2530, t2533) # t2534: "cuda:0 f32[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:379: return roped.to(dtype=x.dtype) - t2535 = ltorch.to(t2534, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2535: "cuda:0 bf16[1, 32, 512, 128]" - # t2535 = prims.convert_element_type(t2534, dtypes.bfloat16) # t2535: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:239: q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) - t2536 = ltorch.getitem(t2503, (..., slice(128, None, None))) # t2536: "cuda:0 bf16[1, 32, 512, 0]" - # t2536 = prims.slice_prim(t2503, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2536: "cuda:0 bf16[1, 32, 512, 0]" - t2537 = ltorch.cat((t2520, t2536), -1) # t2537: "cuda:0 bf16[1, 32, 512, 128]" - # t2537 = prims.cat((t2520, t2536), -1) # t2537: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:240: k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) - t2538 = ltorch.getitem(t2504, (..., slice(128, None, None))) # t2538: "cuda:0 bf16[1, 32, 512, 0]" - # t2538 = prims.slice_prim(t2504, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2538: "cuda:0 bf16[1, 32, 512, 0]" - t2539 = ltorch.cat((t2535, t2538), -1) # t2539: "cuda:0 bf16[1, 32, 512, 128]" - # t2539 = prims.cat((t2535, t2538), -1) # t2539: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:258: y = torch.nn.functional.scaled_dot_product_attention( - t2569 = ltorch.scaled_dot_product_attention(t2537, t2539, t2505, None, 0.0, True, scale=0.08838834764831843) # t2569: "cuda:0 bf16[1, 32, 512, 128]" - # t2542 = ltorch.mul(t2537, 0.29730177875068026) # t2542: "cuda:0 bf16[1, 32, 512, 128]" - # t2540 = prims.convert_element_type(t2537, dtypes.float32) # t2540: "cuda:0 f32[1, 32, 512, 128]" - # t2541 = prims.mul(t2540, 0.29730177875068026) # t2541: "cuda:0 f32[1, 32, 512, 128]" - # t2542 = prims.convert_element_type(t2541, dtypes.bfloat16) # t2542: "cuda:0 bf16[1, 32, 512, 128]" - # t2543 = ltorch.transpose(t2539, -2, -1) # t2543: "cuda:0 bf16[1, 32, 128, 512]" - # t2543 = prims.transpose(t2539, (0, 1, 3, 2)) # t2543: "cuda:0 bf16[1, 32, 128, 512]" - # t2546 = ltorch.mul(t2543, 0.29730177875068026) # t2546: "cuda:0 bf16[1, 32, 128, 512]" - # t2544 = prims.convert_element_type(t2543, dtypes.float32) # t2544: "cuda:0 f32[1, 32, 128, 512]" - # t2545 = prims.mul(t2544, 0.29730177875068026) # t2545: "cuda:0 f32[1, 32, 128, 512]" - # t2546 = prims.convert_element_type(t2545, dtypes.bfloat16) # t2546: "cuda:0 bf16[1, 32, 128, 512]" - # t2547 = ltorch.matmul(t2542, t2546) # t2547: "cuda:0 bf16[1, 32, 512, 512]" - # t2547 = prims.matmul(t2542, t2546) # t2547: "cuda:0 bf16[1, 32, 512, 512]" - # t2557 = ltorch.tril(t2547, 0, fill_value=-float('inf')) # t2557: "cuda:0 bf16[1, 32, 512, 512]" - # t2548 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t2548: "cuda:0 i64[512]" - # t2548 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t2548: "cuda:0 i64[512]" - # t2549 = ltorch.unsqueeze(t2548, -1) # t2549: "cuda:0 i64[512, 1]" - # t2549 = prims.broadcast_in_dim(t2548, [512, 1], [0]) # t2549: "cuda:0 i64[512, 1]" - # t2550 = ltorch.arange(512, None, 1, device=devices.Device("cuda:0"), dtype=None) # t2550: "cuda:0 i64[512]" - # t2550 = prims.iota(512, start=0, step=1, device=devices.Device("cuda:0"), dtype=dtypes.int64) # t2550: "cuda:0 i64[512]" - # t2551 = ltorch.unsqueeze(t2550, -2) # t2551: "cuda:0 i64[1, 512]" - # t2551 = prims.broadcast_in_dim(t2550, [1, 512], [1]) # t2551: "cuda:0 i64[1, 512]" - # t2552 = ltorch.add(t2549, 0, alpha=None) # t2552: "cuda:0 i64[512, 1]" - # t2552 = prims.add(t2549, 0) # t2552: "cuda:0 i64[512, 1]" - # t2555 = ltorch.ge(t2552, t2551) # t2555: "cuda:0 b8[512, 512]" - # t2553 = prims.broadcast_in_dim(t2552, (512, 512), (0, 1)) # t2553: "cuda:0 i64[512, 512]" - # t2554 = prims.broadcast_in_dim(t2551, (512, 512), (0, 1)) # t2554: "cuda:0 i64[512, 512]" - # t2555 = prims.ge(t2553, t2554) # t2555: "cuda:0 b8[512, 512]" - # t2557 = ltorch.where(t2555, t2547, -float('inf')) # t2557: "cuda:0 bf16[1, 32, 512, 512]" - # t2556 = prims.broadcast_in_dim(t2555, (1, 32, 512, 512), (2, 3)) # t2556: "cuda:0 b8[1, 32, 512, 512]" - # t2557 = prims.where(t2556, t2547, -float('inf')) # t2557: "cuda:0 bf16[1, 32, 512, 512]" - # t2568 = ltorch._softmax(t2557, -1, dtype=None) # t2568: "cuda:0 bf16[1, 32, 512, 512]" - # t2558 = ltorch.to(t2557, dtypes.float32, None, device=None, dtype=None, copy=False, memory_format=None) # t2558: "cuda:0 f32[1, 32, 512, 512]" - # t2558 = prims.convert_element_type(t2557, dtypes.float32) # t2558: "cuda:0 f32[1, 32, 512, 512]" - # t2560 = ltorch.amax(t2558, -1, True) # t2560: "cuda:0 f32[1, 32, 512, 1]" - # t2559 = prims.amax(t2558, (3,)) # t2559: "cuda:0 f32[1, 32, 512]" - # t2560 = prims.broadcast_in_dim(t2559, [1, 32, 512, 1], [0, 1, 2]) # t2560: "cuda:0 f32[1, 32, 512, 1]" - # t2562 = ltorch.sub(t2558, t2560, alpha=None) # t2562: "cuda:0 f32[1, 32, 512, 512]" - # t2561 = prims.broadcast_in_dim(t2560, (1, 32, 512, 512), (0, 1, 2, 3)) # t2561: "cuda:0 f32[1, 32, 512, 512]" - # t2562 = prims.sub(t2558, t2561) # t2562: "cuda:0 f32[1, 32, 512, 512]" - # t2563 = ltorch.exp(t2562) # t2563: "cuda:0 f32[1, 32, 512, 512]" - # t2563 = prims.exp(t2562) # t2563: "cuda:0 f32[1, 32, 512, 512]" - # t2565 = ltorch.sum(t2563, -1, True, dtype=None) # t2565: "cuda:0 f32[1, 32, 512, 1]" - # t2564 = prims.sum(t2563, (3,)) # t2564: "cuda:0 f32[1, 32, 512]" - # t2565 = prims.broadcast_in_dim(t2564, [1, 32, 512, 1], [0, 1, 2]) # t2565: "cuda:0 f32[1, 32, 512, 1]" - # t2567 = ltorch.true_divide(t2563, t2565) # t2567: "cuda:0 f32[1, 32, 512, 512]" - # t2566 = prims.broadcast_in_dim(t2565, (1, 32, 512, 512), (0, 1, 2, 3)) # t2566: "cuda:0 f32[1, 32, 512, 512]" - # t2567 = prims.div(t2563, t2566) # t2567: "cuda:0 f32[1, 32, 512, 512]" - # t2568 = ltorch.to(t2567, dtypes.bfloat16, None, device=None, dtype=None, copy=False, memory_format=None) # t2568: "cuda:0 bf16[1, 32, 512, 512]" - # t2568 = prims.convert_element_type(t2567, dtypes.bfloat16) # t2568: "cuda:0 bf16[1, 32, 512, 512]" - # t2569 = ltorch.matmul(t2568, t2505) # t2569: "cuda:0 bf16[1, 32, 512, 128]" - # t2569 = prims.matmul(t2568, t2505) # t2569: "cuda:0 bf16[1, 32, 512, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:261: return y.transpose(1, 2) - t2570 = ltorch.transpose(t2569, 1, 2) # t2570: "cuda:0 bf16[1, 512, 32, 128]" - # t2570 = prims.transpose(t2569, (0, 2, 1, 3)) # t2570: "cuda:0 bf16[1, 512, 32, 128]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:249: y = y.reshape(B, T, self.config.head_size * self.config.n_head) # re-assemble all head outputs side by side - t2571 = ltorch.reshape(t2570, 1, 512, 4096) # t2571: "cuda:0 bf16[1, 512, 4096]" - # t2571 = prims.reshape(t2570, (1, 512, 4096)) # t2571: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2575 = ltorch.linear(t2571, t_transformer_h_15_attn_proj_weight, None) # t2575: "cuda:0 bf16[1, 512, 4096]" - # t2575 = prims.linear(t2571, t_transformer_h_15_attn_proj_weight, None) # t2575: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:186: x = attention_output + x - t2579 = ltorch.add(t2575, t2469, alpha=None) # t2579: "cuda:0 bf16[1, 512, 4096]" - # t2576 = prims.convert_element_type(t2575, dtypes.float32) # t2576: "cuda:0 f32[1, 512, 4096]" - # t2577 = prims.convert_element_type(t2469, dtypes.float32) # t2577: "cuda:0 f32[1, 512, 4096]" - # t2578 = prims.add(t2576, t2577) # t2578: "cuda:0 f32[1, 512, 4096]" - # t2579 = prims.convert_element_type(t2578, dtypes.bfloat16) # t2579: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t2580 = prims.convert_element_type(t2579, dtypes.float32) # t2580: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t2581 = ltorch.mul(t2580, t2580) # t2581: "cuda:0 f32[1, 512, 4096]" - # t2581 = prims.mul(t2580, t2580) # t2581: "cuda:0 f32[1, 512, 4096]" - t2585 = ltorch.mean(t2581, -1, True, dtype=None) # t2585: "cuda:0 f32[1, 512, 1]" - # t2583 = prims.sum(t2581, (2,)) # t2583: "cuda:0 f32[1, 512]" - # t2584 = prims.broadcast_in_dim(t2583, [1, 512, 1], [0, 1]) # t2584: "cuda:0 f32[1, 512, 1]" - # t2585 = ltorch.true_divide(t2584, 4096) # t2585: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t2585 = prims.div(t2584, 4096.0) # t2585: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t2587 = ltorch.add(t2585, 1e-05, alpha=None) # t2587: "cuda:0 f32[1, 512, 1]" - # t2587 = prims.add(t2585, 1e-05) # t2587: "cuda:0 f32[1, 512, 1]" - t2588 = ltorch.rsqrt(t2587) # t2588: "cuda:0 f32[1, 512, 1]" - # t2588 = prims.rsqrt(t2587) # t2588: "cuda:0 f32[1, 512, 1]" - t2590 = ltorch.mul(t2580, t2588) # t2590: "cuda:0 f32[1, 512, 4096]" - # t2589 = prims.broadcast_in_dim(t2588, (1, 512, 4096), (0, 1, 2)) # t2589: "cuda:0 f32[1, 512, 4096]" - # t2590 = prims.mul(t2580, t2589) # t2590: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t2591 = ltorch.to(t2590, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2591: "cuda:0 bf16[1, 512, 4096]" - # t2591 = prims.convert_element_type(t2590, dtypes.bfloat16) # t2591: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t2601 = ltorch.mul(t2591, t_transformer_h_15_norm_2_weight) # t2601: "cuda:0 bf16[1, 512, 4096]" - # t2597 = prims.broadcast_in_dim(t_transformer_h_15_norm_2_weight, (1, 512, 4096), (2,)) # t2597: "cuda:0 bf16[1, 512, 4096]" - # t2598 = prims.convert_element_type(t2591, dtypes.float32) # t2598: "cuda:0 f32[1, 512, 4096]" - # t2599 = prims.convert_element_type(t2597, dtypes.float32) # t2599: "cuda:0 f32[1, 512, 4096]" - # t2600 = prims.mul(t2598, t2599) # t2600: "cuda:0 f32[1, 512, 4096]" - # t2601 = prims.convert_element_type(t2600, dtypes.bfloat16) # t2601: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2606 = ltorch.linear(t2601, t_transformer_h_15_mlp_fc_1_weight, None) # t2606: "cuda:0 bf16[1, 512, 11008]" - # t2606 = prims.linear(t2601, t_transformer_h_15_mlp_fc_1_weight, None) # t2606: "cuda:0 bf16[1, 512, 11008]" - t2610 = ltorch.linear(t2601, t_transformer_h_15_mlp_fc_2_weight, None) # t2610: "cuda:0 bf16[1, 512, 11008]" - # t2610 = prims.linear(t2601, t_transformer_h_15_mlp_fc_2_weight, None) # t2610: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:313: x = torch.nn.functional.silu(x_fc_1) * x_fc_2 - t2620 = ltorch.silu(t2606, False) # t2620: "cuda:0 bf16[1, 512, 11008]" - # t2611 = prims.convert_element_type(t2606, dtypes.float32) # t2611: "cuda:0 f32[1, 512, 11008]" - # t2612 = prims.neg(t2611) # t2612: "cuda:0 f32[1, 512, 11008]" - # t2613 = prims.exp(t2612) # t2613: "cuda:0 f32[1, 512, 11008]" - # t2614 = prims.add(1.0, t2613) # t2614: "cuda:0 f32[1, 512, 11008]" - # t2615 = prims.reciprocal(t2614) # t2615: "cuda:0 f32[1, 512, 11008]" - # t2616 = prims.convert_element_type(t2615, dtypes.bfloat16) # t2616: "cuda:0 bf16[1, 512, 11008]" - # t2617 = prims.convert_element_type(t2606, dtypes.float32) # t2617: "cuda:0 f32[1, 512, 11008]" - # t2618 = prims.convert_element_type(t2616, dtypes.float32) # t2618: "cuda:0 f32[1, 512, 11008]" - # t2619 = prims.mul(t2617, t2618) # t2619: "cuda:0 f32[1, 512, 11008]" - # t2620 = prims.convert_element_type(t2619, dtypes.bfloat16) # t2620: "cuda:0 bf16[1, 512, 11008]" - t2624 = ltorch.mul(t2620, t2610) # t2624: "cuda:0 bf16[1, 512, 11008]" - # t2621 = prims.convert_element_type(t2620, dtypes.float32) # t2621: "cuda:0 f32[1, 512, 11008]" - # t2622 = prims.convert_element_type(t2610, dtypes.float32) # t2622: "cuda:0 f32[1, 512, 11008]" - # t2623 = prims.mul(t2621, t2622) # t2623: "cuda:0 f32[1, 512, 11008]" - # t2624 = prims.convert_element_type(t2623, dtypes.bfloat16) # t2624: "cuda:0 bf16[1, 512, 11008]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2628 = ltorch.linear(t2624, t_transformer_h_15_mlp_proj_weight, None) # t2628: "cuda:0 bf16[1, 512, 4096]" - # t2628 = prims.linear(t2624, t_transformer_h_15_mlp_proj_weight, None) # t2628: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:187: x = self.mlp(self.norm_2(x)) + x - t2632 = ltorch.add(t2628, t2579, alpha=None) # t2632: "cuda:0 bf16[1, 512, 4096]" - # t2629 = prims.convert_element_type(t2628, dtypes.float32) # t2629: "cuda:0 f32[1, 512, 4096]" - # t2630 = prims.convert_element_type(t2579, dtypes.float32) # t2630: "cuda:0 f32[1, 512, 4096]" - # t2631 = prims.add(t2629, t2630) # t2631: "cuda:0 f32[1, 512, 4096]" - # t2632 = prims.convert_element_type(t2631, dtypes.bfloat16) # t2632: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:429: x = x.float() - t2633 = prims.convert_element_type(t2632, dtypes.float32) # t2633: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:431: norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) - t2634 = ltorch.mul(t2633, t2633) # t2634: "cuda:0 f32[1, 512, 4096]" - # t2634 = prims.mul(t2633, t2633) # t2634: "cuda:0 f32[1, 512, 4096]" - t2638 = ltorch.mean(t2634, -1, True, dtype=None) # t2638: "cuda:0 f32[1, 512, 1]" - # t2636 = prims.sum(t2634, (2,)) # t2636: "cuda:0 f32[1, 512]" - # t2637 = prims.broadcast_in_dim(t2636, [1, 512, 1], [0, 1]) # t2637: "cuda:0 f32[1, 512, 1]" - # t2638 = ltorch.true_divide(t2637, 4096) # t2638: "cuda:0 f32[1, 512, 1]" - # _ = prims.convert_element_type(4096, float) - # t2638 = prims.div(t2637, 4096.0) # t2638: "cuda:0 f32[1, 512, 1]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:432: x_normed = x * torch.rsqrt(norm_x + self.eps) - t2640 = ltorch.add(t2638, 1e-05, alpha=None) # t2640: "cuda:0 f32[1, 512, 1]" - # t2640 = prims.add(t2638, 1e-05) # t2640: "cuda:0 f32[1, 512, 1]" - t2641 = ltorch.rsqrt(t2640) # t2641: "cuda:0 f32[1, 512, 1]" - # t2641 = prims.rsqrt(t2640) # t2641: "cuda:0 f32[1, 512, 1]" - t2643 = ltorch.mul(t2633, t2641) # t2643: "cuda:0 f32[1, 512, 4096]" - # t2642 = prims.broadcast_in_dim(t2641, (1, 512, 4096), (0, 1, 2)) # t2642: "cuda:0 f32[1, 512, 4096]" - # t2643 = prims.mul(t2633, t2642) # t2643: "cuda:0 f32[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:433: x_normed = x_normed.to(dtype=dtype) - t2644 = ltorch.to(t2643, None, None, device=None, dtype=dtypes.bfloat16, copy=False, memory_format=None) # t2644: "cuda:0 bf16[1, 512, 4096]" - # t2644 = prims.convert_element_type(t2643, dtypes.bfloat16) # t2644: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/litgpt/model.py:438: return x_normed * self.weight - t2654 = ltorch.mul(t2644, t_transformer_ln_f_weight) # t2654: "cuda:0 bf16[1, 512, 4096]" - # t2650 = prims.broadcast_in_dim(t_transformer_ln_f_weight, (1, 512, 4096), (2,)) # t2650: "cuda:0 bf16[1, 512, 4096]" - # t2651 = prims.convert_element_type(t2644, dtypes.float32) # t2651: "cuda:0 f32[1, 512, 4096]" - # t2652 = prims.convert_element_type(t2650, dtypes.float32) # t2652: "cuda:0 f32[1, 512, 4096]" - # t2653 = prims.mul(t2651, t2652) # t2653: "cuda:0 f32[1, 512, 4096]" - # t2654 = prims.convert_element_type(t2653, dtypes.bfloat16) # t2654: "cuda:0 bf16[1, 512, 4096]" - - # /usr/local/lib/python3.10/dist-packages/torch/nn/modules/linear.py:125: return F.linear(input, self.weight, self.bias) - t2658 = ltorch.linear(t2654, t_lm_head_weight, None) # t2658: "cuda:0 bf16[1, 512, 32000]" - # t2658 = prims.linear(t2654, t_lm_head_weight, None) # t2658: "cuda:0 bf16[1, 512, 32000]" - return t2658 -============================================ END: primal_trace sort_data_parallel_syncs -============================================ START: primal_trace forward_and_backward_from_trace -# Constructed by Augmented forward pass -import thunder -import thunder.core.dtypes as dtypes -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(idx, tos1, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_1_weight, t_transformer_h_0_mlp_fc_2_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_weight, t_transformer_h_0_norm_2_weight, t_transformer_h_1_attn_attn_weight, t_transformer_h_1_attn_proj_weight, t_transformer_h_1_mlp_fc_1_weight, t_transformer_h_1_mlp_fc_2_weight, t_transformer_h_1_mlp_proj_weight, t_transformer_h_1_norm_1_weight, t_transformer_h_1_norm_2_weight, t_transformer_h_2_attn_attn_weight, t_transformer_h_2_attn_proj_weight, t_transformer_h_2_mlp_fc_1_weight, t_transformer_h_2_mlp_fc_2_weight, t_transformer_h_2_mlp_proj_weight, t_transformer_h_2_norm_1_weight, t_transformer_h_2_norm_2_weight, t_transformer_h_3_attn_attn_weight, t_transformer_h_3_attn_proj_weight, t_transformer_h_3_mlp_fc_1_weight, t_transformer_h_3_mlp_fc_2_weight, t_transformer_h_3_mlp_proj_weight, t_transformer_h_3_norm_1_weight, t_transformer_h_3_norm_2_weight, t_transformer_h_4_attn_attn_weight, t_transformer_h_4_attn_proj_weight, t_transformer_h_4_mlp_fc_1_weight, t_transformer_h_4_mlp_fc_2_weight, t_transformer_h_4_mlp_proj_weight, t_transformer_h_4_norm_1_weight, t_transformer_h_4_norm_2_weight, t_transformer_h_5_attn_attn_weight, t_transformer_h_5_attn_proj_weight, t_transformer_h_5_mlp_fc_1_weight, t_transformer_h_5_mlp_fc_2_weight, t_transformer_h_5_mlp_proj_weight, t_transformer_h_5_norm_1_weight, t_transformer_h_5_norm_2_weight, t_transformer_h_6_attn_attn_weight, t_transformer_h_6_attn_proj_weight, t_transformer_h_6_mlp_fc_1_weight, t_transformer_h_6_mlp_fc_2_weight, t_transformer_h_6_mlp_proj_weight, t_transformer_h_6_norm_1_weight, t_transformer_h_6_norm_2_weight, t_transformer_h_7_attn_attn_weight, t_transformer_h_7_attn_proj_weight, t_transformer_h_7_mlp_fc_1_weight, t_transformer_h_7_mlp_fc_2_weight, t_transformer_h_7_mlp_proj_weight, t_transformer_h_7_norm_1_weight, t_transformer_h_7_norm_2_weight, t_transformer_h_8_attn_attn_weight, t_transformer_h_8_attn_proj_weight, t_transformer_h_8_mlp_fc_1_weight, t_transformer_h_8_mlp_fc_2_weight, t_transformer_h_8_mlp_proj_weight, t_transformer_h_8_norm_1_weight, t_transformer_h_8_norm_2_weight, t_transformer_h_9_attn_attn_weight, t_transformer_h_9_attn_proj_weight, t_transformer_h_9_mlp_fc_1_weight, t_transformer_h_9_mlp_fc_2_weight, t_transformer_h_9_mlp_proj_weight, t_transformer_h_9_norm_1_weight, t_transformer_h_9_norm_2_weight, t_transformer_h_10_attn_attn_weight, t_transformer_h_10_attn_proj_weight, t_transformer_h_10_mlp_fc_1_weight, t_transformer_h_10_mlp_fc_2_weight, t_transformer_h_10_mlp_proj_weight, t_transformer_h_10_norm_1_weight, t_transformer_h_10_norm_2_weight, t_transformer_h_11_attn_attn_weight, t_transformer_h_11_attn_proj_weight, t_transformer_h_11_mlp_fc_1_weight, t_transformer_h_11_mlp_fc_2_weight, t_transformer_h_11_mlp_proj_weight, t_transformer_h_11_norm_1_weight, t_transformer_h_11_norm_2_weight, t_transformer_h_12_attn_attn_weight, t_transformer_h_12_attn_proj_weight, t_transformer_h_12_mlp_fc_1_weight, t_transformer_h_12_mlp_fc_2_weight, t_transformer_h_12_mlp_proj_weight, t_transformer_h_12_norm_1_weight, t_transformer_h_12_norm_2_weight, t_transformer_h_13_attn_attn_weight, t_transformer_h_13_attn_proj_weight, t_transformer_h_13_mlp_fc_1_weight, t_transformer_h_13_mlp_fc_2_weight, t_transformer_h_13_mlp_proj_weight, t_transformer_h_13_norm_1_weight, t_transformer_h_13_norm_2_weight, t_transformer_h_14_attn_attn_weight, t_transformer_h_14_attn_proj_weight, t_transformer_h_14_mlp_fc_1_weight, t_transformer_h_14_mlp_fc_2_weight, t_transformer_h_14_mlp_proj_weight, t_transformer_h_14_norm_1_weight, t_transformer_h_14_norm_2_weight, t_transformer_h_15_attn_attn_weight, t_transformer_h_15_attn_proj_weight, t_transformer_h_15_mlp_fc_1_weight, t_transformer_h_15_mlp_fc_2_weight, t_transformer_h_15_mlp_proj_weight, t_transformer_h_15_norm_1_weight, t_transformer_h_15_norm_2_weight, t_transformer_ln_f_weight, t_transformer_wte_weight): - # idx: "cuda:0 i64[1, 512]" - # tos1: "cuda:0 f32[4096, 128]" - # t_lm_head_weight: "cuda:0 bf16[32000, 4096]" - # t_sin: "cuda:0 f32[4096, 128]" - # t_transformer_h_0_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_0_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_0_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_0_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_0_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_0_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_0_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_1_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_1_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_1_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_1_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_1_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_1_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_1_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_2_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_2_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_2_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_2_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_2_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_2_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_2_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_3_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_3_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_3_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_3_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_3_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_3_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_3_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_4_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_4_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_4_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_4_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_4_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_4_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_4_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_5_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_5_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_5_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_5_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_5_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_5_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_5_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_6_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_6_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_6_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_6_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_6_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_6_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_6_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_7_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_7_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_7_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_7_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_7_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_7_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_7_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_8_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_8_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_8_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_8_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_8_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_8_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_8_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_9_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_9_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_9_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_9_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_9_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_9_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_9_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_10_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_10_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_10_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_10_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_10_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_10_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_10_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_11_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_11_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_11_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_11_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_11_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_11_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_11_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_12_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_12_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_12_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_12_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_12_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_12_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_12_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_13_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_13_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_13_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_13_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_13_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_13_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_13_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_14_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_14_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_14_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_14_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_14_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_14_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_14_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_15_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_15_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_15_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_15_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_15_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_15_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_15_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_ln_f_weight: "cuda:0 bf16[4096]" - # t_transformer_wte_weight: "cuda:0 bf16[32000, 4096]" - t0 = prims.slice_prim(tos1, [0, 0], [512, 128], [1, 1]) # t0: "cuda:0 f32[512, 128]" - t1 = prims.slice_prim(t_sin, [0, 0], [512, 128], [1, 1]) # t1: "cuda:0 f32[512, 128]" - t4 = ltorch.embedding(idx, t_transformer_wte_weight, None, None, 2.0, False, False) # t4: "cuda:0 bf16[1, 512, 4096]" - # t2 = ltorch.reshape(idx, [512]) # t2: "cuda:0 i64[512]" - # t2 = prims.reshape(idx, (512,)) # t2: "cuda:0 i64[512]" - # t3 = prims.take(t_transformer_wte_weight, t2, 0) # t3: "cuda:0 bf16[512, 4096]" - # t4 = ltorch.reshape(t3, [1, 512, 4096]) # t4: "cuda:0 bf16[1, 512, 4096]" - # t4 = prims.reshape(t3, (1, 512, 4096)) # t4: "cuda:0 bf16[1, 512, 4096]" - t5 = prims.convert_element_type(t4, dtypes.float32) # t5: "cuda:0 f32[1, 512, 4096]" - t6 = ltorch.mul(t5, t5) # t6: "cuda:0 f32[1, 512, 4096]" - # t6 = prims.mul(t5, t5) # t6: "cuda:0 f32[1, 512, 4096]" - t7 = prims.sum(t6, (2,)) # t7: "cuda:0 f32[1, 512]" - t8 = prims.broadcast_in_dim(t7, [1, 512, 1], [0, 1]) # t8: "cuda:0 f32[1, 512, 1]" - t9 = ltorch.true_divide(t8, 4096.0) # t9: "cuda:0 f32[1, 512, 1]" - # t9 = prims.div(t8, 4096.0) # t9: "cuda:0 f32[1, 512, 1]" - t10 = ltorch.add(t9, 1e-05, alpha=None) # t10: "cuda:0 f32[1, 512, 1]" - # t10 = prims.add(t9, 1e-05) # t10: "cuda:0 f32[1, 512, 1]" - t11 = prims.rsqrt(t10) # t11: "cuda:0 f32[1, 512, 1]" - t12 = prims.broadcast_in_dim(t11, (1, 512, 4096), (0, 1, 2)) # t12: "cuda:0 f32[1, 512, 4096]" - t13 = ltorch.mul(t5, t12) # t13: "cuda:0 f32[1, 512, 4096]" - # t13 = prims.mul(t5, t12) # t13: "cuda:0 f32[1, 512, 4096]" - t14 = prims.convert_element_type(t13, dtypes.bfloat16) # t14: "cuda:0 bf16[1, 512, 4096]" - t15 = prims.broadcast_in_dim(t_transformer_h_0_norm_1_weight, (1, 512, 4096), (2,)) # t15: "cuda:0 bf16[1, 512, 4096]" - t16 = prims.convert_element_type(t14, dtypes.float32) # t16: "cuda:0 f32[1, 512, 4096]" - t17 = prims.convert_element_type(t15, dtypes.float32) # t17: "cuda:0 f32[1, 512, 4096]" - t18 = ltorch.mul(t16, t17) # t18: "cuda:0 f32[1, 512, 4096]" - # t18 = prims.mul(t16, t17) # t18: "cuda:0 f32[1, 512, 4096]" - t19 = prims.convert_element_type(t18, dtypes.bfloat16) # t19: "cuda:0 bf16[1, 512, 4096]" - t20 = prims.linear(t19, t_transformer_h_0_attn_attn_weight, None) # t20: "cuda:0 bf16[1, 512, 12288]" - t21 = prims.reshape(t20, (1, 512, 32, 3, 128)) # t21: "cuda:0 bf16[1, 512, 32, 3, 128]" - t22 = prims.transpose(t21, (0, 2, 3, 1, 4)) # t22: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t23, t24, t25) = ltorch.split(t22, (1, 1, 1), 2) - # t23 = prims.slice_prim(t22, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t23: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t24 = prims.slice_prim(t22, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t24: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t25 = prims.slice_prim(t22, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t25: "cuda:0 bf16[1, 32, 1, 512, 128]" - t26 = prims.reshape(t23, (1, 32, 512, 128)) # t26: "cuda:0 bf16[1, 32, 512, 128]" - t32 = prims.reshape(t24, (1, 32, 512, 128)) # t32: "cuda:0 bf16[1, 32, 512, 128]" - t38 = prims.reshape(t25, (1, 32, 512, 128)) # t38: "cuda:0 bf16[1, 32, 512, 128]" - t39 = prims.slice_prim(t26, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t39: "cuda:0 bf16[1, 32, 512, 128]" - t40 = prims.slice_prim(t39, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t40: "cuda:0 bf16[1, 32, 512, 64]" - t41 = prims.slice_prim(t39, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t41: "cuda:0 bf16[1, 32, 512, 64]" - t42 = prims.convert_element_type(t41, dtypes.float32) # t42: "cuda:0 f32[1, 32, 512, 64]" - t43 = prims.neg(t42) # t43: "cuda:0 f32[1, 32, 512, 64]" - t44 = prims.convert_element_type(t43, dtypes.bfloat16) # t44: "cuda:0 bf16[1, 32, 512, 64]" - t45 = prims.cat((t44, t40), -1) # t45: "cuda:0 bf16[1, 32, 512, 128]" - t46 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t46: "cuda:0 f32[1, 32, 512, 128]" - t47 = prims.convert_element_type(t39, dtypes.float32) # t47: "cuda:0 f32[1, 32, 512, 128]" - t48 = ltorch.mul(t47, t46) # t48: "cuda:0 f32[1, 32, 512, 128]" - # t48 = prims.mul(t47, t46) # t48: "cuda:0 f32[1, 32, 512, 128]" - t49 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t49: "cuda:0 f32[1, 32, 512, 128]" - t50 = prims.convert_element_type(t45, dtypes.float32) # t50: "cuda:0 f32[1, 32, 512, 128]" - t51 = ltorch.mul(t50, t49) # t51: "cuda:0 f32[1, 32, 512, 128]" - # t51 = prims.mul(t50, t49) # t51: "cuda:0 f32[1, 32, 512, 128]" - t52 = ltorch.add(t48, t51, alpha=None) # t52: "cuda:0 f32[1, 32, 512, 128]" - # t52 = prims.add(t48, t51) # t52: "cuda:0 f32[1, 32, 512, 128]" - t53 = prims.convert_element_type(t52, dtypes.bfloat16) # t53: "cuda:0 bf16[1, 32, 512, 128]" - t54 = prims.slice_prim(t32, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t54: "cuda:0 bf16[1, 32, 512, 128]" - t55 = prims.slice_prim(t54, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t55: "cuda:0 bf16[1, 32, 512, 64]" - t56 = prims.slice_prim(t54, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t56: "cuda:0 bf16[1, 32, 512, 64]" - t57 = prims.convert_element_type(t56, dtypes.float32) # t57: "cuda:0 f32[1, 32, 512, 64]" - t58 = prims.neg(t57) # t58: "cuda:0 f32[1, 32, 512, 64]" - t59 = prims.convert_element_type(t58, dtypes.bfloat16) # t59: "cuda:0 bf16[1, 32, 512, 64]" - t61 = prims.cat((t59, t55), -1) # t61: "cuda:0 bf16[1, 32, 512, 128]" - t62 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t62: "cuda:0 f32[1, 32, 512, 128]" - t63 = prims.convert_element_type(t54, dtypes.float32) # t63: "cuda:0 f32[1, 32, 512, 128]" - t64 = ltorch.mul(t63, t62) # t64: "cuda:0 f32[1, 32, 512, 128]" - # t64 = prims.mul(t63, t62) # t64: "cuda:0 f32[1, 32, 512, 128]" - t65 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t65: "cuda:0 f32[1, 32, 512, 128]" - t66 = prims.convert_element_type(t61, dtypes.float32) # t66: "cuda:0 f32[1, 32, 512, 128]" - t67 = ltorch.mul(t66, t65) # t67: "cuda:0 f32[1, 32, 512, 128]" - # t67 = prims.mul(t66, t65) # t67: "cuda:0 f32[1, 32, 512, 128]" - t68 = ltorch.add(t64, t67, alpha=None) # t68: "cuda:0 f32[1, 32, 512, 128]" - # t68 = prims.add(t64, t67) # t68: "cuda:0 f32[1, 32, 512, 128]" - t69 = prims.convert_element_type(t68, dtypes.bfloat16) # t69: "cuda:0 bf16[1, 32, 512, 128]" - t70 = prims.slice_prim(t26, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t70: "cuda:0 bf16[1, 32, 512, 0]" - t71 = prims.cat((t53, t70), -1) # t71: "cuda:0 bf16[1, 32, 512, 128]" - t72 = prims.slice_prim(t32, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t72: "cuda:0 bf16[1, 32, 512, 0]" - t74 = prims.cat((t69, t72), -1) # t74: "cuda:0 bf16[1, 32, 512, 128]" - (t75, t76, t77, t78) = cudnn_sdpa_fwd(t71, t74, t38, None, 0.0, True, scale=0.08838834764831843) - t79 = prims.transpose(t75, (0, 2, 1, 3)) # t79: "cuda:0 bf16[1, 512, 32, 128]" - t80 = prims.reshape(t79, (1, 512, 4096)) # t80: "cuda:0 bf16[1, 512, 4096]" - t81 = prims.linear(t80, t_transformer_h_0_attn_proj_weight, None) # t81: "cuda:0 bf16[1, 512, 4096]" - t82 = prims.convert_element_type(t81, dtypes.float32) # t82: "cuda:0 f32[1, 512, 4096]" - t83 = prims.convert_element_type(t4, dtypes.float32) # t83: "cuda:0 f32[1, 512, 4096]" - t84 = ltorch.add(t82, t83, alpha=None) # t84: "cuda:0 f32[1, 512, 4096]" - # t84 = prims.add(t82, t83) # t84: "cuda:0 f32[1, 512, 4096]" - t85 = prims.convert_element_type(t84, dtypes.bfloat16) # t85: "cuda:0 bf16[1, 512, 4096]" - t86 = prims.convert_element_type(t85, dtypes.float32) # t86: "cuda:0 f32[1, 512, 4096]" - t87 = ltorch.mul(t86, t86) # t87: "cuda:0 f32[1, 512, 4096]" - # t87 = prims.mul(t86, t86) # t87: "cuda:0 f32[1, 512, 4096]" - t89 = prims.sum(t87, (2,)) # t89: "cuda:0 f32[1, 512]" - t90 = prims.broadcast_in_dim(t89, [1, 512, 1], [0, 1]) # t90: "cuda:0 f32[1, 512, 1]" - t92 = ltorch.true_divide(t90, 4096.0) # t92: "cuda:0 f32[1, 512, 1]" - # t92 = prims.div(t90, 4096.0) # t92: "cuda:0 f32[1, 512, 1]" - t94 = ltorch.add(t92, 1e-05, alpha=None) # t94: "cuda:0 f32[1, 512, 1]" - # t94 = prims.add(t92, 1e-05) # t94: "cuda:0 f32[1, 512, 1]" - t95 = prims.rsqrt(t94) # t95: "cuda:0 f32[1, 512, 1]" - t96 = prims.broadcast_in_dim(t95, (1, 512, 4096), (0, 1, 2)) # t96: "cuda:0 f32[1, 512, 4096]" - t97 = ltorch.mul(t86, t96) # t97: "cuda:0 f32[1, 512, 4096]" - # t97 = prims.mul(t86, t96) # t97: "cuda:0 f32[1, 512, 4096]" - t98 = prims.convert_element_type(t97, dtypes.bfloat16) # t98: "cuda:0 bf16[1, 512, 4096]" - t99 = prims.broadcast_in_dim(t_transformer_h_0_norm_2_weight, (1, 512, 4096), (2,)) # t99: "cuda:0 bf16[1, 512, 4096]" - t100 = prims.convert_element_type(t98, dtypes.float32) # t100: "cuda:0 f32[1, 512, 4096]" - t101 = prims.convert_element_type(t99, dtypes.float32) # t101: "cuda:0 f32[1, 512, 4096]" - t102 = ltorch.mul(t100, t101) # t102: "cuda:0 f32[1, 512, 4096]" - # t102 = prims.mul(t100, t101) # t102: "cuda:0 f32[1, 512, 4096]" - t103 = prims.convert_element_type(t102, dtypes.bfloat16) # t103: "cuda:0 bf16[1, 512, 4096]" - t104 = prims.linear(t103, t_transformer_h_0_mlp_fc_1_weight, None) # t104: "cuda:0 bf16[1, 512, 11008]" - t105 = prims.linear(t103, t_transformer_h_0_mlp_fc_2_weight, None) # t105: "cuda:0 bf16[1, 512, 11008]" - t106 = prims.convert_element_type(t104, dtypes.float32) # t106: "cuda:0 f32[1, 512, 11008]" - t107 = prims.neg(t106) # t107: "cuda:0 f32[1, 512, 11008]" - t108 = prims.exp(t107) # t108: "cuda:0 f32[1, 512, 11008]" - t109 = ltorch.add(1.0, t108, alpha=None) # t109: "cuda:0 f32[1, 512, 11008]" - # t109 = prims.add(1.0, t108) # t109: "cuda:0 f32[1, 512, 11008]" - t110 = prims.reciprocal(t109) # t110: "cuda:0 f32[1, 512, 11008]" - t111 = prims.convert_element_type(t110, dtypes.bfloat16) # t111: "cuda:0 bf16[1, 512, 11008]" - t112 = prims.convert_element_type(t104, dtypes.float32) # t112: "cuda:0 f32[1, 512, 11008]" - t113 = prims.convert_element_type(t111, dtypes.float32) # t113: "cuda:0 f32[1, 512, 11008]" - t114 = ltorch.mul(t112, t113) # t114: "cuda:0 f32[1, 512, 11008]" - # t114 = prims.mul(t112, t113) # t114: "cuda:0 f32[1, 512, 11008]" - t115 = prims.convert_element_type(t114, dtypes.bfloat16) # t115: "cuda:0 bf16[1, 512, 11008]" - t116 = prims.convert_element_type(t115, dtypes.float32) # t116: "cuda:0 f32[1, 512, 11008]" - t117 = prims.convert_element_type(t105, dtypes.float32) # t117: "cuda:0 f32[1, 512, 11008]" - t118 = ltorch.mul(t116, t117) # t118: "cuda:0 f32[1, 512, 11008]" - # t118 = prims.mul(t116, t117) # t118: "cuda:0 f32[1, 512, 11008]" - t119 = prims.convert_element_type(t118, dtypes.bfloat16) # t119: "cuda:0 bf16[1, 512, 11008]" - t120 = prims.linear(t119, t_transformer_h_0_mlp_proj_weight, None) # t120: "cuda:0 bf16[1, 512, 4096]" - t121 = prims.convert_element_type(t120, dtypes.float32) # t121: "cuda:0 f32[1, 512, 4096]" - t122 = prims.convert_element_type(t85, dtypes.float32) # t122: "cuda:0 f32[1, 512, 4096]" - t123 = ltorch.add(t121, t122, alpha=None) # t123: "cuda:0 f32[1, 512, 4096]" - # t123 = prims.add(t121, t122) # t123: "cuda:0 f32[1, 512, 4096]" - t124 = prims.convert_element_type(t123, dtypes.bfloat16) # t124: "cuda:0 bf16[1, 512, 4096]" - t125 = prims.convert_element_type(t124, dtypes.float32) # t125: "cuda:0 f32[1, 512, 4096]" - t126 = ltorch.mul(t125, t125) # t126: "cuda:0 f32[1, 512, 4096]" - # t126 = prims.mul(t125, t125) # t126: "cuda:0 f32[1, 512, 4096]" - t128 = prims.sum(t126, (2,)) # t128: "cuda:0 f32[1, 512]" - t129 = prims.broadcast_in_dim(t128, [1, 512, 1], [0, 1]) # t129: "cuda:0 f32[1, 512, 1]" - t131 = ltorch.true_divide(t129, 4096.0) # t131: "cuda:0 f32[1, 512, 1]" - # t131 = prims.div(t129, 4096.0) # t131: "cuda:0 f32[1, 512, 1]" - t133 = ltorch.add(t131, 1e-05, alpha=None) # t133: "cuda:0 f32[1, 512, 1]" - # t133 = prims.add(t131, 1e-05) # t133: "cuda:0 f32[1, 512, 1]" - t134 = prims.rsqrt(t133) # t134: "cuda:0 f32[1, 512, 1]" - t135 = prims.broadcast_in_dim(t134, (1, 512, 4096), (0, 1, 2)) # t135: "cuda:0 f32[1, 512, 4096]" - t136 = ltorch.mul(t125, t135) # t136: "cuda:0 f32[1, 512, 4096]" - # t136 = prims.mul(t125, t135) # t136: "cuda:0 f32[1, 512, 4096]" - t137 = prims.convert_element_type(t136, dtypes.bfloat16) # t137: "cuda:0 bf16[1, 512, 4096]" - t138 = prims.broadcast_in_dim(t_transformer_h_1_norm_1_weight, (1, 512, 4096), (2,)) # t138: "cuda:0 bf16[1, 512, 4096]" - t139 = prims.convert_element_type(t137, dtypes.float32) # t139: "cuda:0 f32[1, 512, 4096]" - t140 = prims.convert_element_type(t138, dtypes.float32) # t140: "cuda:0 f32[1, 512, 4096]" - t141 = ltorch.mul(t139, t140) # t141: "cuda:0 f32[1, 512, 4096]" - # t141 = prims.mul(t139, t140) # t141: "cuda:0 f32[1, 512, 4096]" - t142 = prims.convert_element_type(t141, dtypes.bfloat16) # t142: "cuda:0 bf16[1, 512, 4096]" - t143 = prims.linear(t142, t_transformer_h_1_attn_attn_weight, None) # t143: "cuda:0 bf16[1, 512, 12288]" - t149 = prims.reshape(t143, (1, 512, 32, 3, 128)) # t149: "cuda:0 bf16[1, 512, 32, 3, 128]" - t155 = prims.transpose(t149, (0, 2, 3, 1, 4)) # t155: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t156, t157, t158) = ltorch.split(t155, (1, 1, 1), 2) - # t156 = prims.slice_prim(t155, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t156: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t157 = prims.slice_prim(t155, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t157: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t158 = prims.slice_prim(t155, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t158: "cuda:0 bf16[1, 32, 1, 512, 128]" - t164 = prims.reshape(t156, (1, 32, 512, 128)) # t164: "cuda:0 bf16[1, 32, 512, 128]" - t170 = prims.reshape(t157, (1, 32, 512, 128)) # t170: "cuda:0 bf16[1, 32, 512, 128]" - t176 = prims.reshape(t158, (1, 32, 512, 128)) # t176: "cuda:0 bf16[1, 32, 512, 128]" - t177 = prims.slice_prim(t164, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t177: "cuda:0 bf16[1, 32, 512, 128]" - t178 = prims.slice_prim(t177, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t178: "cuda:0 bf16[1, 32, 512, 64]" - t179 = prims.slice_prim(t177, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t179: "cuda:0 bf16[1, 32, 512, 64]" - t180 = prims.convert_element_type(t179, dtypes.float32) # t180: "cuda:0 f32[1, 32, 512, 64]" - t181 = prims.neg(t180) # t181: "cuda:0 f32[1, 32, 512, 64]" - t182 = prims.convert_element_type(t181, dtypes.bfloat16) # t182: "cuda:0 bf16[1, 32, 512, 64]" - t184 = prims.cat((t182, t178), -1) # t184: "cuda:0 bf16[1, 32, 512, 128]" - t185 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t185: "cuda:0 f32[1, 32, 512, 128]" - t186 = prims.convert_element_type(t177, dtypes.float32) # t186: "cuda:0 f32[1, 32, 512, 128]" - t187 = ltorch.mul(t186, t185) # t187: "cuda:0 f32[1, 32, 512, 128]" - # t187 = prims.mul(t186, t185) # t187: "cuda:0 f32[1, 32, 512, 128]" - t188 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t188: "cuda:0 f32[1, 32, 512, 128]" - t189 = prims.convert_element_type(t184, dtypes.float32) # t189: "cuda:0 f32[1, 32, 512, 128]" - t190 = ltorch.mul(t189, t188) # t190: "cuda:0 f32[1, 32, 512, 128]" - # t190 = prims.mul(t189, t188) # t190: "cuda:0 f32[1, 32, 512, 128]" - t191 = ltorch.add(t187, t190, alpha=None) # t191: "cuda:0 f32[1, 32, 512, 128]" - # t191 = prims.add(t187, t190) # t191: "cuda:0 f32[1, 32, 512, 128]" - t192 = prims.convert_element_type(t191, dtypes.bfloat16) # t192: "cuda:0 bf16[1, 32, 512, 128]" - t193 = prims.slice_prim(t170, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t193: "cuda:0 bf16[1, 32, 512, 128]" - t194 = prims.slice_prim(t193, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t194: "cuda:0 bf16[1, 32, 512, 64]" - t195 = prims.slice_prim(t193, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t195: "cuda:0 bf16[1, 32, 512, 64]" - t196 = prims.convert_element_type(t195, dtypes.float32) # t196: "cuda:0 f32[1, 32, 512, 64]" - t197 = prims.neg(t196) # t197: "cuda:0 f32[1, 32, 512, 64]" - t198 = prims.convert_element_type(t197, dtypes.bfloat16) # t198: "cuda:0 bf16[1, 32, 512, 64]" - t200 = prims.cat((t198, t194), -1) # t200: "cuda:0 bf16[1, 32, 512, 128]" - t201 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t201: "cuda:0 f32[1, 32, 512, 128]" - t202 = prims.convert_element_type(t193, dtypes.float32) # t202: "cuda:0 f32[1, 32, 512, 128]" - t203 = ltorch.mul(t202, t201) # t203: "cuda:0 f32[1, 32, 512, 128]" - # t203 = prims.mul(t202, t201) # t203: "cuda:0 f32[1, 32, 512, 128]" - t204 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t204: "cuda:0 f32[1, 32, 512, 128]" - t205 = prims.convert_element_type(t200, dtypes.float32) # t205: "cuda:0 f32[1, 32, 512, 128]" - t206 = ltorch.mul(t205, t204) # t206: "cuda:0 f32[1, 32, 512, 128]" - # t206 = prims.mul(t205, t204) # t206: "cuda:0 f32[1, 32, 512, 128]" - t207 = ltorch.add(t203, t206, alpha=None) # t207: "cuda:0 f32[1, 32, 512, 128]" - # t207 = prims.add(t203, t206) # t207: "cuda:0 f32[1, 32, 512, 128]" - t208 = prims.convert_element_type(t207, dtypes.bfloat16) # t208: "cuda:0 bf16[1, 32, 512, 128]" - t209 = prims.slice_prim(t164, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t209: "cuda:0 bf16[1, 32, 512, 0]" - t211 = prims.cat((t192, t209), -1) # t211: "cuda:0 bf16[1, 32, 512, 128]" - t212 = prims.slice_prim(t170, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t212: "cuda:0 bf16[1, 32, 512, 0]" - t214 = prims.cat((t208, t212), -1) # t214: "cuda:0 bf16[1, 32, 512, 128]" - (t215, t216, t217, t218) = cudnn_sdpa_fwd(t211, t214, t176, None, 0.0, True, scale=0.08838834764831843) - t221 = prims.transpose(t215, (0, 2, 1, 3)) # t221: "cuda:0 bf16[1, 512, 32, 128]" - t225 = prims.reshape(t221, (1, 512, 4096)) # t225: "cuda:0 bf16[1, 512, 4096]" - t226 = prims.linear(t225, t_transformer_h_1_attn_proj_weight, None) # t226: "cuda:0 bf16[1, 512, 4096]" - t227 = prims.convert_element_type(t226, dtypes.float32) # t227: "cuda:0 f32[1, 512, 4096]" - t228 = prims.convert_element_type(t124, dtypes.float32) # t228: "cuda:0 f32[1, 512, 4096]" - t229 = ltorch.add(t227, t228, alpha=None) # t229: "cuda:0 f32[1, 512, 4096]" - # t229 = prims.add(t227, t228) # t229: "cuda:0 f32[1, 512, 4096]" - t230 = prims.convert_element_type(t229, dtypes.bfloat16) # t230: "cuda:0 bf16[1, 512, 4096]" - t231 = prims.convert_element_type(t230, dtypes.float32) # t231: "cuda:0 f32[1, 512, 4096]" - t232 = ltorch.mul(t231, t231) # t232: "cuda:0 f32[1, 512, 4096]" - # t232 = prims.mul(t231, t231) # t232: "cuda:0 f32[1, 512, 4096]" - t234 = prims.sum(t232, (2,)) # t234: "cuda:0 f32[1, 512]" - t235 = prims.broadcast_in_dim(t234, [1, 512, 1], [0, 1]) # t235: "cuda:0 f32[1, 512, 1]" - t237 = ltorch.true_divide(t235, 4096.0) # t237: "cuda:0 f32[1, 512, 1]" - # t237 = prims.div(t235, 4096.0) # t237: "cuda:0 f32[1, 512, 1]" - t239 = ltorch.add(t237, 1e-05, alpha=None) # t239: "cuda:0 f32[1, 512, 1]" - # t239 = prims.add(t237, 1e-05) # t239: "cuda:0 f32[1, 512, 1]" - t240 = prims.rsqrt(t239) # t240: "cuda:0 f32[1, 512, 1]" - t241 = prims.broadcast_in_dim(t240, (1, 512, 4096), (0, 1, 2)) # t241: "cuda:0 f32[1, 512, 4096]" - t242 = ltorch.mul(t231, t241) # t242: "cuda:0 f32[1, 512, 4096]" - # t242 = prims.mul(t231, t241) # t242: "cuda:0 f32[1, 512, 4096]" - t243 = prims.convert_element_type(t242, dtypes.bfloat16) # t243: "cuda:0 bf16[1, 512, 4096]" - t244 = prims.broadcast_in_dim(t_transformer_h_1_norm_2_weight, (1, 512, 4096), (2,)) # t244: "cuda:0 bf16[1, 512, 4096]" - t245 = prims.convert_element_type(t243, dtypes.float32) # t245: "cuda:0 f32[1, 512, 4096]" - t246 = prims.convert_element_type(t244, dtypes.float32) # t246: "cuda:0 f32[1, 512, 4096]" - t247 = ltorch.mul(t245, t246) # t247: "cuda:0 f32[1, 512, 4096]" - # t247 = prims.mul(t245, t246) # t247: "cuda:0 f32[1, 512, 4096]" - t248 = prims.convert_element_type(t247, dtypes.bfloat16) # t248: "cuda:0 bf16[1, 512, 4096]" - t249 = prims.linear(t248, t_transformer_h_1_mlp_fc_1_weight, None) # t249: "cuda:0 bf16[1, 512, 11008]" - t250 = prims.linear(t248, t_transformer_h_1_mlp_fc_2_weight, None) # t250: "cuda:0 bf16[1, 512, 11008]" - t251 = prims.convert_element_type(t249, dtypes.float32) # t251: "cuda:0 f32[1, 512, 11008]" - t252 = prims.neg(t251) # t252: "cuda:0 f32[1, 512, 11008]" - t253 = prims.exp(t252) # t253: "cuda:0 f32[1, 512, 11008]" - t254 = ltorch.add(1.0, t253, alpha=None) # t254: "cuda:0 f32[1, 512, 11008]" - # t254 = prims.add(1.0, t253) # t254: "cuda:0 f32[1, 512, 11008]" - t255 = prims.reciprocal(t254) # t255: "cuda:0 f32[1, 512, 11008]" - t256 = prims.convert_element_type(t255, dtypes.bfloat16) # t256: "cuda:0 bf16[1, 512, 11008]" - t257 = prims.convert_element_type(t249, dtypes.float32) # t257: "cuda:0 f32[1, 512, 11008]" - t258 = prims.convert_element_type(t256, dtypes.float32) # t258: "cuda:0 f32[1, 512, 11008]" - t259 = ltorch.mul(t257, t258) # t259: "cuda:0 f32[1, 512, 11008]" - # t259 = prims.mul(t257, t258) # t259: "cuda:0 f32[1, 512, 11008]" - t260 = prims.convert_element_type(t259, dtypes.bfloat16) # t260: "cuda:0 bf16[1, 512, 11008]" - t261 = prims.convert_element_type(t260, dtypes.float32) # t261: "cuda:0 f32[1, 512, 11008]" - t262 = prims.convert_element_type(t250, dtypes.float32) # t262: "cuda:0 f32[1, 512, 11008]" - t263 = ltorch.mul(t261, t262) # t263: "cuda:0 f32[1, 512, 11008]" - # t263 = prims.mul(t261, t262) # t263: "cuda:0 f32[1, 512, 11008]" - t264 = prims.convert_element_type(t263, dtypes.bfloat16) # t264: "cuda:0 bf16[1, 512, 11008]" - t265 = prims.linear(t264, t_transformer_h_1_mlp_proj_weight, None) # t265: "cuda:0 bf16[1, 512, 4096]" - t266 = prims.convert_element_type(t265, dtypes.float32) # t266: "cuda:0 f32[1, 512, 4096]" - t267 = prims.convert_element_type(t230, dtypes.float32) # t267: "cuda:0 f32[1, 512, 4096]" - t268 = ltorch.add(t266, t267, alpha=None) # t268: "cuda:0 f32[1, 512, 4096]" - # t268 = prims.add(t266, t267) # t268: "cuda:0 f32[1, 512, 4096]" - t269 = prims.convert_element_type(t268, dtypes.bfloat16) # t269: "cuda:0 bf16[1, 512, 4096]" - t270 = prims.convert_element_type(t269, dtypes.float32) # t270: "cuda:0 f32[1, 512, 4096]" - t271 = ltorch.mul(t270, t270) # t271: "cuda:0 f32[1, 512, 4096]" - # t271 = prims.mul(t270, t270) # t271: "cuda:0 f32[1, 512, 4096]" - t273 = prims.sum(t271, (2,)) # t273: "cuda:0 f32[1, 512]" - t274 = prims.broadcast_in_dim(t273, [1, 512, 1], [0, 1]) # t274: "cuda:0 f32[1, 512, 1]" - t276 = ltorch.true_divide(t274, 4096.0) # t276: "cuda:0 f32[1, 512, 1]" - # t276 = prims.div(t274, 4096.0) # t276: "cuda:0 f32[1, 512, 1]" - t278 = ltorch.add(t276, 1e-05, alpha=None) # t278: "cuda:0 f32[1, 512, 1]" - # t278 = prims.add(t276, 1e-05) # t278: "cuda:0 f32[1, 512, 1]" - t279 = prims.rsqrt(t278) # t279: "cuda:0 f32[1, 512, 1]" - t280 = prims.broadcast_in_dim(t279, (1, 512, 4096), (0, 1, 2)) # t280: "cuda:0 f32[1, 512, 4096]" - t281 = ltorch.mul(t270, t280) # t281: "cuda:0 f32[1, 512, 4096]" - # t281 = prims.mul(t270, t280) # t281: "cuda:0 f32[1, 512, 4096]" - t282 = prims.convert_element_type(t281, dtypes.bfloat16) # t282: "cuda:0 bf16[1, 512, 4096]" - t283 = prims.broadcast_in_dim(t_transformer_h_2_norm_1_weight, (1, 512, 4096), (2,)) # t283: "cuda:0 bf16[1, 512, 4096]" - t284 = prims.convert_element_type(t282, dtypes.float32) # t284: "cuda:0 f32[1, 512, 4096]" - t285 = prims.convert_element_type(t283, dtypes.float32) # t285: "cuda:0 f32[1, 512, 4096]" - t286 = ltorch.mul(t284, t285) # t286: "cuda:0 f32[1, 512, 4096]" - # t286 = prims.mul(t284, t285) # t286: "cuda:0 f32[1, 512, 4096]" - t287 = prims.convert_element_type(t286, dtypes.bfloat16) # t287: "cuda:0 bf16[1, 512, 4096]" - t288 = prims.linear(t287, t_transformer_h_2_attn_attn_weight, None) # t288: "cuda:0 bf16[1, 512, 12288]" - t294 = prims.reshape(t288, (1, 512, 32, 3, 128)) # t294: "cuda:0 bf16[1, 512, 32, 3, 128]" - t300 = prims.transpose(t294, (0, 2, 3, 1, 4)) # t300: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t301, t302, t303) = ltorch.split(t300, (1, 1, 1), 2) - # t301 = prims.slice_prim(t300, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t301: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t302 = prims.slice_prim(t300, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t302: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t303 = prims.slice_prim(t300, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t303: "cuda:0 bf16[1, 32, 1, 512, 128]" - t309 = prims.reshape(t301, (1, 32, 512, 128)) # t309: "cuda:0 bf16[1, 32, 512, 128]" - t315 = prims.reshape(t302, (1, 32, 512, 128)) # t315: "cuda:0 bf16[1, 32, 512, 128]" - t321 = prims.reshape(t303, (1, 32, 512, 128)) # t321: "cuda:0 bf16[1, 32, 512, 128]" - t322 = prims.slice_prim(t309, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t322: "cuda:0 bf16[1, 32, 512, 128]" - t323 = prims.slice_prim(t322, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t323: "cuda:0 bf16[1, 32, 512, 64]" - t324 = prims.slice_prim(t322, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t324: "cuda:0 bf16[1, 32, 512, 64]" - t325 = prims.convert_element_type(t324, dtypes.float32) # t325: "cuda:0 f32[1, 32, 512, 64]" - t326 = prims.neg(t325) # t326: "cuda:0 f32[1, 32, 512, 64]" - t327 = prims.convert_element_type(t326, dtypes.bfloat16) # t327: "cuda:0 bf16[1, 32, 512, 64]" - t329 = prims.cat((t327, t323), -1) # t329: "cuda:0 bf16[1, 32, 512, 128]" - t330 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t330: "cuda:0 f32[1, 32, 512, 128]" - t331 = prims.convert_element_type(t322, dtypes.float32) # t331: "cuda:0 f32[1, 32, 512, 128]" - t332 = ltorch.mul(t331, t330) # t332: "cuda:0 f32[1, 32, 512, 128]" - # t332 = prims.mul(t331, t330) # t332: "cuda:0 f32[1, 32, 512, 128]" - t333 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t333: "cuda:0 f32[1, 32, 512, 128]" - t334 = prims.convert_element_type(t329, dtypes.float32) # t334: "cuda:0 f32[1, 32, 512, 128]" - t335 = ltorch.mul(t334, t333) # t335: "cuda:0 f32[1, 32, 512, 128]" - # t335 = prims.mul(t334, t333) # t335: "cuda:0 f32[1, 32, 512, 128]" - t336 = ltorch.add(t332, t335, alpha=None) # t336: "cuda:0 f32[1, 32, 512, 128]" - # t336 = prims.add(t332, t335) # t336: "cuda:0 f32[1, 32, 512, 128]" - t337 = prims.convert_element_type(t336, dtypes.bfloat16) # t337: "cuda:0 bf16[1, 32, 512, 128]" - t338 = prims.slice_prim(t315, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t338: "cuda:0 bf16[1, 32, 512, 128]" - t339 = prims.slice_prim(t338, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t339: "cuda:0 bf16[1, 32, 512, 64]" - t340 = prims.slice_prim(t338, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t340: "cuda:0 bf16[1, 32, 512, 64]" - t341 = prims.convert_element_type(t340, dtypes.float32) # t341: "cuda:0 f32[1, 32, 512, 64]" - t342 = prims.neg(t341) # t342: "cuda:0 f32[1, 32, 512, 64]" - t343 = prims.convert_element_type(t342, dtypes.bfloat16) # t343: "cuda:0 bf16[1, 32, 512, 64]" - t345 = prims.cat((t343, t339), -1) # t345: "cuda:0 bf16[1, 32, 512, 128]" - t346 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t346: "cuda:0 f32[1, 32, 512, 128]" - t347 = prims.convert_element_type(t338, dtypes.float32) # t347: "cuda:0 f32[1, 32, 512, 128]" - t348 = ltorch.mul(t347, t346) # t348: "cuda:0 f32[1, 32, 512, 128]" - # t348 = prims.mul(t347, t346) # t348: "cuda:0 f32[1, 32, 512, 128]" - t349 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t349: "cuda:0 f32[1, 32, 512, 128]" - t350 = prims.convert_element_type(t345, dtypes.float32) # t350: "cuda:0 f32[1, 32, 512, 128]" - t351 = ltorch.mul(t350, t349) # t351: "cuda:0 f32[1, 32, 512, 128]" - # t351 = prims.mul(t350, t349) # t351: "cuda:0 f32[1, 32, 512, 128]" - t352 = ltorch.add(t348, t351, alpha=None) # t352: "cuda:0 f32[1, 32, 512, 128]" - # t352 = prims.add(t348, t351) # t352: "cuda:0 f32[1, 32, 512, 128]" - t353 = prims.convert_element_type(t352, dtypes.bfloat16) # t353: "cuda:0 bf16[1, 32, 512, 128]" - t354 = prims.slice_prim(t309, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t354: "cuda:0 bf16[1, 32, 512, 0]" - t356 = prims.cat((t337, t354), -1) # t356: "cuda:0 bf16[1, 32, 512, 128]" - t357 = prims.slice_prim(t315, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t357: "cuda:0 bf16[1, 32, 512, 0]" - t359 = prims.cat((t353, t357), -1) # t359: "cuda:0 bf16[1, 32, 512, 128]" - (t360, t361, t362, t363) = cudnn_sdpa_fwd(t356, t359, t321, None, 0.0, True, scale=0.08838834764831843) - t366 = prims.transpose(t360, (0, 2, 1, 3)) # t366: "cuda:0 bf16[1, 512, 32, 128]" - t370 = prims.reshape(t366, (1, 512, 4096)) # t370: "cuda:0 bf16[1, 512, 4096]" - t371 = prims.linear(t370, t_transformer_h_2_attn_proj_weight, None) # t371: "cuda:0 bf16[1, 512, 4096]" - t372 = prims.convert_element_type(t371, dtypes.float32) # t372: "cuda:0 f32[1, 512, 4096]" - t373 = prims.convert_element_type(t269, dtypes.float32) # t373: "cuda:0 f32[1, 512, 4096]" - t374 = ltorch.add(t372, t373, alpha=None) # t374: "cuda:0 f32[1, 512, 4096]" - # t374 = prims.add(t372, t373) # t374: "cuda:0 f32[1, 512, 4096]" - t375 = prims.convert_element_type(t374, dtypes.bfloat16) # t375: "cuda:0 bf16[1, 512, 4096]" - t376 = prims.convert_element_type(t375, dtypes.float32) # t376: "cuda:0 f32[1, 512, 4096]" - t377 = ltorch.mul(t376, t376) # t377: "cuda:0 f32[1, 512, 4096]" - # t377 = prims.mul(t376, t376) # t377: "cuda:0 f32[1, 512, 4096]" - t379 = prims.sum(t377, (2,)) # t379: "cuda:0 f32[1, 512]" - t380 = prims.broadcast_in_dim(t379, [1, 512, 1], [0, 1]) # t380: "cuda:0 f32[1, 512, 1]" - t382 = ltorch.true_divide(t380, 4096.0) # t382: "cuda:0 f32[1, 512, 1]" - # t382 = prims.div(t380, 4096.0) # t382: "cuda:0 f32[1, 512, 1]" - t384 = ltorch.add(t382, 1e-05, alpha=None) # t384: "cuda:0 f32[1, 512, 1]" - # t384 = prims.add(t382, 1e-05) # t384: "cuda:0 f32[1, 512, 1]" - t385 = prims.rsqrt(t384) # t385: "cuda:0 f32[1, 512, 1]" - t386 = prims.broadcast_in_dim(t385, (1, 512, 4096), (0, 1, 2)) # t386: "cuda:0 f32[1, 512, 4096]" - t387 = ltorch.mul(t376, t386) # t387: "cuda:0 f32[1, 512, 4096]" - # t387 = prims.mul(t376, t386) # t387: "cuda:0 f32[1, 512, 4096]" - t388 = prims.convert_element_type(t387, dtypes.bfloat16) # t388: "cuda:0 bf16[1, 512, 4096]" - t389 = prims.broadcast_in_dim(t_transformer_h_2_norm_2_weight, (1, 512, 4096), (2,)) # t389: "cuda:0 bf16[1, 512, 4096]" - t390 = prims.convert_element_type(t388, dtypes.float32) # t390: "cuda:0 f32[1, 512, 4096]" - t391 = prims.convert_element_type(t389, dtypes.float32) # t391: "cuda:0 f32[1, 512, 4096]" - t392 = ltorch.mul(t390, t391) # t392: "cuda:0 f32[1, 512, 4096]" - # t392 = prims.mul(t390, t391) # t392: "cuda:0 f32[1, 512, 4096]" - t393 = prims.convert_element_type(t392, dtypes.bfloat16) # t393: "cuda:0 bf16[1, 512, 4096]" - t394 = prims.linear(t393, t_transformer_h_2_mlp_fc_1_weight, None) # t394: "cuda:0 bf16[1, 512, 11008]" - t395 = prims.linear(t393, t_transformer_h_2_mlp_fc_2_weight, None) # t395: "cuda:0 bf16[1, 512, 11008]" - t396 = prims.convert_element_type(t394, dtypes.float32) # t396: "cuda:0 f32[1, 512, 11008]" - t397 = prims.neg(t396) # t397: "cuda:0 f32[1, 512, 11008]" - t398 = prims.exp(t397) # t398: "cuda:0 f32[1, 512, 11008]" - t399 = ltorch.add(1.0, t398, alpha=None) # t399: "cuda:0 f32[1, 512, 11008]" - # t399 = prims.add(1.0, t398) # t399: "cuda:0 f32[1, 512, 11008]" - t400 = prims.reciprocal(t399) # t400: "cuda:0 f32[1, 512, 11008]" - t401 = prims.convert_element_type(t400, dtypes.bfloat16) # t401: "cuda:0 bf16[1, 512, 11008]" - t402 = prims.convert_element_type(t394, dtypes.float32) # t402: "cuda:0 f32[1, 512, 11008]" - t403 = prims.convert_element_type(t401, dtypes.float32) # t403: "cuda:0 f32[1, 512, 11008]" - t404 = ltorch.mul(t402, t403) # t404: "cuda:0 f32[1, 512, 11008]" - # t404 = prims.mul(t402, t403) # t404: "cuda:0 f32[1, 512, 11008]" - t405 = prims.convert_element_type(t404, dtypes.bfloat16) # t405: "cuda:0 bf16[1, 512, 11008]" - t406 = prims.convert_element_type(t405, dtypes.float32) # t406: "cuda:0 f32[1, 512, 11008]" - t407 = prims.convert_element_type(t395, dtypes.float32) # t407: "cuda:0 f32[1, 512, 11008]" - t408 = ltorch.mul(t406, t407) # t408: "cuda:0 f32[1, 512, 11008]" - # t408 = prims.mul(t406, t407) # t408: "cuda:0 f32[1, 512, 11008]" - t409 = prims.convert_element_type(t408, dtypes.bfloat16) # t409: "cuda:0 bf16[1, 512, 11008]" - t410 = prims.linear(t409, t_transformer_h_2_mlp_proj_weight, None) # t410: "cuda:0 bf16[1, 512, 4096]" - t411 = prims.convert_element_type(t410, dtypes.float32) # t411: "cuda:0 f32[1, 512, 4096]" - t412 = prims.convert_element_type(t375, dtypes.float32) # t412: "cuda:0 f32[1, 512, 4096]" - t413 = ltorch.add(t411, t412, alpha=None) # t413: "cuda:0 f32[1, 512, 4096]" - # t413 = prims.add(t411, t412) # t413: "cuda:0 f32[1, 512, 4096]" - t414 = prims.convert_element_type(t413, dtypes.bfloat16) # t414: "cuda:0 bf16[1, 512, 4096]" - t415 = prims.convert_element_type(t414, dtypes.float32) # t415: "cuda:0 f32[1, 512, 4096]" - t416 = ltorch.mul(t415, t415) # t416: "cuda:0 f32[1, 512, 4096]" - # t416 = prims.mul(t415, t415) # t416: "cuda:0 f32[1, 512, 4096]" - t418 = prims.sum(t416, (2,)) # t418: "cuda:0 f32[1, 512]" - t419 = prims.broadcast_in_dim(t418, [1, 512, 1], [0, 1]) # t419: "cuda:0 f32[1, 512, 1]" - t421 = ltorch.true_divide(t419, 4096.0) # t421: "cuda:0 f32[1, 512, 1]" - # t421 = prims.div(t419, 4096.0) # t421: "cuda:0 f32[1, 512, 1]" - t423 = ltorch.add(t421, 1e-05, alpha=None) # t423: "cuda:0 f32[1, 512, 1]" - # t423 = prims.add(t421, 1e-05) # t423: "cuda:0 f32[1, 512, 1]" - t424 = prims.rsqrt(t423) # t424: "cuda:0 f32[1, 512, 1]" - t425 = prims.broadcast_in_dim(t424, (1, 512, 4096), (0, 1, 2)) # t425: "cuda:0 f32[1, 512, 4096]" - t426 = ltorch.mul(t415, t425) # t426: "cuda:0 f32[1, 512, 4096]" - # t426 = prims.mul(t415, t425) # t426: "cuda:0 f32[1, 512, 4096]" - t427 = prims.convert_element_type(t426, dtypes.bfloat16) # t427: "cuda:0 bf16[1, 512, 4096]" - t428 = prims.broadcast_in_dim(t_transformer_h_3_norm_1_weight, (1, 512, 4096), (2,)) # t428: "cuda:0 bf16[1, 512, 4096]" - t429 = prims.convert_element_type(t427, dtypes.float32) # t429: "cuda:0 f32[1, 512, 4096]" - t430 = prims.convert_element_type(t428, dtypes.float32) # t430: "cuda:0 f32[1, 512, 4096]" - t431 = ltorch.mul(t429, t430) # t431: "cuda:0 f32[1, 512, 4096]" - # t431 = prims.mul(t429, t430) # t431: "cuda:0 f32[1, 512, 4096]" - t432 = prims.convert_element_type(t431, dtypes.bfloat16) # t432: "cuda:0 bf16[1, 512, 4096]" - t433 = prims.linear(t432, t_transformer_h_3_attn_attn_weight, None) # t433: "cuda:0 bf16[1, 512, 12288]" - t439 = prims.reshape(t433, (1, 512, 32, 3, 128)) # t439: "cuda:0 bf16[1, 512, 32, 3, 128]" - t445 = prims.transpose(t439, (0, 2, 3, 1, 4)) # t445: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t446, t447, t448) = ltorch.split(t445, (1, 1, 1), 2) - # t446 = prims.slice_prim(t445, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t446: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t447 = prims.slice_prim(t445, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t447: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t448 = prims.slice_prim(t445, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t448: "cuda:0 bf16[1, 32, 1, 512, 128]" - t454 = prims.reshape(t446, (1, 32, 512, 128)) # t454: "cuda:0 bf16[1, 32, 512, 128]" - t460 = prims.reshape(t447, (1, 32, 512, 128)) # t460: "cuda:0 bf16[1, 32, 512, 128]" - t466 = prims.reshape(t448, (1, 32, 512, 128)) # t466: "cuda:0 bf16[1, 32, 512, 128]" - t467 = prims.slice_prim(t454, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t467: "cuda:0 bf16[1, 32, 512, 128]" - t468 = prims.slice_prim(t467, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t468: "cuda:0 bf16[1, 32, 512, 64]" - t469 = prims.slice_prim(t467, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t469: "cuda:0 bf16[1, 32, 512, 64]" - t470 = prims.convert_element_type(t469, dtypes.float32) # t470: "cuda:0 f32[1, 32, 512, 64]" - t471 = prims.neg(t470) # t471: "cuda:0 f32[1, 32, 512, 64]" - t472 = prims.convert_element_type(t471, dtypes.bfloat16) # t472: "cuda:0 bf16[1, 32, 512, 64]" - t474 = prims.cat((t472, t468), -1) # t474: "cuda:0 bf16[1, 32, 512, 128]" - t475 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t475: "cuda:0 f32[1, 32, 512, 128]" - t476 = prims.convert_element_type(t467, dtypes.float32) # t476: "cuda:0 f32[1, 32, 512, 128]" - t477 = ltorch.mul(t476, t475) # t477: "cuda:0 f32[1, 32, 512, 128]" - # t477 = prims.mul(t476, t475) # t477: "cuda:0 f32[1, 32, 512, 128]" - t478 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t478: "cuda:0 f32[1, 32, 512, 128]" - t479 = prims.convert_element_type(t474, dtypes.float32) # t479: "cuda:0 f32[1, 32, 512, 128]" - t480 = ltorch.mul(t479, t478) # t480: "cuda:0 f32[1, 32, 512, 128]" - # t480 = prims.mul(t479, t478) # t480: "cuda:0 f32[1, 32, 512, 128]" - t481 = ltorch.add(t477, t480, alpha=None) # t481: "cuda:0 f32[1, 32, 512, 128]" - # t481 = prims.add(t477, t480) # t481: "cuda:0 f32[1, 32, 512, 128]" - t482 = prims.convert_element_type(t481, dtypes.bfloat16) # t482: "cuda:0 bf16[1, 32, 512, 128]" - t483 = prims.slice_prim(t460, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t483: "cuda:0 bf16[1, 32, 512, 128]" - t484 = prims.slice_prim(t483, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t484: "cuda:0 bf16[1, 32, 512, 64]" - t485 = prims.slice_prim(t483, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t485: "cuda:0 bf16[1, 32, 512, 64]" - t486 = prims.convert_element_type(t485, dtypes.float32) # t486: "cuda:0 f32[1, 32, 512, 64]" - t487 = prims.neg(t486) # t487: "cuda:0 f32[1, 32, 512, 64]" - t488 = prims.convert_element_type(t487, dtypes.bfloat16) # t488: "cuda:0 bf16[1, 32, 512, 64]" - t490 = prims.cat((t488, t484), -1) # t490: "cuda:0 bf16[1, 32, 512, 128]" - t491 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t491: "cuda:0 f32[1, 32, 512, 128]" - t492 = prims.convert_element_type(t483, dtypes.float32) # t492: "cuda:0 f32[1, 32, 512, 128]" - t493 = ltorch.mul(t492, t491) # t493: "cuda:0 f32[1, 32, 512, 128]" - # t493 = prims.mul(t492, t491) # t493: "cuda:0 f32[1, 32, 512, 128]" - t494 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t494: "cuda:0 f32[1, 32, 512, 128]" - t495 = prims.convert_element_type(t490, dtypes.float32) # t495: "cuda:0 f32[1, 32, 512, 128]" - t496 = ltorch.mul(t495, t494) # t496: "cuda:0 f32[1, 32, 512, 128]" - # t496 = prims.mul(t495, t494) # t496: "cuda:0 f32[1, 32, 512, 128]" - t497 = ltorch.add(t493, t496, alpha=None) # t497: "cuda:0 f32[1, 32, 512, 128]" - # t497 = prims.add(t493, t496) # t497: "cuda:0 f32[1, 32, 512, 128]" - t498 = prims.convert_element_type(t497, dtypes.bfloat16) # t498: "cuda:0 bf16[1, 32, 512, 128]" - t499 = prims.slice_prim(t454, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t499: "cuda:0 bf16[1, 32, 512, 0]" - t501 = prims.cat((t482, t499), -1) # t501: "cuda:0 bf16[1, 32, 512, 128]" - t502 = prims.slice_prim(t460, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t502: "cuda:0 bf16[1, 32, 512, 0]" - t504 = prims.cat((t498, t502), -1) # t504: "cuda:0 bf16[1, 32, 512, 128]" - (t505, t506, t507, t508) = cudnn_sdpa_fwd(t501, t504, t466, None, 0.0, True, scale=0.08838834764831843) - t511 = prims.transpose(t505, (0, 2, 1, 3)) # t511: "cuda:0 bf16[1, 512, 32, 128]" - t515 = prims.reshape(t511, (1, 512, 4096)) # t515: "cuda:0 bf16[1, 512, 4096]" - t516 = prims.linear(t515, t_transformer_h_3_attn_proj_weight, None) # t516: "cuda:0 bf16[1, 512, 4096]" - t517 = prims.convert_element_type(t516, dtypes.float32) # t517: "cuda:0 f32[1, 512, 4096]" - t518 = prims.convert_element_type(t414, dtypes.float32) # t518: "cuda:0 f32[1, 512, 4096]" - t519 = ltorch.add(t517, t518, alpha=None) # t519: "cuda:0 f32[1, 512, 4096]" - # t519 = prims.add(t517, t518) # t519: "cuda:0 f32[1, 512, 4096]" - t520 = prims.convert_element_type(t519, dtypes.bfloat16) # t520: "cuda:0 bf16[1, 512, 4096]" - t521 = prims.convert_element_type(t520, dtypes.float32) # t521: "cuda:0 f32[1, 512, 4096]" - t522 = ltorch.mul(t521, t521) # t522: "cuda:0 f32[1, 512, 4096]" - # t522 = prims.mul(t521, t521) # t522: "cuda:0 f32[1, 512, 4096]" - t524 = prims.sum(t522, (2,)) # t524: "cuda:0 f32[1, 512]" - t525 = prims.broadcast_in_dim(t524, [1, 512, 1], [0, 1]) # t525: "cuda:0 f32[1, 512, 1]" - t527 = ltorch.true_divide(t525, 4096.0) # t527: "cuda:0 f32[1, 512, 1]" - # t527 = prims.div(t525, 4096.0) # t527: "cuda:0 f32[1, 512, 1]" - t529 = ltorch.add(t527, 1e-05, alpha=None) # t529: "cuda:0 f32[1, 512, 1]" - # t529 = prims.add(t527, 1e-05) # t529: "cuda:0 f32[1, 512, 1]" - t530 = prims.rsqrt(t529) # t530: "cuda:0 f32[1, 512, 1]" - t531 = prims.broadcast_in_dim(t530, (1, 512, 4096), (0, 1, 2)) # t531: "cuda:0 f32[1, 512, 4096]" - t532 = ltorch.mul(t521, t531) # t532: "cuda:0 f32[1, 512, 4096]" - # t532 = prims.mul(t521, t531) # t532: "cuda:0 f32[1, 512, 4096]" - t533 = prims.convert_element_type(t532, dtypes.bfloat16) # t533: "cuda:0 bf16[1, 512, 4096]" - t534 = prims.broadcast_in_dim(t_transformer_h_3_norm_2_weight, (1, 512, 4096), (2,)) # t534: "cuda:0 bf16[1, 512, 4096]" - t535 = prims.convert_element_type(t533, dtypes.float32) # t535: "cuda:0 f32[1, 512, 4096]" - t536 = prims.convert_element_type(t534, dtypes.float32) # t536: "cuda:0 f32[1, 512, 4096]" - t537 = ltorch.mul(t535, t536) # t537: "cuda:0 f32[1, 512, 4096]" - # t537 = prims.mul(t535, t536) # t537: "cuda:0 f32[1, 512, 4096]" - t538 = prims.convert_element_type(t537, dtypes.bfloat16) # t538: "cuda:0 bf16[1, 512, 4096]" - t539 = prims.linear(t538, t_transformer_h_3_mlp_fc_1_weight, None) # t539: "cuda:0 bf16[1, 512, 11008]" - t540 = prims.linear(t538, t_transformer_h_3_mlp_fc_2_weight, None) # t540: "cuda:0 bf16[1, 512, 11008]" - t541 = prims.convert_element_type(t539, dtypes.float32) # t541: "cuda:0 f32[1, 512, 11008]" - t542 = prims.neg(t541) # t542: "cuda:0 f32[1, 512, 11008]" - t543 = prims.exp(t542) # t543: "cuda:0 f32[1, 512, 11008]" - t544 = ltorch.add(1.0, t543, alpha=None) # t544: "cuda:0 f32[1, 512, 11008]" - # t544 = prims.add(1.0, t543) # t544: "cuda:0 f32[1, 512, 11008]" - t545 = prims.reciprocal(t544) # t545: "cuda:0 f32[1, 512, 11008]" - t546 = prims.convert_element_type(t545, dtypes.bfloat16) # t546: "cuda:0 bf16[1, 512, 11008]" - t547 = prims.convert_element_type(t539, dtypes.float32) # t547: "cuda:0 f32[1, 512, 11008]" - t548 = prims.convert_element_type(t546, dtypes.float32) # t548: "cuda:0 f32[1, 512, 11008]" - t549 = ltorch.mul(t547, t548) # t549: "cuda:0 f32[1, 512, 11008]" - # t549 = prims.mul(t547, t548) # t549: "cuda:0 f32[1, 512, 11008]" - t550 = prims.convert_element_type(t549, dtypes.bfloat16) # t550: "cuda:0 bf16[1, 512, 11008]" - t551 = prims.convert_element_type(t550, dtypes.float32) # t551: "cuda:0 f32[1, 512, 11008]" - t552 = prims.convert_element_type(t540, dtypes.float32) # t552: "cuda:0 f32[1, 512, 11008]" - t553 = ltorch.mul(t551, t552) # t553: "cuda:0 f32[1, 512, 11008]" - # t553 = prims.mul(t551, t552) # t553: "cuda:0 f32[1, 512, 11008]" - t554 = prims.convert_element_type(t553, dtypes.bfloat16) # t554: "cuda:0 bf16[1, 512, 11008]" - t555 = prims.linear(t554, t_transformer_h_3_mlp_proj_weight, None) # t555: "cuda:0 bf16[1, 512, 4096]" - t556 = prims.convert_element_type(t555, dtypes.float32) # t556: "cuda:0 f32[1, 512, 4096]" - t557 = prims.convert_element_type(t520, dtypes.float32) # t557: "cuda:0 f32[1, 512, 4096]" - t558 = ltorch.add(t556, t557, alpha=None) # t558: "cuda:0 f32[1, 512, 4096]" - # t558 = prims.add(t556, t557) # t558: "cuda:0 f32[1, 512, 4096]" - t559 = prims.convert_element_type(t558, dtypes.bfloat16) # t559: "cuda:0 bf16[1, 512, 4096]" - t560 = prims.convert_element_type(t559, dtypes.float32) # t560: "cuda:0 f32[1, 512, 4096]" - t561 = ltorch.mul(t560, t560) # t561: "cuda:0 f32[1, 512, 4096]" - # t561 = prims.mul(t560, t560) # t561: "cuda:0 f32[1, 512, 4096]" - t563 = prims.sum(t561, (2,)) # t563: "cuda:0 f32[1, 512]" - t564 = prims.broadcast_in_dim(t563, [1, 512, 1], [0, 1]) # t564: "cuda:0 f32[1, 512, 1]" - t566 = ltorch.true_divide(t564, 4096.0) # t566: "cuda:0 f32[1, 512, 1]" - # t566 = prims.div(t564, 4096.0) # t566: "cuda:0 f32[1, 512, 1]" - t568 = ltorch.add(t566, 1e-05, alpha=None) # t568: "cuda:0 f32[1, 512, 1]" - # t568 = prims.add(t566, 1e-05) # t568: "cuda:0 f32[1, 512, 1]" - t569 = prims.rsqrt(t568) # t569: "cuda:0 f32[1, 512, 1]" - t570 = prims.broadcast_in_dim(t569, (1, 512, 4096), (0, 1, 2)) # t570: "cuda:0 f32[1, 512, 4096]" - t571 = ltorch.mul(t560, t570) # t571: "cuda:0 f32[1, 512, 4096]" - # t571 = prims.mul(t560, t570) # t571: "cuda:0 f32[1, 512, 4096]" - t572 = prims.convert_element_type(t571, dtypes.bfloat16) # t572: "cuda:0 bf16[1, 512, 4096]" - t573 = prims.broadcast_in_dim(t_transformer_h_4_norm_1_weight, (1, 512, 4096), (2,)) # t573: "cuda:0 bf16[1, 512, 4096]" - t574 = prims.convert_element_type(t572, dtypes.float32) # t574: "cuda:0 f32[1, 512, 4096]" - t575 = prims.convert_element_type(t573, dtypes.float32) # t575: "cuda:0 f32[1, 512, 4096]" - t576 = ltorch.mul(t574, t575) # t576: "cuda:0 f32[1, 512, 4096]" - # t576 = prims.mul(t574, t575) # t576: "cuda:0 f32[1, 512, 4096]" - t577 = prims.convert_element_type(t576, dtypes.bfloat16) # t577: "cuda:0 bf16[1, 512, 4096]" - t578 = prims.linear(t577, t_transformer_h_4_attn_attn_weight, None) # t578: "cuda:0 bf16[1, 512, 12288]" - t584 = prims.reshape(t578, (1, 512, 32, 3, 128)) # t584: "cuda:0 bf16[1, 512, 32, 3, 128]" - t590 = prims.transpose(t584, (0, 2, 3, 1, 4)) # t590: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t591, t592, t593) = ltorch.split(t590, (1, 1, 1), 2) - # t591 = prims.slice_prim(t590, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t591: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t592 = prims.slice_prim(t590, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t592: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t593 = prims.slice_prim(t590, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t593: "cuda:0 bf16[1, 32, 1, 512, 128]" - t599 = prims.reshape(t591, (1, 32, 512, 128)) # t599: "cuda:0 bf16[1, 32, 512, 128]" - t605 = prims.reshape(t592, (1, 32, 512, 128)) # t605: "cuda:0 bf16[1, 32, 512, 128]" - t611 = prims.reshape(t593, (1, 32, 512, 128)) # t611: "cuda:0 bf16[1, 32, 512, 128]" - t612 = prims.slice_prim(t599, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t612: "cuda:0 bf16[1, 32, 512, 128]" - t613 = prims.slice_prim(t612, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t613: "cuda:0 bf16[1, 32, 512, 64]" - t614 = prims.slice_prim(t612, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t614: "cuda:0 bf16[1, 32, 512, 64]" - t615 = prims.convert_element_type(t614, dtypes.float32) # t615: "cuda:0 f32[1, 32, 512, 64]" - t616 = prims.neg(t615) # t616: "cuda:0 f32[1, 32, 512, 64]" - t617 = prims.convert_element_type(t616, dtypes.bfloat16) # t617: "cuda:0 bf16[1, 32, 512, 64]" - t619 = prims.cat((t617, t613), -1) # t619: "cuda:0 bf16[1, 32, 512, 128]" - t620 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t620: "cuda:0 f32[1, 32, 512, 128]" - t621 = prims.convert_element_type(t612, dtypes.float32) # t621: "cuda:0 f32[1, 32, 512, 128]" - t622 = ltorch.mul(t621, t620) # t622: "cuda:0 f32[1, 32, 512, 128]" - # t622 = prims.mul(t621, t620) # t622: "cuda:0 f32[1, 32, 512, 128]" - t623 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t623: "cuda:0 f32[1, 32, 512, 128]" - t624 = prims.convert_element_type(t619, dtypes.float32) # t624: "cuda:0 f32[1, 32, 512, 128]" - t625 = ltorch.mul(t624, t623) # t625: "cuda:0 f32[1, 32, 512, 128]" - # t625 = prims.mul(t624, t623) # t625: "cuda:0 f32[1, 32, 512, 128]" - t626 = ltorch.add(t622, t625, alpha=None) # t626: "cuda:0 f32[1, 32, 512, 128]" - # t626 = prims.add(t622, t625) # t626: "cuda:0 f32[1, 32, 512, 128]" - t627 = prims.convert_element_type(t626, dtypes.bfloat16) # t627: "cuda:0 bf16[1, 32, 512, 128]" - t628 = prims.slice_prim(t605, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t628: "cuda:0 bf16[1, 32, 512, 128]" - t629 = prims.slice_prim(t628, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t629: "cuda:0 bf16[1, 32, 512, 64]" - t630 = prims.slice_prim(t628, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t630: "cuda:0 bf16[1, 32, 512, 64]" - t631 = prims.convert_element_type(t630, dtypes.float32) # t631: "cuda:0 f32[1, 32, 512, 64]" - t632 = prims.neg(t631) # t632: "cuda:0 f32[1, 32, 512, 64]" - t633 = prims.convert_element_type(t632, dtypes.bfloat16) # t633: "cuda:0 bf16[1, 32, 512, 64]" - t635 = prims.cat((t633, t629), -1) # t635: "cuda:0 bf16[1, 32, 512, 128]" - t636 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t636: "cuda:0 f32[1, 32, 512, 128]" - t637 = prims.convert_element_type(t628, dtypes.float32) # t637: "cuda:0 f32[1, 32, 512, 128]" - t638 = ltorch.mul(t637, t636) # t638: "cuda:0 f32[1, 32, 512, 128]" - # t638 = prims.mul(t637, t636) # t638: "cuda:0 f32[1, 32, 512, 128]" - t639 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t639: "cuda:0 f32[1, 32, 512, 128]" - t640 = prims.convert_element_type(t635, dtypes.float32) # t640: "cuda:0 f32[1, 32, 512, 128]" - t641 = ltorch.mul(t640, t639) # t641: "cuda:0 f32[1, 32, 512, 128]" - # t641 = prims.mul(t640, t639) # t641: "cuda:0 f32[1, 32, 512, 128]" - t642 = ltorch.add(t638, t641, alpha=None) # t642: "cuda:0 f32[1, 32, 512, 128]" - # t642 = prims.add(t638, t641) # t642: "cuda:0 f32[1, 32, 512, 128]" - t643 = prims.convert_element_type(t642, dtypes.bfloat16) # t643: "cuda:0 bf16[1, 32, 512, 128]" - t644 = prims.slice_prim(t599, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t644: "cuda:0 bf16[1, 32, 512, 0]" - t646 = prims.cat((t627, t644), -1) # t646: "cuda:0 bf16[1, 32, 512, 128]" - t647 = prims.slice_prim(t605, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t647: "cuda:0 bf16[1, 32, 512, 0]" - t649 = prims.cat((t643, t647), -1) # t649: "cuda:0 bf16[1, 32, 512, 128]" - (t650, t651, t652, t653) = cudnn_sdpa_fwd(t646, t649, t611, None, 0.0, True, scale=0.08838834764831843) - t656 = prims.transpose(t650, (0, 2, 1, 3)) # t656: "cuda:0 bf16[1, 512, 32, 128]" - t660 = prims.reshape(t656, (1, 512, 4096)) # t660: "cuda:0 bf16[1, 512, 4096]" - t661 = prims.linear(t660, t_transformer_h_4_attn_proj_weight, None) # t661: "cuda:0 bf16[1, 512, 4096]" - t662 = prims.convert_element_type(t661, dtypes.float32) # t662: "cuda:0 f32[1, 512, 4096]" - t663 = prims.convert_element_type(t559, dtypes.float32) # t663: "cuda:0 f32[1, 512, 4096]" - t664 = ltorch.add(t662, t663, alpha=None) # t664: "cuda:0 f32[1, 512, 4096]" - # t664 = prims.add(t662, t663) # t664: "cuda:0 f32[1, 512, 4096]" - t665 = prims.convert_element_type(t664, dtypes.bfloat16) # t665: "cuda:0 bf16[1, 512, 4096]" - t666 = prims.convert_element_type(t665, dtypes.float32) # t666: "cuda:0 f32[1, 512, 4096]" - t667 = ltorch.mul(t666, t666) # t667: "cuda:0 f32[1, 512, 4096]" - # t667 = prims.mul(t666, t666) # t667: "cuda:0 f32[1, 512, 4096]" - t669 = prims.sum(t667, (2,)) # t669: "cuda:0 f32[1, 512]" - t670 = prims.broadcast_in_dim(t669, [1, 512, 1], [0, 1]) # t670: "cuda:0 f32[1, 512, 1]" - t672 = ltorch.true_divide(t670, 4096.0) # t672: "cuda:0 f32[1, 512, 1]" - # t672 = prims.div(t670, 4096.0) # t672: "cuda:0 f32[1, 512, 1]" - t674 = ltorch.add(t672, 1e-05, alpha=None) # t674: "cuda:0 f32[1, 512, 1]" - # t674 = prims.add(t672, 1e-05) # t674: "cuda:0 f32[1, 512, 1]" - t675 = prims.rsqrt(t674) # t675: "cuda:0 f32[1, 512, 1]" - t676 = prims.broadcast_in_dim(t675, (1, 512, 4096), (0, 1, 2)) # t676: "cuda:0 f32[1, 512, 4096]" - t677 = ltorch.mul(t666, t676) # t677: "cuda:0 f32[1, 512, 4096]" - # t677 = prims.mul(t666, t676) # t677: "cuda:0 f32[1, 512, 4096]" - t678 = prims.convert_element_type(t677, dtypes.bfloat16) # t678: "cuda:0 bf16[1, 512, 4096]" - t679 = prims.broadcast_in_dim(t_transformer_h_4_norm_2_weight, (1, 512, 4096), (2,)) # t679: "cuda:0 bf16[1, 512, 4096]" - t680 = prims.convert_element_type(t678, dtypes.float32) # t680: "cuda:0 f32[1, 512, 4096]" - t681 = prims.convert_element_type(t679, dtypes.float32) # t681: "cuda:0 f32[1, 512, 4096]" - t682 = ltorch.mul(t680, t681) # t682: "cuda:0 f32[1, 512, 4096]" - # t682 = prims.mul(t680, t681) # t682: "cuda:0 f32[1, 512, 4096]" - t683 = prims.convert_element_type(t682, dtypes.bfloat16) # t683: "cuda:0 bf16[1, 512, 4096]" - t684 = prims.linear(t683, t_transformer_h_4_mlp_fc_1_weight, None) # t684: "cuda:0 bf16[1, 512, 11008]" - t685 = prims.linear(t683, t_transformer_h_4_mlp_fc_2_weight, None) # t685: "cuda:0 bf16[1, 512, 11008]" - t686 = prims.convert_element_type(t684, dtypes.float32) # t686: "cuda:0 f32[1, 512, 11008]" - t687 = prims.neg(t686) # t687: "cuda:0 f32[1, 512, 11008]" - t688 = prims.exp(t687) # t688: "cuda:0 f32[1, 512, 11008]" - t689 = ltorch.add(1.0, t688, alpha=None) # t689: "cuda:0 f32[1, 512, 11008]" - # t689 = prims.add(1.0, t688) # t689: "cuda:0 f32[1, 512, 11008]" - t690 = prims.reciprocal(t689) # t690: "cuda:0 f32[1, 512, 11008]" - t691 = prims.convert_element_type(t690, dtypes.bfloat16) # t691: "cuda:0 bf16[1, 512, 11008]" - t692 = prims.convert_element_type(t684, dtypes.float32) # t692: "cuda:0 f32[1, 512, 11008]" - t693 = prims.convert_element_type(t691, dtypes.float32) # t693: "cuda:0 f32[1, 512, 11008]" - t694 = ltorch.mul(t692, t693) # t694: "cuda:0 f32[1, 512, 11008]" - # t694 = prims.mul(t692, t693) # t694: "cuda:0 f32[1, 512, 11008]" - t695 = prims.convert_element_type(t694, dtypes.bfloat16) # t695: "cuda:0 bf16[1, 512, 11008]" - t696 = prims.convert_element_type(t695, dtypes.float32) # t696: "cuda:0 f32[1, 512, 11008]" - t697 = prims.convert_element_type(t685, dtypes.float32) # t697: "cuda:0 f32[1, 512, 11008]" - t698 = ltorch.mul(t696, t697) # t698: "cuda:0 f32[1, 512, 11008]" - # t698 = prims.mul(t696, t697) # t698: "cuda:0 f32[1, 512, 11008]" - t699 = prims.convert_element_type(t698, dtypes.bfloat16) # t699: "cuda:0 bf16[1, 512, 11008]" - t700 = prims.linear(t699, t_transformer_h_4_mlp_proj_weight, None) # t700: "cuda:0 bf16[1, 512, 4096]" - t701 = prims.convert_element_type(t700, dtypes.float32) # t701: "cuda:0 f32[1, 512, 4096]" - t702 = prims.convert_element_type(t665, dtypes.float32) # t702: "cuda:0 f32[1, 512, 4096]" - t703 = ltorch.add(t701, t702, alpha=None) # t703: "cuda:0 f32[1, 512, 4096]" - # t703 = prims.add(t701, t702) # t703: "cuda:0 f32[1, 512, 4096]" - t704 = prims.convert_element_type(t703, dtypes.bfloat16) # t704: "cuda:0 bf16[1, 512, 4096]" - t705 = prims.convert_element_type(t704, dtypes.float32) # t705: "cuda:0 f32[1, 512, 4096]" - t706 = ltorch.mul(t705, t705) # t706: "cuda:0 f32[1, 512, 4096]" - # t706 = prims.mul(t705, t705) # t706: "cuda:0 f32[1, 512, 4096]" - t708 = prims.sum(t706, (2,)) # t708: "cuda:0 f32[1, 512]" - t709 = prims.broadcast_in_dim(t708, [1, 512, 1], [0, 1]) # t709: "cuda:0 f32[1, 512, 1]" - t711 = ltorch.true_divide(t709, 4096.0) # t711: "cuda:0 f32[1, 512, 1]" - # t711 = prims.div(t709, 4096.0) # t711: "cuda:0 f32[1, 512, 1]" - t713 = ltorch.add(t711, 1e-05, alpha=None) # t713: "cuda:0 f32[1, 512, 1]" - # t713 = prims.add(t711, 1e-05) # t713: "cuda:0 f32[1, 512, 1]" - t714 = prims.rsqrt(t713) # t714: "cuda:0 f32[1, 512, 1]" - t715 = prims.broadcast_in_dim(t714, (1, 512, 4096), (0, 1, 2)) # t715: "cuda:0 f32[1, 512, 4096]" - t716 = ltorch.mul(t705, t715) # t716: "cuda:0 f32[1, 512, 4096]" - # t716 = prims.mul(t705, t715) # t716: "cuda:0 f32[1, 512, 4096]" - t717 = prims.convert_element_type(t716, dtypes.bfloat16) # t717: "cuda:0 bf16[1, 512, 4096]" - t718 = prims.broadcast_in_dim(t_transformer_h_5_norm_1_weight, (1, 512, 4096), (2,)) # t718: "cuda:0 bf16[1, 512, 4096]" - t719 = prims.convert_element_type(t717, dtypes.float32) # t719: "cuda:0 f32[1, 512, 4096]" - t720 = prims.convert_element_type(t718, dtypes.float32) # t720: "cuda:0 f32[1, 512, 4096]" - t721 = ltorch.mul(t719, t720) # t721: "cuda:0 f32[1, 512, 4096]" - # t721 = prims.mul(t719, t720) # t721: "cuda:0 f32[1, 512, 4096]" - t722 = prims.convert_element_type(t721, dtypes.bfloat16) # t722: "cuda:0 bf16[1, 512, 4096]" - t723 = prims.linear(t722, t_transformer_h_5_attn_attn_weight, None) # t723: "cuda:0 bf16[1, 512, 12288]" - t729 = prims.reshape(t723, (1, 512, 32, 3, 128)) # t729: "cuda:0 bf16[1, 512, 32, 3, 128]" - t735 = prims.transpose(t729, (0, 2, 3, 1, 4)) # t735: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t736, t737, t738) = ltorch.split(t735, (1, 1, 1), 2) - # t736 = prims.slice_prim(t735, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t736: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t737 = prims.slice_prim(t735, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t737: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t738 = prims.slice_prim(t735, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t738: "cuda:0 bf16[1, 32, 1, 512, 128]" - t744 = prims.reshape(t736, (1, 32, 512, 128)) # t744: "cuda:0 bf16[1, 32, 512, 128]" - t750 = prims.reshape(t737, (1, 32, 512, 128)) # t750: "cuda:0 bf16[1, 32, 512, 128]" - t756 = prims.reshape(t738, (1, 32, 512, 128)) # t756: "cuda:0 bf16[1, 32, 512, 128]" - t757 = prims.slice_prim(t744, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t757: "cuda:0 bf16[1, 32, 512, 128]" - t758 = prims.slice_prim(t757, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t758: "cuda:0 bf16[1, 32, 512, 64]" - t759 = prims.slice_prim(t757, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t759: "cuda:0 bf16[1, 32, 512, 64]" - t760 = prims.convert_element_type(t759, dtypes.float32) # t760: "cuda:0 f32[1, 32, 512, 64]" - t761 = prims.neg(t760) # t761: "cuda:0 f32[1, 32, 512, 64]" - t762 = prims.convert_element_type(t761, dtypes.bfloat16) # t762: "cuda:0 bf16[1, 32, 512, 64]" - t764 = prims.cat((t762, t758), -1) # t764: "cuda:0 bf16[1, 32, 512, 128]" - t765 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t765: "cuda:0 f32[1, 32, 512, 128]" - t766 = prims.convert_element_type(t757, dtypes.float32) # t766: "cuda:0 f32[1, 32, 512, 128]" - t767 = ltorch.mul(t766, t765) # t767: "cuda:0 f32[1, 32, 512, 128]" - # t767 = prims.mul(t766, t765) # t767: "cuda:0 f32[1, 32, 512, 128]" - t768 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t768: "cuda:0 f32[1, 32, 512, 128]" - t769 = prims.convert_element_type(t764, dtypes.float32) # t769: "cuda:0 f32[1, 32, 512, 128]" - t770 = ltorch.mul(t769, t768) # t770: "cuda:0 f32[1, 32, 512, 128]" - # t770 = prims.mul(t769, t768) # t770: "cuda:0 f32[1, 32, 512, 128]" - t771 = ltorch.add(t767, t770, alpha=None) # t771: "cuda:0 f32[1, 32, 512, 128]" - # t771 = prims.add(t767, t770) # t771: "cuda:0 f32[1, 32, 512, 128]" - t772 = prims.convert_element_type(t771, dtypes.bfloat16) # t772: "cuda:0 bf16[1, 32, 512, 128]" - t773 = prims.slice_prim(t750, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t773: "cuda:0 bf16[1, 32, 512, 128]" - t774 = prims.slice_prim(t773, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t774: "cuda:0 bf16[1, 32, 512, 64]" - t775 = prims.slice_prim(t773, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t775: "cuda:0 bf16[1, 32, 512, 64]" - t776 = prims.convert_element_type(t775, dtypes.float32) # t776: "cuda:0 f32[1, 32, 512, 64]" - t777 = prims.neg(t776) # t777: "cuda:0 f32[1, 32, 512, 64]" - t778 = prims.convert_element_type(t777, dtypes.bfloat16) # t778: "cuda:0 bf16[1, 32, 512, 64]" - t780 = prims.cat((t778, t774), -1) # t780: "cuda:0 bf16[1, 32, 512, 128]" - t781 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t781: "cuda:0 f32[1, 32, 512, 128]" - t782 = prims.convert_element_type(t773, dtypes.float32) # t782: "cuda:0 f32[1, 32, 512, 128]" - t783 = ltorch.mul(t782, t781) # t783: "cuda:0 f32[1, 32, 512, 128]" - # t783 = prims.mul(t782, t781) # t783: "cuda:0 f32[1, 32, 512, 128]" - t784 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t784: "cuda:0 f32[1, 32, 512, 128]" - t785 = prims.convert_element_type(t780, dtypes.float32) # t785: "cuda:0 f32[1, 32, 512, 128]" - t786 = ltorch.mul(t785, t784) # t786: "cuda:0 f32[1, 32, 512, 128]" - # t786 = prims.mul(t785, t784) # t786: "cuda:0 f32[1, 32, 512, 128]" - t787 = ltorch.add(t783, t786, alpha=None) # t787: "cuda:0 f32[1, 32, 512, 128]" - # t787 = prims.add(t783, t786) # t787: "cuda:0 f32[1, 32, 512, 128]" - t788 = prims.convert_element_type(t787, dtypes.bfloat16) # t788: "cuda:0 bf16[1, 32, 512, 128]" - t789 = prims.slice_prim(t744, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t789: "cuda:0 bf16[1, 32, 512, 0]" - t791 = prims.cat((t772, t789), -1) # t791: "cuda:0 bf16[1, 32, 512, 128]" - t792 = prims.slice_prim(t750, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t792: "cuda:0 bf16[1, 32, 512, 0]" - t794 = prims.cat((t788, t792), -1) # t794: "cuda:0 bf16[1, 32, 512, 128]" - (t795, t796, t797, t798) = cudnn_sdpa_fwd(t791, t794, t756, None, 0.0, True, scale=0.08838834764831843) - t801 = prims.transpose(t795, (0, 2, 1, 3)) # t801: "cuda:0 bf16[1, 512, 32, 128]" - t805 = prims.reshape(t801, (1, 512, 4096)) # t805: "cuda:0 bf16[1, 512, 4096]" - t806 = prims.linear(t805, t_transformer_h_5_attn_proj_weight, None) # t806: "cuda:0 bf16[1, 512, 4096]" - t807 = prims.convert_element_type(t806, dtypes.float32) # t807: "cuda:0 f32[1, 512, 4096]" - t808 = prims.convert_element_type(t704, dtypes.float32) # t808: "cuda:0 f32[1, 512, 4096]" - t809 = ltorch.add(t807, t808, alpha=None) # t809: "cuda:0 f32[1, 512, 4096]" - # t809 = prims.add(t807, t808) # t809: "cuda:0 f32[1, 512, 4096]" - t810 = prims.convert_element_type(t809, dtypes.bfloat16) # t810: "cuda:0 bf16[1, 512, 4096]" - t811 = prims.convert_element_type(t810, dtypes.float32) # t811: "cuda:0 f32[1, 512, 4096]" - t812 = ltorch.mul(t811, t811) # t812: "cuda:0 f32[1, 512, 4096]" - # t812 = prims.mul(t811, t811) # t812: "cuda:0 f32[1, 512, 4096]" - t814 = prims.sum(t812, (2,)) # t814: "cuda:0 f32[1, 512]" - t815 = prims.broadcast_in_dim(t814, [1, 512, 1], [0, 1]) # t815: "cuda:0 f32[1, 512, 1]" - t817 = ltorch.true_divide(t815, 4096.0) # t817: "cuda:0 f32[1, 512, 1]" - # t817 = prims.div(t815, 4096.0) # t817: "cuda:0 f32[1, 512, 1]" - t819 = ltorch.add(t817, 1e-05, alpha=None) # t819: "cuda:0 f32[1, 512, 1]" - # t819 = prims.add(t817, 1e-05) # t819: "cuda:0 f32[1, 512, 1]" - t820 = prims.rsqrt(t819) # t820: "cuda:0 f32[1, 512, 1]" - t821 = prims.broadcast_in_dim(t820, (1, 512, 4096), (0, 1, 2)) # t821: "cuda:0 f32[1, 512, 4096]" - t822 = ltorch.mul(t811, t821) # t822: "cuda:0 f32[1, 512, 4096]" - # t822 = prims.mul(t811, t821) # t822: "cuda:0 f32[1, 512, 4096]" - t823 = prims.convert_element_type(t822, dtypes.bfloat16) # t823: "cuda:0 bf16[1, 512, 4096]" - t824 = prims.broadcast_in_dim(t_transformer_h_5_norm_2_weight, (1, 512, 4096), (2,)) # t824: "cuda:0 bf16[1, 512, 4096]" - t825 = prims.convert_element_type(t823, dtypes.float32) # t825: "cuda:0 f32[1, 512, 4096]" - t826 = prims.convert_element_type(t824, dtypes.float32) # t826: "cuda:0 f32[1, 512, 4096]" - t827 = ltorch.mul(t825, t826) # t827: "cuda:0 f32[1, 512, 4096]" - # t827 = prims.mul(t825, t826) # t827: "cuda:0 f32[1, 512, 4096]" - t828 = prims.convert_element_type(t827, dtypes.bfloat16) # t828: "cuda:0 bf16[1, 512, 4096]" - t829 = prims.linear(t828, t_transformer_h_5_mlp_fc_1_weight, None) # t829: "cuda:0 bf16[1, 512, 11008]" - t830 = prims.linear(t828, t_transformer_h_5_mlp_fc_2_weight, None) # t830: "cuda:0 bf16[1, 512, 11008]" - t831 = prims.convert_element_type(t829, dtypes.float32) # t831: "cuda:0 f32[1, 512, 11008]" - t832 = prims.neg(t831) # t832: "cuda:0 f32[1, 512, 11008]" - t833 = prims.exp(t832) # t833: "cuda:0 f32[1, 512, 11008]" - t834 = ltorch.add(1.0, t833, alpha=None) # t834: "cuda:0 f32[1, 512, 11008]" - # t834 = prims.add(1.0, t833) # t834: "cuda:0 f32[1, 512, 11008]" - t835 = prims.reciprocal(t834) # t835: "cuda:0 f32[1, 512, 11008]" - t836 = prims.convert_element_type(t835, dtypes.bfloat16) # t836: "cuda:0 bf16[1, 512, 11008]" - t837 = prims.convert_element_type(t829, dtypes.float32) # t837: "cuda:0 f32[1, 512, 11008]" - t838 = prims.convert_element_type(t836, dtypes.float32) # t838: "cuda:0 f32[1, 512, 11008]" - t839 = ltorch.mul(t837, t838) # t839: "cuda:0 f32[1, 512, 11008]" - # t839 = prims.mul(t837, t838) # t839: "cuda:0 f32[1, 512, 11008]" - t840 = prims.convert_element_type(t839, dtypes.bfloat16) # t840: "cuda:0 bf16[1, 512, 11008]" - t841 = prims.convert_element_type(t840, dtypes.float32) # t841: "cuda:0 f32[1, 512, 11008]" - t842 = prims.convert_element_type(t830, dtypes.float32) # t842: "cuda:0 f32[1, 512, 11008]" - t843 = ltorch.mul(t841, t842) # t843: "cuda:0 f32[1, 512, 11008]" - # t843 = prims.mul(t841, t842) # t843: "cuda:0 f32[1, 512, 11008]" - t844 = prims.convert_element_type(t843, dtypes.bfloat16) # t844: "cuda:0 bf16[1, 512, 11008]" - t845 = prims.linear(t844, t_transformer_h_5_mlp_proj_weight, None) # t845: "cuda:0 bf16[1, 512, 4096]" - t846 = prims.convert_element_type(t845, dtypes.float32) # t846: "cuda:0 f32[1, 512, 4096]" - t847 = prims.convert_element_type(t810, dtypes.float32) # t847: "cuda:0 f32[1, 512, 4096]" - t848 = ltorch.add(t846, t847, alpha=None) # t848: "cuda:0 f32[1, 512, 4096]" - # t848 = prims.add(t846, t847) # t848: "cuda:0 f32[1, 512, 4096]" - t849 = prims.convert_element_type(t848, dtypes.bfloat16) # t849: "cuda:0 bf16[1, 512, 4096]" - t850 = prims.convert_element_type(t849, dtypes.float32) # t850: "cuda:0 f32[1, 512, 4096]" - t851 = ltorch.mul(t850, t850) # t851: "cuda:0 f32[1, 512, 4096]" - # t851 = prims.mul(t850, t850) # t851: "cuda:0 f32[1, 512, 4096]" - t853 = prims.sum(t851, (2,)) # t853: "cuda:0 f32[1, 512]" - t854 = prims.broadcast_in_dim(t853, [1, 512, 1], [0, 1]) # t854: "cuda:0 f32[1, 512, 1]" - t856 = ltorch.true_divide(t854, 4096.0) # t856: "cuda:0 f32[1, 512, 1]" - # t856 = prims.div(t854, 4096.0) # t856: "cuda:0 f32[1, 512, 1]" - t858 = ltorch.add(t856, 1e-05, alpha=None) # t858: "cuda:0 f32[1, 512, 1]" - # t858 = prims.add(t856, 1e-05) # t858: "cuda:0 f32[1, 512, 1]" - t859 = prims.rsqrt(t858) # t859: "cuda:0 f32[1, 512, 1]" - t860 = prims.broadcast_in_dim(t859, (1, 512, 4096), (0, 1, 2)) # t860: "cuda:0 f32[1, 512, 4096]" - t861 = ltorch.mul(t850, t860) # t861: "cuda:0 f32[1, 512, 4096]" - # t861 = prims.mul(t850, t860) # t861: "cuda:0 f32[1, 512, 4096]" - t862 = prims.convert_element_type(t861, dtypes.bfloat16) # t862: "cuda:0 bf16[1, 512, 4096]" - t863 = prims.broadcast_in_dim(t_transformer_h_6_norm_1_weight, (1, 512, 4096), (2,)) # t863: "cuda:0 bf16[1, 512, 4096]" - t864 = prims.convert_element_type(t862, dtypes.float32) # t864: "cuda:0 f32[1, 512, 4096]" - t865 = prims.convert_element_type(t863, dtypes.float32) # t865: "cuda:0 f32[1, 512, 4096]" - t866 = ltorch.mul(t864, t865) # t866: "cuda:0 f32[1, 512, 4096]" - # t866 = prims.mul(t864, t865) # t866: "cuda:0 f32[1, 512, 4096]" - t867 = prims.convert_element_type(t866, dtypes.bfloat16) # t867: "cuda:0 bf16[1, 512, 4096]" - t868 = prims.linear(t867, t_transformer_h_6_attn_attn_weight, None) # t868: "cuda:0 bf16[1, 512, 12288]" - t874 = prims.reshape(t868, (1, 512, 32, 3, 128)) # t874: "cuda:0 bf16[1, 512, 32, 3, 128]" - t880 = prims.transpose(t874, (0, 2, 3, 1, 4)) # t880: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t881, t882, t883) = ltorch.split(t880, (1, 1, 1), 2) - # t881 = prims.slice_prim(t880, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t881: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t882 = prims.slice_prim(t880, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t882: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t883 = prims.slice_prim(t880, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t883: "cuda:0 bf16[1, 32, 1, 512, 128]" - t889 = prims.reshape(t881, (1, 32, 512, 128)) # t889: "cuda:0 bf16[1, 32, 512, 128]" - t895 = prims.reshape(t882, (1, 32, 512, 128)) # t895: "cuda:0 bf16[1, 32, 512, 128]" - t901 = prims.reshape(t883, (1, 32, 512, 128)) # t901: "cuda:0 bf16[1, 32, 512, 128]" - t902 = prims.slice_prim(t889, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t902: "cuda:0 bf16[1, 32, 512, 128]" - t903 = prims.slice_prim(t902, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t903: "cuda:0 bf16[1, 32, 512, 64]" - t904 = prims.slice_prim(t902, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t904: "cuda:0 bf16[1, 32, 512, 64]" - t905 = prims.convert_element_type(t904, dtypes.float32) # t905: "cuda:0 f32[1, 32, 512, 64]" - t906 = prims.neg(t905) # t906: "cuda:0 f32[1, 32, 512, 64]" - t907 = prims.convert_element_type(t906, dtypes.bfloat16) # t907: "cuda:0 bf16[1, 32, 512, 64]" - t909 = prims.cat((t907, t903), -1) # t909: "cuda:0 bf16[1, 32, 512, 128]" - t910 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t910: "cuda:0 f32[1, 32, 512, 128]" - t911 = prims.convert_element_type(t902, dtypes.float32) # t911: "cuda:0 f32[1, 32, 512, 128]" - t912 = ltorch.mul(t911, t910) # t912: "cuda:0 f32[1, 32, 512, 128]" - # t912 = prims.mul(t911, t910) # t912: "cuda:0 f32[1, 32, 512, 128]" - t913 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t913: "cuda:0 f32[1, 32, 512, 128]" - t914 = prims.convert_element_type(t909, dtypes.float32) # t914: "cuda:0 f32[1, 32, 512, 128]" - t915 = ltorch.mul(t914, t913) # t915: "cuda:0 f32[1, 32, 512, 128]" - # t915 = prims.mul(t914, t913) # t915: "cuda:0 f32[1, 32, 512, 128]" - t916 = ltorch.add(t912, t915, alpha=None) # t916: "cuda:0 f32[1, 32, 512, 128]" - # t916 = prims.add(t912, t915) # t916: "cuda:0 f32[1, 32, 512, 128]" - t917 = prims.convert_element_type(t916, dtypes.bfloat16) # t917: "cuda:0 bf16[1, 32, 512, 128]" - t918 = prims.slice_prim(t895, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t918: "cuda:0 bf16[1, 32, 512, 128]" - t919 = prims.slice_prim(t918, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t919: "cuda:0 bf16[1, 32, 512, 64]" - t920 = prims.slice_prim(t918, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t920: "cuda:0 bf16[1, 32, 512, 64]" - t921 = prims.convert_element_type(t920, dtypes.float32) # t921: "cuda:0 f32[1, 32, 512, 64]" - t922 = prims.neg(t921) # t922: "cuda:0 f32[1, 32, 512, 64]" - t923 = prims.convert_element_type(t922, dtypes.bfloat16) # t923: "cuda:0 bf16[1, 32, 512, 64]" - t925 = prims.cat((t923, t919), -1) # t925: "cuda:0 bf16[1, 32, 512, 128]" - t926 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t926: "cuda:0 f32[1, 32, 512, 128]" - t927 = prims.convert_element_type(t918, dtypes.float32) # t927: "cuda:0 f32[1, 32, 512, 128]" - t928 = ltorch.mul(t927, t926) # t928: "cuda:0 f32[1, 32, 512, 128]" - # t928 = prims.mul(t927, t926) # t928: "cuda:0 f32[1, 32, 512, 128]" - t929 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t929: "cuda:0 f32[1, 32, 512, 128]" - t930 = prims.convert_element_type(t925, dtypes.float32) # t930: "cuda:0 f32[1, 32, 512, 128]" - t931 = ltorch.mul(t930, t929) # t931: "cuda:0 f32[1, 32, 512, 128]" - # t931 = prims.mul(t930, t929) # t931: "cuda:0 f32[1, 32, 512, 128]" - t932 = ltorch.add(t928, t931, alpha=None) # t932: "cuda:0 f32[1, 32, 512, 128]" - # t932 = prims.add(t928, t931) # t932: "cuda:0 f32[1, 32, 512, 128]" - t933 = prims.convert_element_type(t932, dtypes.bfloat16) # t933: "cuda:0 bf16[1, 32, 512, 128]" - t934 = prims.slice_prim(t889, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t934: "cuda:0 bf16[1, 32, 512, 0]" - t936 = prims.cat((t917, t934), -1) # t936: "cuda:0 bf16[1, 32, 512, 128]" - t937 = prims.slice_prim(t895, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t937: "cuda:0 bf16[1, 32, 512, 0]" - t939 = prims.cat((t933, t937), -1) # t939: "cuda:0 bf16[1, 32, 512, 128]" - (t940, t941, t942, t943) = cudnn_sdpa_fwd(t936, t939, t901, None, 0.0, True, scale=0.08838834764831843) - t946 = prims.transpose(t940, (0, 2, 1, 3)) # t946: "cuda:0 bf16[1, 512, 32, 128]" - t950 = prims.reshape(t946, (1, 512, 4096)) # t950: "cuda:0 bf16[1, 512, 4096]" - t951 = prims.linear(t950, t_transformer_h_6_attn_proj_weight, None) # t951: "cuda:0 bf16[1, 512, 4096]" - t952 = prims.convert_element_type(t951, dtypes.float32) # t952: "cuda:0 f32[1, 512, 4096]" - t953 = prims.convert_element_type(t849, dtypes.float32) # t953: "cuda:0 f32[1, 512, 4096]" - t954 = ltorch.add(t952, t953, alpha=None) # t954: "cuda:0 f32[1, 512, 4096]" - # t954 = prims.add(t952, t953) # t954: "cuda:0 f32[1, 512, 4096]" - t955 = prims.convert_element_type(t954, dtypes.bfloat16) # t955: "cuda:0 bf16[1, 512, 4096]" - t956 = prims.convert_element_type(t955, dtypes.float32) # t956: "cuda:0 f32[1, 512, 4096]" - t957 = ltorch.mul(t956, t956) # t957: "cuda:0 f32[1, 512, 4096]" - # t957 = prims.mul(t956, t956) # t957: "cuda:0 f32[1, 512, 4096]" - t959 = prims.sum(t957, (2,)) # t959: "cuda:0 f32[1, 512]" - t960 = prims.broadcast_in_dim(t959, [1, 512, 1], [0, 1]) # t960: "cuda:0 f32[1, 512, 1]" - t962 = ltorch.true_divide(t960, 4096.0) # t962: "cuda:0 f32[1, 512, 1]" - # t962 = prims.div(t960, 4096.0) # t962: "cuda:0 f32[1, 512, 1]" - t964 = ltorch.add(t962, 1e-05, alpha=None) # t964: "cuda:0 f32[1, 512, 1]" - # t964 = prims.add(t962, 1e-05) # t964: "cuda:0 f32[1, 512, 1]" - t965 = prims.rsqrt(t964) # t965: "cuda:0 f32[1, 512, 1]" - t966 = prims.broadcast_in_dim(t965, (1, 512, 4096), (0, 1, 2)) # t966: "cuda:0 f32[1, 512, 4096]" - t967 = ltorch.mul(t956, t966) # t967: "cuda:0 f32[1, 512, 4096]" - # t967 = prims.mul(t956, t966) # t967: "cuda:0 f32[1, 512, 4096]" - t968 = prims.convert_element_type(t967, dtypes.bfloat16) # t968: "cuda:0 bf16[1, 512, 4096]" - t969 = prims.broadcast_in_dim(t_transformer_h_6_norm_2_weight, (1, 512, 4096), (2,)) # t969: "cuda:0 bf16[1, 512, 4096]" - t970 = prims.convert_element_type(t968, dtypes.float32) # t970: "cuda:0 f32[1, 512, 4096]" - t971 = prims.convert_element_type(t969, dtypes.float32) # t971: "cuda:0 f32[1, 512, 4096]" - t972 = ltorch.mul(t970, t971) # t972: "cuda:0 f32[1, 512, 4096]" - # t972 = prims.mul(t970, t971) # t972: "cuda:0 f32[1, 512, 4096]" - t973 = prims.convert_element_type(t972, dtypes.bfloat16) # t973: "cuda:0 bf16[1, 512, 4096]" - t974 = prims.linear(t973, t_transformer_h_6_mlp_fc_1_weight, None) # t974: "cuda:0 bf16[1, 512, 11008]" - t975 = prims.linear(t973, t_transformer_h_6_mlp_fc_2_weight, None) # t975: "cuda:0 bf16[1, 512, 11008]" - t976 = prims.convert_element_type(t974, dtypes.float32) # t976: "cuda:0 f32[1, 512, 11008]" - t977 = prims.neg(t976) # t977: "cuda:0 f32[1, 512, 11008]" - t978 = prims.exp(t977) # t978: "cuda:0 f32[1, 512, 11008]" - t979 = ltorch.add(1.0, t978, alpha=None) # t979: "cuda:0 f32[1, 512, 11008]" - # t979 = prims.add(1.0, t978) # t979: "cuda:0 f32[1, 512, 11008]" - t980 = prims.reciprocal(t979) # t980: "cuda:0 f32[1, 512, 11008]" - t981 = prims.convert_element_type(t980, dtypes.bfloat16) # t981: "cuda:0 bf16[1, 512, 11008]" - t982 = prims.convert_element_type(t974, dtypes.float32) # t982: "cuda:0 f32[1, 512, 11008]" - t983 = prims.convert_element_type(t981, dtypes.float32) # t983: "cuda:0 f32[1, 512, 11008]" - t984 = ltorch.mul(t982, t983) # t984: "cuda:0 f32[1, 512, 11008]" - # t984 = prims.mul(t982, t983) # t984: "cuda:0 f32[1, 512, 11008]" - t985 = prims.convert_element_type(t984, dtypes.bfloat16) # t985: "cuda:0 bf16[1, 512, 11008]" - t986 = prims.convert_element_type(t985, dtypes.float32) # t986: "cuda:0 f32[1, 512, 11008]" - t987 = prims.convert_element_type(t975, dtypes.float32) # t987: "cuda:0 f32[1, 512, 11008]" - t988 = ltorch.mul(t986, t987) # t988: "cuda:0 f32[1, 512, 11008]" - # t988 = prims.mul(t986, t987) # t988: "cuda:0 f32[1, 512, 11008]" - t989 = prims.convert_element_type(t988, dtypes.bfloat16) # t989: "cuda:0 bf16[1, 512, 11008]" - t990 = prims.linear(t989, t_transformer_h_6_mlp_proj_weight, None) # t990: "cuda:0 bf16[1, 512, 4096]" - t991 = prims.convert_element_type(t990, dtypes.float32) # t991: "cuda:0 f32[1, 512, 4096]" - t992 = prims.convert_element_type(t955, dtypes.float32) # t992: "cuda:0 f32[1, 512, 4096]" - t993 = ltorch.add(t991, t992, alpha=None) # t993: "cuda:0 f32[1, 512, 4096]" - # t993 = prims.add(t991, t992) # t993: "cuda:0 f32[1, 512, 4096]" - t994 = prims.convert_element_type(t993, dtypes.bfloat16) # t994: "cuda:0 bf16[1, 512, 4096]" - t995 = prims.convert_element_type(t994, dtypes.float32) # t995: "cuda:0 f32[1, 512, 4096]" - t996 = ltorch.mul(t995, t995) # t996: "cuda:0 f32[1, 512, 4096]" - # t996 = prims.mul(t995, t995) # t996: "cuda:0 f32[1, 512, 4096]" - t998 = prims.sum(t996, (2,)) # t998: "cuda:0 f32[1, 512]" - t999 = prims.broadcast_in_dim(t998, [1, 512, 1], [0, 1]) # t999: "cuda:0 f32[1, 512, 1]" - t1001 = ltorch.true_divide(t999, 4096.0) # t1001: "cuda:0 f32[1, 512, 1]" - # t1001 = prims.div(t999, 4096.0) # t1001: "cuda:0 f32[1, 512, 1]" - t1003 = ltorch.add(t1001, 1e-05, alpha=None) # t1003: "cuda:0 f32[1, 512, 1]" - # t1003 = prims.add(t1001, 1e-05) # t1003: "cuda:0 f32[1, 512, 1]" - t1004 = prims.rsqrt(t1003) # t1004: "cuda:0 f32[1, 512, 1]" - t1005 = prims.broadcast_in_dim(t1004, (1, 512, 4096), (0, 1, 2)) # t1005: "cuda:0 f32[1, 512, 4096]" - t1006 = ltorch.mul(t995, t1005) # t1006: "cuda:0 f32[1, 512, 4096]" - # t1006 = prims.mul(t995, t1005) # t1006: "cuda:0 f32[1, 512, 4096]" - t1007 = prims.convert_element_type(t1006, dtypes.bfloat16) # t1007: "cuda:0 bf16[1, 512, 4096]" - t1008 = prims.broadcast_in_dim(t_transformer_h_7_norm_1_weight, (1, 512, 4096), (2,)) # t1008: "cuda:0 bf16[1, 512, 4096]" - t1009 = prims.convert_element_type(t1007, dtypes.float32) # t1009: "cuda:0 f32[1, 512, 4096]" - t1010 = prims.convert_element_type(t1008, dtypes.float32) # t1010: "cuda:0 f32[1, 512, 4096]" - t1011 = ltorch.mul(t1009, t1010) # t1011: "cuda:0 f32[1, 512, 4096]" - # t1011 = prims.mul(t1009, t1010) # t1011: "cuda:0 f32[1, 512, 4096]" - t1012 = prims.convert_element_type(t1011, dtypes.bfloat16) # t1012: "cuda:0 bf16[1, 512, 4096]" - t1013 = prims.linear(t1012, t_transformer_h_7_attn_attn_weight, None) # t1013: "cuda:0 bf16[1, 512, 12288]" - t1019 = prims.reshape(t1013, (1, 512, 32, 3, 128)) # t1019: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1025 = prims.transpose(t1019, (0, 2, 3, 1, 4)) # t1025: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1026, t1027, t1028) = ltorch.split(t1025, (1, 1, 1), 2) - # t1026 = prims.slice_prim(t1025, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1026: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1027 = prims.slice_prim(t1025, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1027: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1028 = prims.slice_prim(t1025, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1028: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1034 = prims.reshape(t1026, (1, 32, 512, 128)) # t1034: "cuda:0 bf16[1, 32, 512, 128]" - t1040 = prims.reshape(t1027, (1, 32, 512, 128)) # t1040: "cuda:0 bf16[1, 32, 512, 128]" - t1046 = prims.reshape(t1028, (1, 32, 512, 128)) # t1046: "cuda:0 bf16[1, 32, 512, 128]" - t1047 = prims.slice_prim(t1034, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1047: "cuda:0 bf16[1, 32, 512, 128]" - t1048 = prims.slice_prim(t1047, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1048: "cuda:0 bf16[1, 32, 512, 64]" - t1049 = prims.slice_prim(t1047, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1049: "cuda:0 bf16[1, 32, 512, 64]" - t1050 = prims.convert_element_type(t1049, dtypes.float32) # t1050: "cuda:0 f32[1, 32, 512, 64]" - t1051 = prims.neg(t1050) # t1051: "cuda:0 f32[1, 32, 512, 64]" - t1052 = prims.convert_element_type(t1051, dtypes.bfloat16) # t1052: "cuda:0 bf16[1, 32, 512, 64]" - t1054 = prims.cat((t1052, t1048), -1) # t1054: "cuda:0 bf16[1, 32, 512, 128]" - t1055 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1055: "cuda:0 f32[1, 32, 512, 128]" - t1056 = prims.convert_element_type(t1047, dtypes.float32) # t1056: "cuda:0 f32[1, 32, 512, 128]" - t1057 = ltorch.mul(t1056, t1055) # t1057: "cuda:0 f32[1, 32, 512, 128]" - # t1057 = prims.mul(t1056, t1055) # t1057: "cuda:0 f32[1, 32, 512, 128]" - t1058 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1058: "cuda:0 f32[1, 32, 512, 128]" - t1059 = prims.convert_element_type(t1054, dtypes.float32) # t1059: "cuda:0 f32[1, 32, 512, 128]" - t1060 = ltorch.mul(t1059, t1058) # t1060: "cuda:0 f32[1, 32, 512, 128]" - # t1060 = prims.mul(t1059, t1058) # t1060: "cuda:0 f32[1, 32, 512, 128]" - t1061 = ltorch.add(t1057, t1060, alpha=None) # t1061: "cuda:0 f32[1, 32, 512, 128]" - # t1061 = prims.add(t1057, t1060) # t1061: "cuda:0 f32[1, 32, 512, 128]" - t1062 = prims.convert_element_type(t1061, dtypes.bfloat16) # t1062: "cuda:0 bf16[1, 32, 512, 128]" - t1063 = prims.slice_prim(t1040, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1063: "cuda:0 bf16[1, 32, 512, 128]" - t1064 = prims.slice_prim(t1063, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1064: "cuda:0 bf16[1, 32, 512, 64]" - t1065 = prims.slice_prim(t1063, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1065: "cuda:0 bf16[1, 32, 512, 64]" - t1066 = prims.convert_element_type(t1065, dtypes.float32) # t1066: "cuda:0 f32[1, 32, 512, 64]" - t1067 = prims.neg(t1066) # t1067: "cuda:0 f32[1, 32, 512, 64]" - t1068 = prims.convert_element_type(t1067, dtypes.bfloat16) # t1068: "cuda:0 bf16[1, 32, 512, 64]" - t1070 = prims.cat((t1068, t1064), -1) # t1070: "cuda:0 bf16[1, 32, 512, 128]" - t1071 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1071: "cuda:0 f32[1, 32, 512, 128]" - t1072 = prims.convert_element_type(t1063, dtypes.float32) # t1072: "cuda:0 f32[1, 32, 512, 128]" - t1073 = ltorch.mul(t1072, t1071) # t1073: "cuda:0 f32[1, 32, 512, 128]" - # t1073 = prims.mul(t1072, t1071) # t1073: "cuda:0 f32[1, 32, 512, 128]" - t1074 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1074: "cuda:0 f32[1, 32, 512, 128]" - t1075 = prims.convert_element_type(t1070, dtypes.float32) # t1075: "cuda:0 f32[1, 32, 512, 128]" - t1076 = ltorch.mul(t1075, t1074) # t1076: "cuda:0 f32[1, 32, 512, 128]" - # t1076 = prims.mul(t1075, t1074) # t1076: "cuda:0 f32[1, 32, 512, 128]" - t1077 = ltorch.add(t1073, t1076, alpha=None) # t1077: "cuda:0 f32[1, 32, 512, 128]" - # t1077 = prims.add(t1073, t1076) # t1077: "cuda:0 f32[1, 32, 512, 128]" - t1078 = prims.convert_element_type(t1077, dtypes.bfloat16) # t1078: "cuda:0 bf16[1, 32, 512, 128]" - t1079 = prims.slice_prim(t1034, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1079: "cuda:0 bf16[1, 32, 512, 0]" - t1081 = prims.cat((t1062, t1079), -1) # t1081: "cuda:0 bf16[1, 32, 512, 128]" - t1082 = prims.slice_prim(t1040, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1082: "cuda:0 bf16[1, 32, 512, 0]" - t1084 = prims.cat((t1078, t1082), -1) # t1084: "cuda:0 bf16[1, 32, 512, 128]" - (t1085, t1086, t1087, t1088) = cudnn_sdpa_fwd(t1081, t1084, t1046, None, 0.0, True, scale=0.08838834764831843) - t1091 = prims.transpose(t1085, (0, 2, 1, 3)) # t1091: "cuda:0 bf16[1, 512, 32, 128]" - t1095 = prims.reshape(t1091, (1, 512, 4096)) # t1095: "cuda:0 bf16[1, 512, 4096]" - t1096 = prims.linear(t1095, t_transformer_h_7_attn_proj_weight, None) # t1096: "cuda:0 bf16[1, 512, 4096]" - t1097 = prims.convert_element_type(t1096, dtypes.float32) # t1097: "cuda:0 f32[1, 512, 4096]" - t1098 = prims.convert_element_type(t994, dtypes.float32) # t1098: "cuda:0 f32[1, 512, 4096]" - t1099 = ltorch.add(t1097, t1098, alpha=None) # t1099: "cuda:0 f32[1, 512, 4096]" - # t1099 = prims.add(t1097, t1098) # t1099: "cuda:0 f32[1, 512, 4096]" - t1100 = prims.convert_element_type(t1099, dtypes.bfloat16) # t1100: "cuda:0 bf16[1, 512, 4096]" - t1101 = prims.convert_element_type(t1100, dtypes.float32) # t1101: "cuda:0 f32[1, 512, 4096]" - t1102 = ltorch.mul(t1101, t1101) # t1102: "cuda:0 f32[1, 512, 4096]" - # t1102 = prims.mul(t1101, t1101) # t1102: "cuda:0 f32[1, 512, 4096]" - t1104 = prims.sum(t1102, (2,)) # t1104: "cuda:0 f32[1, 512]" - t1105 = prims.broadcast_in_dim(t1104, [1, 512, 1], [0, 1]) # t1105: "cuda:0 f32[1, 512, 1]" - t1107 = ltorch.true_divide(t1105, 4096.0) # t1107: "cuda:0 f32[1, 512, 1]" - # t1107 = prims.div(t1105, 4096.0) # t1107: "cuda:0 f32[1, 512, 1]" - t1109 = ltorch.add(t1107, 1e-05, alpha=None) # t1109: "cuda:0 f32[1, 512, 1]" - # t1109 = prims.add(t1107, 1e-05) # t1109: "cuda:0 f32[1, 512, 1]" - t1110 = prims.rsqrt(t1109) # t1110: "cuda:0 f32[1, 512, 1]" - t1111 = prims.broadcast_in_dim(t1110, (1, 512, 4096), (0, 1, 2)) # t1111: "cuda:0 f32[1, 512, 4096]" - t1112 = ltorch.mul(t1101, t1111) # t1112: "cuda:0 f32[1, 512, 4096]" - # t1112 = prims.mul(t1101, t1111) # t1112: "cuda:0 f32[1, 512, 4096]" - t1113 = prims.convert_element_type(t1112, dtypes.bfloat16) # t1113: "cuda:0 bf16[1, 512, 4096]" - t1114 = prims.broadcast_in_dim(t_transformer_h_7_norm_2_weight, (1, 512, 4096), (2,)) # t1114: "cuda:0 bf16[1, 512, 4096]" - t1115 = prims.convert_element_type(t1113, dtypes.float32) # t1115: "cuda:0 f32[1, 512, 4096]" - t1116 = prims.convert_element_type(t1114, dtypes.float32) # t1116: "cuda:0 f32[1, 512, 4096]" - t1117 = ltorch.mul(t1115, t1116) # t1117: "cuda:0 f32[1, 512, 4096]" - # t1117 = prims.mul(t1115, t1116) # t1117: "cuda:0 f32[1, 512, 4096]" - t1118 = prims.convert_element_type(t1117, dtypes.bfloat16) # t1118: "cuda:0 bf16[1, 512, 4096]" - t1119 = prims.linear(t1118, t_transformer_h_7_mlp_fc_1_weight, None) # t1119: "cuda:0 bf16[1, 512, 11008]" - t1120 = prims.linear(t1118, t_transformer_h_7_mlp_fc_2_weight, None) # t1120: "cuda:0 bf16[1, 512, 11008]" - t1121 = prims.convert_element_type(t1119, dtypes.float32) # t1121: "cuda:0 f32[1, 512, 11008]" - t1122 = prims.neg(t1121) # t1122: "cuda:0 f32[1, 512, 11008]" - t1123 = prims.exp(t1122) # t1123: "cuda:0 f32[1, 512, 11008]" - t1124 = ltorch.add(1.0, t1123, alpha=None) # t1124: "cuda:0 f32[1, 512, 11008]" - # t1124 = prims.add(1.0, t1123) # t1124: "cuda:0 f32[1, 512, 11008]" - t1125 = prims.reciprocal(t1124) # t1125: "cuda:0 f32[1, 512, 11008]" - t1126 = prims.convert_element_type(t1125, dtypes.bfloat16) # t1126: "cuda:0 bf16[1, 512, 11008]" - t1127 = prims.convert_element_type(t1119, dtypes.float32) # t1127: "cuda:0 f32[1, 512, 11008]" - t1128 = prims.convert_element_type(t1126, dtypes.float32) # t1128: "cuda:0 f32[1, 512, 11008]" - t1129 = ltorch.mul(t1127, t1128) # t1129: "cuda:0 f32[1, 512, 11008]" - # t1129 = prims.mul(t1127, t1128) # t1129: "cuda:0 f32[1, 512, 11008]" - t1130 = prims.convert_element_type(t1129, dtypes.bfloat16) # t1130: "cuda:0 bf16[1, 512, 11008]" - t1131 = prims.convert_element_type(t1130, dtypes.float32) # t1131: "cuda:0 f32[1, 512, 11008]" - t1132 = prims.convert_element_type(t1120, dtypes.float32) # t1132: "cuda:0 f32[1, 512, 11008]" - t1133 = ltorch.mul(t1131, t1132) # t1133: "cuda:0 f32[1, 512, 11008]" - # t1133 = prims.mul(t1131, t1132) # t1133: "cuda:0 f32[1, 512, 11008]" - t1134 = prims.convert_element_type(t1133, dtypes.bfloat16) # t1134: "cuda:0 bf16[1, 512, 11008]" - t1135 = prims.linear(t1134, t_transformer_h_7_mlp_proj_weight, None) # t1135: "cuda:0 bf16[1, 512, 4096]" - t1136 = prims.convert_element_type(t1135, dtypes.float32) # t1136: "cuda:0 f32[1, 512, 4096]" - t1137 = prims.convert_element_type(t1100, dtypes.float32) # t1137: "cuda:0 f32[1, 512, 4096]" - t1138 = ltorch.add(t1136, t1137, alpha=None) # t1138: "cuda:0 f32[1, 512, 4096]" - # t1138 = prims.add(t1136, t1137) # t1138: "cuda:0 f32[1, 512, 4096]" - t1139 = prims.convert_element_type(t1138, dtypes.bfloat16) # t1139: "cuda:0 bf16[1, 512, 4096]" - t1140 = prims.convert_element_type(t1139, dtypes.float32) # t1140: "cuda:0 f32[1, 512, 4096]" - t1141 = ltorch.mul(t1140, t1140) # t1141: "cuda:0 f32[1, 512, 4096]" - # t1141 = prims.mul(t1140, t1140) # t1141: "cuda:0 f32[1, 512, 4096]" - t1143 = prims.sum(t1141, (2,)) # t1143: "cuda:0 f32[1, 512]" - t1144 = prims.broadcast_in_dim(t1143, [1, 512, 1], [0, 1]) # t1144: "cuda:0 f32[1, 512, 1]" - t1146 = ltorch.true_divide(t1144, 4096.0) # t1146: "cuda:0 f32[1, 512, 1]" - # t1146 = prims.div(t1144, 4096.0) # t1146: "cuda:0 f32[1, 512, 1]" - t1148 = ltorch.add(t1146, 1e-05, alpha=None) # t1148: "cuda:0 f32[1, 512, 1]" - # t1148 = prims.add(t1146, 1e-05) # t1148: "cuda:0 f32[1, 512, 1]" - t1149 = prims.rsqrt(t1148) # t1149: "cuda:0 f32[1, 512, 1]" - t1150 = prims.broadcast_in_dim(t1149, (1, 512, 4096), (0, 1, 2)) # t1150: "cuda:0 f32[1, 512, 4096]" - t1151 = ltorch.mul(t1140, t1150) # t1151: "cuda:0 f32[1, 512, 4096]" - # t1151 = prims.mul(t1140, t1150) # t1151: "cuda:0 f32[1, 512, 4096]" - t1152 = prims.convert_element_type(t1151, dtypes.bfloat16) # t1152: "cuda:0 bf16[1, 512, 4096]" - t1153 = prims.broadcast_in_dim(t_transformer_h_8_norm_1_weight, (1, 512, 4096), (2,)) # t1153: "cuda:0 bf16[1, 512, 4096]" - t1154 = prims.convert_element_type(t1152, dtypes.float32) # t1154: "cuda:0 f32[1, 512, 4096]" - t1155 = prims.convert_element_type(t1153, dtypes.float32) # t1155: "cuda:0 f32[1, 512, 4096]" - t1156 = ltorch.mul(t1154, t1155) # t1156: "cuda:0 f32[1, 512, 4096]" - # t1156 = prims.mul(t1154, t1155) # t1156: "cuda:0 f32[1, 512, 4096]" - t1157 = prims.convert_element_type(t1156, dtypes.bfloat16) # t1157: "cuda:0 bf16[1, 512, 4096]" - t1158 = prims.linear(t1157, t_transformer_h_8_attn_attn_weight, None) # t1158: "cuda:0 bf16[1, 512, 12288]" - t1164 = prims.reshape(t1158, (1, 512, 32, 3, 128)) # t1164: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1170 = prims.transpose(t1164, (0, 2, 3, 1, 4)) # t1170: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1171, t1172, t1173) = ltorch.split(t1170, (1, 1, 1), 2) - # t1171 = prims.slice_prim(t1170, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1171: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1172 = prims.slice_prim(t1170, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1172: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1173 = prims.slice_prim(t1170, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1173: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1179 = prims.reshape(t1171, (1, 32, 512, 128)) # t1179: "cuda:0 bf16[1, 32, 512, 128]" - t1185 = prims.reshape(t1172, (1, 32, 512, 128)) # t1185: "cuda:0 bf16[1, 32, 512, 128]" - t1191 = prims.reshape(t1173, (1, 32, 512, 128)) # t1191: "cuda:0 bf16[1, 32, 512, 128]" - t1192 = prims.slice_prim(t1179, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1192: "cuda:0 bf16[1, 32, 512, 128]" - t1193 = prims.slice_prim(t1192, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1193: "cuda:0 bf16[1, 32, 512, 64]" - t1194 = prims.slice_prim(t1192, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1194: "cuda:0 bf16[1, 32, 512, 64]" - t1195 = prims.convert_element_type(t1194, dtypes.float32) # t1195: "cuda:0 f32[1, 32, 512, 64]" - t1196 = prims.neg(t1195) # t1196: "cuda:0 f32[1, 32, 512, 64]" - t1197 = prims.convert_element_type(t1196, dtypes.bfloat16) # t1197: "cuda:0 bf16[1, 32, 512, 64]" - t1199 = prims.cat((t1197, t1193), -1) # t1199: "cuda:0 bf16[1, 32, 512, 128]" - t1200 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1200: "cuda:0 f32[1, 32, 512, 128]" - t1201 = prims.convert_element_type(t1192, dtypes.float32) # t1201: "cuda:0 f32[1, 32, 512, 128]" - t1202 = ltorch.mul(t1201, t1200) # t1202: "cuda:0 f32[1, 32, 512, 128]" - # t1202 = prims.mul(t1201, t1200) # t1202: "cuda:0 f32[1, 32, 512, 128]" - t1203 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1203: "cuda:0 f32[1, 32, 512, 128]" - t1204 = prims.convert_element_type(t1199, dtypes.float32) # t1204: "cuda:0 f32[1, 32, 512, 128]" - t1205 = ltorch.mul(t1204, t1203) # t1205: "cuda:0 f32[1, 32, 512, 128]" - # t1205 = prims.mul(t1204, t1203) # t1205: "cuda:0 f32[1, 32, 512, 128]" - t1206 = ltorch.add(t1202, t1205, alpha=None) # t1206: "cuda:0 f32[1, 32, 512, 128]" - # t1206 = prims.add(t1202, t1205) # t1206: "cuda:0 f32[1, 32, 512, 128]" - t1207 = prims.convert_element_type(t1206, dtypes.bfloat16) # t1207: "cuda:0 bf16[1, 32, 512, 128]" - t1208 = prims.slice_prim(t1185, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1208: "cuda:0 bf16[1, 32, 512, 128]" - t1209 = prims.slice_prim(t1208, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1209: "cuda:0 bf16[1, 32, 512, 64]" - t1210 = prims.slice_prim(t1208, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1210: "cuda:0 bf16[1, 32, 512, 64]" - t1211 = prims.convert_element_type(t1210, dtypes.float32) # t1211: "cuda:0 f32[1, 32, 512, 64]" - t1212 = prims.neg(t1211) # t1212: "cuda:0 f32[1, 32, 512, 64]" - t1213 = prims.convert_element_type(t1212, dtypes.bfloat16) # t1213: "cuda:0 bf16[1, 32, 512, 64]" - t1215 = prims.cat((t1213, t1209), -1) # t1215: "cuda:0 bf16[1, 32, 512, 128]" - t1216 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1216: "cuda:0 f32[1, 32, 512, 128]" - t1217 = prims.convert_element_type(t1208, dtypes.float32) # t1217: "cuda:0 f32[1, 32, 512, 128]" - t1218 = ltorch.mul(t1217, t1216) # t1218: "cuda:0 f32[1, 32, 512, 128]" - # t1218 = prims.mul(t1217, t1216) # t1218: "cuda:0 f32[1, 32, 512, 128]" - t1219 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1219: "cuda:0 f32[1, 32, 512, 128]" - t1220 = prims.convert_element_type(t1215, dtypes.float32) # t1220: "cuda:0 f32[1, 32, 512, 128]" - t1221 = ltorch.mul(t1220, t1219) # t1221: "cuda:0 f32[1, 32, 512, 128]" - # t1221 = prims.mul(t1220, t1219) # t1221: "cuda:0 f32[1, 32, 512, 128]" - t1222 = ltorch.add(t1218, t1221, alpha=None) # t1222: "cuda:0 f32[1, 32, 512, 128]" - # t1222 = prims.add(t1218, t1221) # t1222: "cuda:0 f32[1, 32, 512, 128]" - t1223 = prims.convert_element_type(t1222, dtypes.bfloat16) # t1223: "cuda:0 bf16[1, 32, 512, 128]" - t1224 = prims.slice_prim(t1179, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1224: "cuda:0 bf16[1, 32, 512, 0]" - t1226 = prims.cat((t1207, t1224), -1) # t1226: "cuda:0 bf16[1, 32, 512, 128]" - t1227 = prims.slice_prim(t1185, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1227: "cuda:0 bf16[1, 32, 512, 0]" - t1229 = prims.cat((t1223, t1227), -1) # t1229: "cuda:0 bf16[1, 32, 512, 128]" - (t1230, t1231, t1232, t1233) = cudnn_sdpa_fwd(t1226, t1229, t1191, None, 0.0, True, scale=0.08838834764831843) - t1236 = prims.transpose(t1230, (0, 2, 1, 3)) # t1236: "cuda:0 bf16[1, 512, 32, 128]" - t1240 = prims.reshape(t1236, (1, 512, 4096)) # t1240: "cuda:0 bf16[1, 512, 4096]" - t1241 = prims.linear(t1240, t_transformer_h_8_attn_proj_weight, None) # t1241: "cuda:0 bf16[1, 512, 4096]" - t1242 = prims.convert_element_type(t1241, dtypes.float32) # t1242: "cuda:0 f32[1, 512, 4096]" - t1243 = prims.convert_element_type(t1139, dtypes.float32) # t1243: "cuda:0 f32[1, 512, 4096]" - t1244 = ltorch.add(t1242, t1243, alpha=None) # t1244: "cuda:0 f32[1, 512, 4096]" - # t1244 = prims.add(t1242, t1243) # t1244: "cuda:0 f32[1, 512, 4096]" - t1245 = prims.convert_element_type(t1244, dtypes.bfloat16) # t1245: "cuda:0 bf16[1, 512, 4096]" - t1246 = prims.convert_element_type(t1245, dtypes.float32) # t1246: "cuda:0 f32[1, 512, 4096]" - t1247 = ltorch.mul(t1246, t1246) # t1247: "cuda:0 f32[1, 512, 4096]" - # t1247 = prims.mul(t1246, t1246) # t1247: "cuda:0 f32[1, 512, 4096]" - t1249 = prims.sum(t1247, (2,)) # t1249: "cuda:0 f32[1, 512]" - t1250 = prims.broadcast_in_dim(t1249, [1, 512, 1], [0, 1]) # t1250: "cuda:0 f32[1, 512, 1]" - t1252 = ltorch.true_divide(t1250, 4096.0) # t1252: "cuda:0 f32[1, 512, 1]" - # t1252 = prims.div(t1250, 4096.0) # t1252: "cuda:0 f32[1, 512, 1]" - t1254 = ltorch.add(t1252, 1e-05, alpha=None) # t1254: "cuda:0 f32[1, 512, 1]" - # t1254 = prims.add(t1252, 1e-05) # t1254: "cuda:0 f32[1, 512, 1]" - t1255 = prims.rsqrt(t1254) # t1255: "cuda:0 f32[1, 512, 1]" - t1256 = prims.broadcast_in_dim(t1255, (1, 512, 4096), (0, 1, 2)) # t1256: "cuda:0 f32[1, 512, 4096]" - t1257 = ltorch.mul(t1246, t1256) # t1257: "cuda:0 f32[1, 512, 4096]" - # t1257 = prims.mul(t1246, t1256) # t1257: "cuda:0 f32[1, 512, 4096]" - t1258 = prims.convert_element_type(t1257, dtypes.bfloat16) # t1258: "cuda:0 bf16[1, 512, 4096]" - t1259 = prims.broadcast_in_dim(t_transformer_h_8_norm_2_weight, (1, 512, 4096), (2,)) # t1259: "cuda:0 bf16[1, 512, 4096]" - t1260 = prims.convert_element_type(t1258, dtypes.float32) # t1260: "cuda:0 f32[1, 512, 4096]" - t1261 = prims.convert_element_type(t1259, dtypes.float32) # t1261: "cuda:0 f32[1, 512, 4096]" - t1262 = ltorch.mul(t1260, t1261) # t1262: "cuda:0 f32[1, 512, 4096]" - # t1262 = prims.mul(t1260, t1261) # t1262: "cuda:0 f32[1, 512, 4096]" - t1263 = prims.convert_element_type(t1262, dtypes.bfloat16) # t1263: "cuda:0 bf16[1, 512, 4096]" - t1264 = prims.linear(t1263, t_transformer_h_8_mlp_fc_1_weight, None) # t1264: "cuda:0 bf16[1, 512, 11008]" - t1265 = prims.linear(t1263, t_transformer_h_8_mlp_fc_2_weight, None) # t1265: "cuda:0 bf16[1, 512, 11008]" - t1266 = prims.convert_element_type(t1264, dtypes.float32) # t1266: "cuda:0 f32[1, 512, 11008]" - t1267 = prims.neg(t1266) # t1267: "cuda:0 f32[1, 512, 11008]" - t1268 = prims.exp(t1267) # t1268: "cuda:0 f32[1, 512, 11008]" - t1269 = ltorch.add(1.0, t1268, alpha=None) # t1269: "cuda:0 f32[1, 512, 11008]" - # t1269 = prims.add(1.0, t1268) # t1269: "cuda:0 f32[1, 512, 11008]" - t1270 = prims.reciprocal(t1269) # t1270: "cuda:0 f32[1, 512, 11008]" - t1271 = prims.convert_element_type(t1270, dtypes.bfloat16) # t1271: "cuda:0 bf16[1, 512, 11008]" - t1272 = prims.convert_element_type(t1264, dtypes.float32) # t1272: "cuda:0 f32[1, 512, 11008]" - t1273 = prims.convert_element_type(t1271, dtypes.float32) # t1273: "cuda:0 f32[1, 512, 11008]" - t1274 = ltorch.mul(t1272, t1273) # t1274: "cuda:0 f32[1, 512, 11008]" - # t1274 = prims.mul(t1272, t1273) # t1274: "cuda:0 f32[1, 512, 11008]" - t1275 = prims.convert_element_type(t1274, dtypes.bfloat16) # t1275: "cuda:0 bf16[1, 512, 11008]" - t1276 = prims.convert_element_type(t1275, dtypes.float32) # t1276: "cuda:0 f32[1, 512, 11008]" - t1277 = prims.convert_element_type(t1265, dtypes.float32) # t1277: "cuda:0 f32[1, 512, 11008]" - t1278 = ltorch.mul(t1276, t1277) # t1278: "cuda:0 f32[1, 512, 11008]" - # t1278 = prims.mul(t1276, t1277) # t1278: "cuda:0 f32[1, 512, 11008]" - t1279 = prims.convert_element_type(t1278, dtypes.bfloat16) # t1279: "cuda:0 bf16[1, 512, 11008]" - t1280 = prims.linear(t1279, t_transformer_h_8_mlp_proj_weight, None) # t1280: "cuda:0 bf16[1, 512, 4096]" - t1281 = prims.convert_element_type(t1280, dtypes.float32) # t1281: "cuda:0 f32[1, 512, 4096]" - t1282 = prims.convert_element_type(t1245, dtypes.float32) # t1282: "cuda:0 f32[1, 512, 4096]" - t1283 = ltorch.add(t1281, t1282, alpha=None) # t1283: "cuda:0 f32[1, 512, 4096]" - # t1283 = prims.add(t1281, t1282) # t1283: "cuda:0 f32[1, 512, 4096]" - t1284 = prims.convert_element_type(t1283, dtypes.bfloat16) # t1284: "cuda:0 bf16[1, 512, 4096]" - t1285 = prims.convert_element_type(t1284, dtypes.float32) # t1285: "cuda:0 f32[1, 512, 4096]" - t1286 = ltorch.mul(t1285, t1285) # t1286: "cuda:0 f32[1, 512, 4096]" - # t1286 = prims.mul(t1285, t1285) # t1286: "cuda:0 f32[1, 512, 4096]" - t1288 = prims.sum(t1286, (2,)) # t1288: "cuda:0 f32[1, 512]" - t1289 = prims.broadcast_in_dim(t1288, [1, 512, 1], [0, 1]) # t1289: "cuda:0 f32[1, 512, 1]" - t1291 = ltorch.true_divide(t1289, 4096.0) # t1291: "cuda:0 f32[1, 512, 1]" - # t1291 = prims.div(t1289, 4096.0) # t1291: "cuda:0 f32[1, 512, 1]" - t1293 = ltorch.add(t1291, 1e-05, alpha=None) # t1293: "cuda:0 f32[1, 512, 1]" - # t1293 = prims.add(t1291, 1e-05) # t1293: "cuda:0 f32[1, 512, 1]" - t1294 = prims.rsqrt(t1293) # t1294: "cuda:0 f32[1, 512, 1]" - t1295 = prims.broadcast_in_dim(t1294, (1, 512, 4096), (0, 1, 2)) # t1295: "cuda:0 f32[1, 512, 4096]" - t1296 = ltorch.mul(t1285, t1295) # t1296: "cuda:0 f32[1, 512, 4096]" - # t1296 = prims.mul(t1285, t1295) # t1296: "cuda:0 f32[1, 512, 4096]" - t1297 = prims.convert_element_type(t1296, dtypes.bfloat16) # t1297: "cuda:0 bf16[1, 512, 4096]" - t1298 = prims.broadcast_in_dim(t_transformer_h_9_norm_1_weight, (1, 512, 4096), (2,)) # t1298: "cuda:0 bf16[1, 512, 4096]" - t1299 = prims.convert_element_type(t1297, dtypes.float32) # t1299: "cuda:0 f32[1, 512, 4096]" - t1300 = prims.convert_element_type(t1298, dtypes.float32) # t1300: "cuda:0 f32[1, 512, 4096]" - t1301 = ltorch.mul(t1299, t1300) # t1301: "cuda:0 f32[1, 512, 4096]" - # t1301 = prims.mul(t1299, t1300) # t1301: "cuda:0 f32[1, 512, 4096]" - t1302 = prims.convert_element_type(t1301, dtypes.bfloat16) # t1302: "cuda:0 bf16[1, 512, 4096]" - t1303 = prims.linear(t1302, t_transformer_h_9_attn_attn_weight, None) # t1303: "cuda:0 bf16[1, 512, 12288]" - t1309 = prims.reshape(t1303, (1, 512, 32, 3, 128)) # t1309: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1315 = prims.transpose(t1309, (0, 2, 3, 1, 4)) # t1315: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1316, t1317, t1318) = ltorch.split(t1315, (1, 1, 1), 2) - # t1316 = prims.slice_prim(t1315, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1316: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1317 = prims.slice_prim(t1315, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1317: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1318 = prims.slice_prim(t1315, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1318: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1324 = prims.reshape(t1316, (1, 32, 512, 128)) # t1324: "cuda:0 bf16[1, 32, 512, 128]" - t1330 = prims.reshape(t1317, (1, 32, 512, 128)) # t1330: "cuda:0 bf16[1, 32, 512, 128]" - t1336 = prims.reshape(t1318, (1, 32, 512, 128)) # t1336: "cuda:0 bf16[1, 32, 512, 128]" - t1337 = prims.slice_prim(t1324, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1337: "cuda:0 bf16[1, 32, 512, 128]" - t1338 = prims.slice_prim(t1337, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1338: "cuda:0 bf16[1, 32, 512, 64]" - t1339 = prims.slice_prim(t1337, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1339: "cuda:0 bf16[1, 32, 512, 64]" - t1340 = prims.convert_element_type(t1339, dtypes.float32) # t1340: "cuda:0 f32[1, 32, 512, 64]" - t1341 = prims.neg(t1340) # t1341: "cuda:0 f32[1, 32, 512, 64]" - t1342 = prims.convert_element_type(t1341, dtypes.bfloat16) # t1342: "cuda:0 bf16[1, 32, 512, 64]" - t1344 = prims.cat((t1342, t1338), -1) # t1344: "cuda:0 bf16[1, 32, 512, 128]" - t1345 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1345: "cuda:0 f32[1, 32, 512, 128]" - t1346 = prims.convert_element_type(t1337, dtypes.float32) # t1346: "cuda:0 f32[1, 32, 512, 128]" - t1347 = ltorch.mul(t1346, t1345) # t1347: "cuda:0 f32[1, 32, 512, 128]" - # t1347 = prims.mul(t1346, t1345) # t1347: "cuda:0 f32[1, 32, 512, 128]" - t1348 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1348: "cuda:0 f32[1, 32, 512, 128]" - t1349 = prims.convert_element_type(t1344, dtypes.float32) # t1349: "cuda:0 f32[1, 32, 512, 128]" - t1350 = ltorch.mul(t1349, t1348) # t1350: "cuda:0 f32[1, 32, 512, 128]" - # t1350 = prims.mul(t1349, t1348) # t1350: "cuda:0 f32[1, 32, 512, 128]" - t1351 = ltorch.add(t1347, t1350, alpha=None) # t1351: "cuda:0 f32[1, 32, 512, 128]" - # t1351 = prims.add(t1347, t1350) # t1351: "cuda:0 f32[1, 32, 512, 128]" - t1352 = prims.convert_element_type(t1351, dtypes.bfloat16) # t1352: "cuda:0 bf16[1, 32, 512, 128]" - t1353 = prims.slice_prim(t1330, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1353: "cuda:0 bf16[1, 32, 512, 128]" - t1354 = prims.slice_prim(t1353, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1354: "cuda:0 bf16[1, 32, 512, 64]" - t1355 = prims.slice_prim(t1353, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1355: "cuda:0 bf16[1, 32, 512, 64]" - t1356 = prims.convert_element_type(t1355, dtypes.float32) # t1356: "cuda:0 f32[1, 32, 512, 64]" - t1357 = prims.neg(t1356) # t1357: "cuda:0 f32[1, 32, 512, 64]" - t1358 = prims.convert_element_type(t1357, dtypes.bfloat16) # t1358: "cuda:0 bf16[1, 32, 512, 64]" - t1360 = prims.cat((t1358, t1354), -1) # t1360: "cuda:0 bf16[1, 32, 512, 128]" - t1361 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1361: "cuda:0 f32[1, 32, 512, 128]" - t1362 = prims.convert_element_type(t1353, dtypes.float32) # t1362: "cuda:0 f32[1, 32, 512, 128]" - t1363 = ltorch.mul(t1362, t1361) # t1363: "cuda:0 f32[1, 32, 512, 128]" - # t1363 = prims.mul(t1362, t1361) # t1363: "cuda:0 f32[1, 32, 512, 128]" - t1364 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1364: "cuda:0 f32[1, 32, 512, 128]" - t1365 = prims.convert_element_type(t1360, dtypes.float32) # t1365: "cuda:0 f32[1, 32, 512, 128]" - t1366 = ltorch.mul(t1365, t1364) # t1366: "cuda:0 f32[1, 32, 512, 128]" - # t1366 = prims.mul(t1365, t1364) # t1366: "cuda:0 f32[1, 32, 512, 128]" - t1367 = ltorch.add(t1363, t1366, alpha=None) # t1367: "cuda:0 f32[1, 32, 512, 128]" - # t1367 = prims.add(t1363, t1366) # t1367: "cuda:0 f32[1, 32, 512, 128]" - t1368 = prims.convert_element_type(t1367, dtypes.bfloat16) # t1368: "cuda:0 bf16[1, 32, 512, 128]" - t1369 = prims.slice_prim(t1324, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1369: "cuda:0 bf16[1, 32, 512, 0]" - t1371 = prims.cat((t1352, t1369), -1) # t1371: "cuda:0 bf16[1, 32, 512, 128]" - t1372 = prims.slice_prim(t1330, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1372: "cuda:0 bf16[1, 32, 512, 0]" - t1374 = prims.cat((t1368, t1372), -1) # t1374: "cuda:0 bf16[1, 32, 512, 128]" - (t1375, t1376, t1377, t1378) = cudnn_sdpa_fwd(t1371, t1374, t1336, None, 0.0, True, scale=0.08838834764831843) - t1381 = prims.transpose(t1375, (0, 2, 1, 3)) # t1381: "cuda:0 bf16[1, 512, 32, 128]" - t1385 = prims.reshape(t1381, (1, 512, 4096)) # t1385: "cuda:0 bf16[1, 512, 4096]" - t1386 = prims.linear(t1385, t_transformer_h_9_attn_proj_weight, None) # t1386: "cuda:0 bf16[1, 512, 4096]" - t1387 = prims.convert_element_type(t1386, dtypes.float32) # t1387: "cuda:0 f32[1, 512, 4096]" - t1388 = prims.convert_element_type(t1284, dtypes.float32) # t1388: "cuda:0 f32[1, 512, 4096]" - t1389 = ltorch.add(t1387, t1388, alpha=None) # t1389: "cuda:0 f32[1, 512, 4096]" - # t1389 = prims.add(t1387, t1388) # t1389: "cuda:0 f32[1, 512, 4096]" - t1390 = prims.convert_element_type(t1389, dtypes.bfloat16) # t1390: "cuda:0 bf16[1, 512, 4096]" - t1391 = prims.convert_element_type(t1390, dtypes.float32) # t1391: "cuda:0 f32[1, 512, 4096]" - t1392 = ltorch.mul(t1391, t1391) # t1392: "cuda:0 f32[1, 512, 4096]" - # t1392 = prims.mul(t1391, t1391) # t1392: "cuda:0 f32[1, 512, 4096]" - t1394 = prims.sum(t1392, (2,)) # t1394: "cuda:0 f32[1, 512]" - t1395 = prims.broadcast_in_dim(t1394, [1, 512, 1], [0, 1]) # t1395: "cuda:0 f32[1, 512, 1]" - t1397 = ltorch.true_divide(t1395, 4096.0) # t1397: "cuda:0 f32[1, 512, 1]" - # t1397 = prims.div(t1395, 4096.0) # t1397: "cuda:0 f32[1, 512, 1]" - t1399 = ltorch.add(t1397, 1e-05, alpha=None) # t1399: "cuda:0 f32[1, 512, 1]" - # t1399 = prims.add(t1397, 1e-05) # t1399: "cuda:0 f32[1, 512, 1]" - t1400 = prims.rsqrt(t1399) # t1400: "cuda:0 f32[1, 512, 1]" - t1401 = prims.broadcast_in_dim(t1400, (1, 512, 4096), (0, 1, 2)) # t1401: "cuda:0 f32[1, 512, 4096]" - t1402 = ltorch.mul(t1391, t1401) # t1402: "cuda:0 f32[1, 512, 4096]" - # t1402 = prims.mul(t1391, t1401) # t1402: "cuda:0 f32[1, 512, 4096]" - t1403 = prims.convert_element_type(t1402, dtypes.bfloat16) # t1403: "cuda:0 bf16[1, 512, 4096]" - t1404 = prims.broadcast_in_dim(t_transformer_h_9_norm_2_weight, (1, 512, 4096), (2,)) # t1404: "cuda:0 bf16[1, 512, 4096]" - t1405 = prims.convert_element_type(t1403, dtypes.float32) # t1405: "cuda:0 f32[1, 512, 4096]" - t1406 = prims.convert_element_type(t1404, dtypes.float32) # t1406: "cuda:0 f32[1, 512, 4096]" - t1407 = ltorch.mul(t1405, t1406) # t1407: "cuda:0 f32[1, 512, 4096]" - # t1407 = prims.mul(t1405, t1406) # t1407: "cuda:0 f32[1, 512, 4096]" - t1408 = prims.convert_element_type(t1407, dtypes.bfloat16) # t1408: "cuda:0 bf16[1, 512, 4096]" - t1409 = prims.linear(t1408, t_transformer_h_9_mlp_fc_1_weight, None) # t1409: "cuda:0 bf16[1, 512, 11008]" - t1410 = prims.linear(t1408, t_transformer_h_9_mlp_fc_2_weight, None) # t1410: "cuda:0 bf16[1, 512, 11008]" - t1411 = prims.convert_element_type(t1409, dtypes.float32) # t1411: "cuda:0 f32[1, 512, 11008]" - t1412 = prims.neg(t1411) # t1412: "cuda:0 f32[1, 512, 11008]" - t1413 = prims.exp(t1412) # t1413: "cuda:0 f32[1, 512, 11008]" - t1414 = ltorch.add(1.0, t1413, alpha=None) # t1414: "cuda:0 f32[1, 512, 11008]" - # t1414 = prims.add(1.0, t1413) # t1414: "cuda:0 f32[1, 512, 11008]" - t1415 = prims.reciprocal(t1414) # t1415: "cuda:0 f32[1, 512, 11008]" - t1416 = prims.convert_element_type(t1415, dtypes.bfloat16) # t1416: "cuda:0 bf16[1, 512, 11008]" - t1417 = prims.convert_element_type(t1409, dtypes.float32) # t1417: "cuda:0 f32[1, 512, 11008]" - t1418 = prims.convert_element_type(t1416, dtypes.float32) # t1418: "cuda:0 f32[1, 512, 11008]" - t1419 = ltorch.mul(t1417, t1418) # t1419: "cuda:0 f32[1, 512, 11008]" - # t1419 = prims.mul(t1417, t1418) # t1419: "cuda:0 f32[1, 512, 11008]" - t1420 = prims.convert_element_type(t1419, dtypes.bfloat16) # t1420: "cuda:0 bf16[1, 512, 11008]" - t1421 = prims.convert_element_type(t1420, dtypes.float32) # t1421: "cuda:0 f32[1, 512, 11008]" - t1422 = prims.convert_element_type(t1410, dtypes.float32) # t1422: "cuda:0 f32[1, 512, 11008]" - t1423 = ltorch.mul(t1421, t1422) # t1423: "cuda:0 f32[1, 512, 11008]" - # t1423 = prims.mul(t1421, t1422) # t1423: "cuda:0 f32[1, 512, 11008]" - t1424 = prims.convert_element_type(t1423, dtypes.bfloat16) # t1424: "cuda:0 bf16[1, 512, 11008]" - t1425 = prims.linear(t1424, t_transformer_h_9_mlp_proj_weight, None) # t1425: "cuda:0 bf16[1, 512, 4096]" - t1426 = prims.convert_element_type(t1425, dtypes.float32) # t1426: "cuda:0 f32[1, 512, 4096]" - t1427 = prims.convert_element_type(t1390, dtypes.float32) # t1427: "cuda:0 f32[1, 512, 4096]" - t1428 = ltorch.add(t1426, t1427, alpha=None) # t1428: "cuda:0 f32[1, 512, 4096]" - # t1428 = prims.add(t1426, t1427) # t1428: "cuda:0 f32[1, 512, 4096]" - t1429 = prims.convert_element_type(t1428, dtypes.bfloat16) # t1429: "cuda:0 bf16[1, 512, 4096]" - t1430 = prims.convert_element_type(t1429, dtypes.float32) # t1430: "cuda:0 f32[1, 512, 4096]" - t1431 = ltorch.mul(t1430, t1430) # t1431: "cuda:0 f32[1, 512, 4096]" - # t1431 = prims.mul(t1430, t1430) # t1431: "cuda:0 f32[1, 512, 4096]" - t1433 = prims.sum(t1431, (2,)) # t1433: "cuda:0 f32[1, 512]" - t1434 = prims.broadcast_in_dim(t1433, [1, 512, 1], [0, 1]) # t1434: "cuda:0 f32[1, 512, 1]" - t1436 = ltorch.true_divide(t1434, 4096.0) # t1436: "cuda:0 f32[1, 512, 1]" - # t1436 = prims.div(t1434, 4096.0) # t1436: "cuda:0 f32[1, 512, 1]" - t1438 = ltorch.add(t1436, 1e-05, alpha=None) # t1438: "cuda:0 f32[1, 512, 1]" - # t1438 = prims.add(t1436, 1e-05) # t1438: "cuda:0 f32[1, 512, 1]" - t1439 = prims.rsqrt(t1438) # t1439: "cuda:0 f32[1, 512, 1]" - t1440 = prims.broadcast_in_dim(t1439, (1, 512, 4096), (0, 1, 2)) # t1440: "cuda:0 f32[1, 512, 4096]" - t1441 = ltorch.mul(t1430, t1440) # t1441: "cuda:0 f32[1, 512, 4096]" - # t1441 = prims.mul(t1430, t1440) # t1441: "cuda:0 f32[1, 512, 4096]" - t1442 = prims.convert_element_type(t1441, dtypes.bfloat16) # t1442: "cuda:0 bf16[1, 512, 4096]" - t1443 = prims.broadcast_in_dim(t_transformer_h_10_norm_1_weight, (1, 512, 4096), (2,)) # t1443: "cuda:0 bf16[1, 512, 4096]" - t1444 = prims.convert_element_type(t1442, dtypes.float32) # t1444: "cuda:0 f32[1, 512, 4096]" - t1445 = prims.convert_element_type(t1443, dtypes.float32) # t1445: "cuda:0 f32[1, 512, 4096]" - t1446 = ltorch.mul(t1444, t1445) # t1446: "cuda:0 f32[1, 512, 4096]" - # t1446 = prims.mul(t1444, t1445) # t1446: "cuda:0 f32[1, 512, 4096]" - t1447 = prims.convert_element_type(t1446, dtypes.bfloat16) # t1447: "cuda:0 bf16[1, 512, 4096]" - t1448 = prims.linear(t1447, t_transformer_h_10_attn_attn_weight, None) # t1448: "cuda:0 bf16[1, 512, 12288]" - t1454 = prims.reshape(t1448, (1, 512, 32, 3, 128)) # t1454: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1460 = prims.transpose(t1454, (0, 2, 3, 1, 4)) # t1460: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1461, t1462, t1463) = ltorch.split(t1460, (1, 1, 1), 2) - # t1461 = prims.slice_prim(t1460, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1461: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1462 = prims.slice_prim(t1460, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1462: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1463 = prims.slice_prim(t1460, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1463: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1469 = prims.reshape(t1461, (1, 32, 512, 128)) # t1469: "cuda:0 bf16[1, 32, 512, 128]" - t1475 = prims.reshape(t1462, (1, 32, 512, 128)) # t1475: "cuda:0 bf16[1, 32, 512, 128]" - t1481 = prims.reshape(t1463, (1, 32, 512, 128)) # t1481: "cuda:0 bf16[1, 32, 512, 128]" - t1482 = prims.slice_prim(t1469, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1482: "cuda:0 bf16[1, 32, 512, 128]" - t1483 = prims.slice_prim(t1482, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1483: "cuda:0 bf16[1, 32, 512, 64]" - t1484 = prims.slice_prim(t1482, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1484: "cuda:0 bf16[1, 32, 512, 64]" - t1485 = prims.convert_element_type(t1484, dtypes.float32) # t1485: "cuda:0 f32[1, 32, 512, 64]" - t1486 = prims.neg(t1485) # t1486: "cuda:0 f32[1, 32, 512, 64]" - t1487 = prims.convert_element_type(t1486, dtypes.bfloat16) # t1487: "cuda:0 bf16[1, 32, 512, 64]" - t1489 = prims.cat((t1487, t1483), -1) # t1489: "cuda:0 bf16[1, 32, 512, 128]" - t1490 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1490: "cuda:0 f32[1, 32, 512, 128]" - t1491 = prims.convert_element_type(t1482, dtypes.float32) # t1491: "cuda:0 f32[1, 32, 512, 128]" - t1492 = ltorch.mul(t1491, t1490) # t1492: "cuda:0 f32[1, 32, 512, 128]" - # t1492 = prims.mul(t1491, t1490) # t1492: "cuda:0 f32[1, 32, 512, 128]" - t1493 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1493: "cuda:0 f32[1, 32, 512, 128]" - t1494 = prims.convert_element_type(t1489, dtypes.float32) # t1494: "cuda:0 f32[1, 32, 512, 128]" - t1495 = ltorch.mul(t1494, t1493) # t1495: "cuda:0 f32[1, 32, 512, 128]" - # t1495 = prims.mul(t1494, t1493) # t1495: "cuda:0 f32[1, 32, 512, 128]" - t1496 = ltorch.add(t1492, t1495, alpha=None) # t1496: "cuda:0 f32[1, 32, 512, 128]" - # t1496 = prims.add(t1492, t1495) # t1496: "cuda:0 f32[1, 32, 512, 128]" - t1497 = prims.convert_element_type(t1496, dtypes.bfloat16) # t1497: "cuda:0 bf16[1, 32, 512, 128]" - t1498 = prims.slice_prim(t1475, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1498: "cuda:0 bf16[1, 32, 512, 128]" - t1499 = prims.slice_prim(t1498, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1499: "cuda:0 bf16[1, 32, 512, 64]" - t1500 = prims.slice_prim(t1498, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1500: "cuda:0 bf16[1, 32, 512, 64]" - t1501 = prims.convert_element_type(t1500, dtypes.float32) # t1501: "cuda:0 f32[1, 32, 512, 64]" - t1502 = prims.neg(t1501) # t1502: "cuda:0 f32[1, 32, 512, 64]" - t1503 = prims.convert_element_type(t1502, dtypes.bfloat16) # t1503: "cuda:0 bf16[1, 32, 512, 64]" - t1505 = prims.cat((t1503, t1499), -1) # t1505: "cuda:0 bf16[1, 32, 512, 128]" - t1506 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1506: "cuda:0 f32[1, 32, 512, 128]" - t1507 = prims.convert_element_type(t1498, dtypes.float32) # t1507: "cuda:0 f32[1, 32, 512, 128]" - t1508 = ltorch.mul(t1507, t1506) # t1508: "cuda:0 f32[1, 32, 512, 128]" - # t1508 = prims.mul(t1507, t1506) # t1508: "cuda:0 f32[1, 32, 512, 128]" - t1509 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1509: "cuda:0 f32[1, 32, 512, 128]" - t1510 = prims.convert_element_type(t1505, dtypes.float32) # t1510: "cuda:0 f32[1, 32, 512, 128]" - t1511 = ltorch.mul(t1510, t1509) # t1511: "cuda:0 f32[1, 32, 512, 128]" - # t1511 = prims.mul(t1510, t1509) # t1511: "cuda:0 f32[1, 32, 512, 128]" - t1512 = ltorch.add(t1508, t1511, alpha=None) # t1512: "cuda:0 f32[1, 32, 512, 128]" - # t1512 = prims.add(t1508, t1511) # t1512: "cuda:0 f32[1, 32, 512, 128]" - t1513 = prims.convert_element_type(t1512, dtypes.bfloat16) # t1513: "cuda:0 bf16[1, 32, 512, 128]" - t1514 = prims.slice_prim(t1469, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1514: "cuda:0 bf16[1, 32, 512, 0]" - t1516 = prims.cat((t1497, t1514), -1) # t1516: "cuda:0 bf16[1, 32, 512, 128]" - t1517 = prims.slice_prim(t1475, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1517: "cuda:0 bf16[1, 32, 512, 0]" - t1519 = prims.cat((t1513, t1517), -1) # t1519: "cuda:0 bf16[1, 32, 512, 128]" - (t1520, t1521, t1522, t1523) = cudnn_sdpa_fwd(t1516, t1519, t1481, None, 0.0, True, scale=0.08838834764831843) - t1526 = prims.transpose(t1520, (0, 2, 1, 3)) # t1526: "cuda:0 bf16[1, 512, 32, 128]" - t1530 = prims.reshape(t1526, (1, 512, 4096)) # t1530: "cuda:0 bf16[1, 512, 4096]" - t1531 = prims.linear(t1530, t_transformer_h_10_attn_proj_weight, None) # t1531: "cuda:0 bf16[1, 512, 4096]" - t1532 = prims.convert_element_type(t1531, dtypes.float32) # t1532: "cuda:0 f32[1, 512, 4096]" - t1533 = prims.convert_element_type(t1429, dtypes.float32) # t1533: "cuda:0 f32[1, 512, 4096]" - t1534 = ltorch.add(t1532, t1533, alpha=None) # t1534: "cuda:0 f32[1, 512, 4096]" - # t1534 = prims.add(t1532, t1533) # t1534: "cuda:0 f32[1, 512, 4096]" - t1535 = prims.convert_element_type(t1534, dtypes.bfloat16) # t1535: "cuda:0 bf16[1, 512, 4096]" - t1536 = prims.convert_element_type(t1535, dtypes.float32) # t1536: "cuda:0 f32[1, 512, 4096]" - t1537 = ltorch.mul(t1536, t1536) # t1537: "cuda:0 f32[1, 512, 4096]" - # t1537 = prims.mul(t1536, t1536) # t1537: "cuda:0 f32[1, 512, 4096]" - t1539 = prims.sum(t1537, (2,)) # t1539: "cuda:0 f32[1, 512]" - t1540 = prims.broadcast_in_dim(t1539, [1, 512, 1], [0, 1]) # t1540: "cuda:0 f32[1, 512, 1]" - t1542 = ltorch.true_divide(t1540, 4096.0) # t1542: "cuda:0 f32[1, 512, 1]" - # t1542 = prims.div(t1540, 4096.0) # t1542: "cuda:0 f32[1, 512, 1]" - t1544 = ltorch.add(t1542, 1e-05, alpha=None) # t1544: "cuda:0 f32[1, 512, 1]" - # t1544 = prims.add(t1542, 1e-05) # t1544: "cuda:0 f32[1, 512, 1]" - t1545 = prims.rsqrt(t1544) # t1545: "cuda:0 f32[1, 512, 1]" - t1546 = prims.broadcast_in_dim(t1545, (1, 512, 4096), (0, 1, 2)) # t1546: "cuda:0 f32[1, 512, 4096]" - t1547 = ltorch.mul(t1536, t1546) # t1547: "cuda:0 f32[1, 512, 4096]" - # t1547 = prims.mul(t1536, t1546) # t1547: "cuda:0 f32[1, 512, 4096]" - t1548 = prims.convert_element_type(t1547, dtypes.bfloat16) # t1548: "cuda:0 bf16[1, 512, 4096]" - t1549 = prims.broadcast_in_dim(t_transformer_h_10_norm_2_weight, (1, 512, 4096), (2,)) # t1549: "cuda:0 bf16[1, 512, 4096]" - t1550 = prims.convert_element_type(t1548, dtypes.float32) # t1550: "cuda:0 f32[1, 512, 4096]" - t1551 = prims.convert_element_type(t1549, dtypes.float32) # t1551: "cuda:0 f32[1, 512, 4096]" - t1552 = ltorch.mul(t1550, t1551) # t1552: "cuda:0 f32[1, 512, 4096]" - # t1552 = prims.mul(t1550, t1551) # t1552: "cuda:0 f32[1, 512, 4096]" - t1553 = prims.convert_element_type(t1552, dtypes.bfloat16) # t1553: "cuda:0 bf16[1, 512, 4096]" - t1554 = prims.linear(t1553, t_transformer_h_10_mlp_fc_1_weight, None) # t1554: "cuda:0 bf16[1, 512, 11008]" - t1555 = prims.linear(t1553, t_transformer_h_10_mlp_fc_2_weight, None) # t1555: "cuda:0 bf16[1, 512, 11008]" - t1556 = prims.convert_element_type(t1554, dtypes.float32) # t1556: "cuda:0 f32[1, 512, 11008]" - t1557 = prims.neg(t1556) # t1557: "cuda:0 f32[1, 512, 11008]" - t1558 = prims.exp(t1557) # t1558: "cuda:0 f32[1, 512, 11008]" - t1559 = ltorch.add(1.0, t1558, alpha=None) # t1559: "cuda:0 f32[1, 512, 11008]" - # t1559 = prims.add(1.0, t1558) # t1559: "cuda:0 f32[1, 512, 11008]" - t1560 = prims.reciprocal(t1559) # t1560: "cuda:0 f32[1, 512, 11008]" - t1561 = prims.convert_element_type(t1560, dtypes.bfloat16) # t1561: "cuda:0 bf16[1, 512, 11008]" - t1562 = prims.convert_element_type(t1554, dtypes.float32) # t1562: "cuda:0 f32[1, 512, 11008]" - t1563 = prims.convert_element_type(t1561, dtypes.float32) # t1563: "cuda:0 f32[1, 512, 11008]" - t1564 = ltorch.mul(t1562, t1563) # t1564: "cuda:0 f32[1, 512, 11008]" - # t1564 = prims.mul(t1562, t1563) # t1564: "cuda:0 f32[1, 512, 11008]" - t1565 = prims.convert_element_type(t1564, dtypes.bfloat16) # t1565: "cuda:0 bf16[1, 512, 11008]" - t1566 = prims.convert_element_type(t1565, dtypes.float32) # t1566: "cuda:0 f32[1, 512, 11008]" - t1567 = prims.convert_element_type(t1555, dtypes.float32) # t1567: "cuda:0 f32[1, 512, 11008]" - t1568 = ltorch.mul(t1566, t1567) # t1568: "cuda:0 f32[1, 512, 11008]" - # t1568 = prims.mul(t1566, t1567) # t1568: "cuda:0 f32[1, 512, 11008]" - t1569 = prims.convert_element_type(t1568, dtypes.bfloat16) # t1569: "cuda:0 bf16[1, 512, 11008]" - t1570 = prims.linear(t1569, t_transformer_h_10_mlp_proj_weight, None) # t1570: "cuda:0 bf16[1, 512, 4096]" - t1571 = prims.convert_element_type(t1570, dtypes.float32) # t1571: "cuda:0 f32[1, 512, 4096]" - t1572 = prims.convert_element_type(t1535, dtypes.float32) # t1572: "cuda:0 f32[1, 512, 4096]" - t1573 = ltorch.add(t1571, t1572, alpha=None) # t1573: "cuda:0 f32[1, 512, 4096]" - # t1573 = prims.add(t1571, t1572) # t1573: "cuda:0 f32[1, 512, 4096]" - t1574 = prims.convert_element_type(t1573, dtypes.bfloat16) # t1574: "cuda:0 bf16[1, 512, 4096]" - t1575 = prims.convert_element_type(t1574, dtypes.float32) # t1575: "cuda:0 f32[1, 512, 4096]" - t1576 = ltorch.mul(t1575, t1575) # t1576: "cuda:0 f32[1, 512, 4096]" - # t1576 = prims.mul(t1575, t1575) # t1576: "cuda:0 f32[1, 512, 4096]" - t1578 = prims.sum(t1576, (2,)) # t1578: "cuda:0 f32[1, 512]" - t1579 = prims.broadcast_in_dim(t1578, [1, 512, 1], [0, 1]) # t1579: "cuda:0 f32[1, 512, 1]" - t1581 = ltorch.true_divide(t1579, 4096.0) # t1581: "cuda:0 f32[1, 512, 1]" - # t1581 = prims.div(t1579, 4096.0) # t1581: "cuda:0 f32[1, 512, 1]" - t1583 = ltorch.add(t1581, 1e-05, alpha=None) # t1583: "cuda:0 f32[1, 512, 1]" - # t1583 = prims.add(t1581, 1e-05) # t1583: "cuda:0 f32[1, 512, 1]" - t1584 = prims.rsqrt(t1583) # t1584: "cuda:0 f32[1, 512, 1]" - t1585 = prims.broadcast_in_dim(t1584, (1, 512, 4096), (0, 1, 2)) # t1585: "cuda:0 f32[1, 512, 4096]" - t1586 = ltorch.mul(t1575, t1585) # t1586: "cuda:0 f32[1, 512, 4096]" - # t1586 = prims.mul(t1575, t1585) # t1586: "cuda:0 f32[1, 512, 4096]" - t1587 = prims.convert_element_type(t1586, dtypes.bfloat16) # t1587: "cuda:0 bf16[1, 512, 4096]" - t1588 = prims.broadcast_in_dim(t_transformer_h_11_norm_1_weight, (1, 512, 4096), (2,)) # t1588: "cuda:0 bf16[1, 512, 4096]" - t1589 = prims.convert_element_type(t1587, dtypes.float32) # t1589: "cuda:0 f32[1, 512, 4096]" - t1590 = prims.convert_element_type(t1588, dtypes.float32) # t1590: "cuda:0 f32[1, 512, 4096]" - t1591 = ltorch.mul(t1589, t1590) # t1591: "cuda:0 f32[1, 512, 4096]" - # t1591 = prims.mul(t1589, t1590) # t1591: "cuda:0 f32[1, 512, 4096]" - t1592 = prims.convert_element_type(t1591, dtypes.bfloat16) # t1592: "cuda:0 bf16[1, 512, 4096]" - t1593 = prims.linear(t1592, t_transformer_h_11_attn_attn_weight, None) # t1593: "cuda:0 bf16[1, 512, 12288]" - t1599 = prims.reshape(t1593, (1, 512, 32, 3, 128)) # t1599: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1605 = prims.transpose(t1599, (0, 2, 3, 1, 4)) # t1605: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1606, t1607, t1608) = ltorch.split(t1605, (1, 1, 1), 2) - # t1606 = prims.slice_prim(t1605, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1606: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1607 = prims.slice_prim(t1605, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1607: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1608 = prims.slice_prim(t1605, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1608: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1614 = prims.reshape(t1606, (1, 32, 512, 128)) # t1614: "cuda:0 bf16[1, 32, 512, 128]" - t1620 = prims.reshape(t1607, (1, 32, 512, 128)) # t1620: "cuda:0 bf16[1, 32, 512, 128]" - t1626 = prims.reshape(t1608, (1, 32, 512, 128)) # t1626: "cuda:0 bf16[1, 32, 512, 128]" - t1627 = prims.slice_prim(t1614, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1627: "cuda:0 bf16[1, 32, 512, 128]" - t1628 = prims.slice_prim(t1627, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1628: "cuda:0 bf16[1, 32, 512, 64]" - t1629 = prims.slice_prim(t1627, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1629: "cuda:0 bf16[1, 32, 512, 64]" - t1630 = prims.convert_element_type(t1629, dtypes.float32) # t1630: "cuda:0 f32[1, 32, 512, 64]" - t1631 = prims.neg(t1630) # t1631: "cuda:0 f32[1, 32, 512, 64]" - t1632 = prims.convert_element_type(t1631, dtypes.bfloat16) # t1632: "cuda:0 bf16[1, 32, 512, 64]" - t1634 = prims.cat((t1632, t1628), -1) # t1634: "cuda:0 bf16[1, 32, 512, 128]" - t1635 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1635: "cuda:0 f32[1, 32, 512, 128]" - t1636 = prims.convert_element_type(t1627, dtypes.float32) # t1636: "cuda:0 f32[1, 32, 512, 128]" - t1637 = ltorch.mul(t1636, t1635) # t1637: "cuda:0 f32[1, 32, 512, 128]" - # t1637 = prims.mul(t1636, t1635) # t1637: "cuda:0 f32[1, 32, 512, 128]" - t1638 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1638: "cuda:0 f32[1, 32, 512, 128]" - t1639 = prims.convert_element_type(t1634, dtypes.float32) # t1639: "cuda:0 f32[1, 32, 512, 128]" - t1640 = ltorch.mul(t1639, t1638) # t1640: "cuda:0 f32[1, 32, 512, 128]" - # t1640 = prims.mul(t1639, t1638) # t1640: "cuda:0 f32[1, 32, 512, 128]" - t1641 = ltorch.add(t1637, t1640, alpha=None) # t1641: "cuda:0 f32[1, 32, 512, 128]" - # t1641 = prims.add(t1637, t1640) # t1641: "cuda:0 f32[1, 32, 512, 128]" - t1642 = prims.convert_element_type(t1641, dtypes.bfloat16) # t1642: "cuda:0 bf16[1, 32, 512, 128]" - t1643 = prims.slice_prim(t1620, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1643: "cuda:0 bf16[1, 32, 512, 128]" - t1644 = prims.slice_prim(t1643, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1644: "cuda:0 bf16[1, 32, 512, 64]" - t1645 = prims.slice_prim(t1643, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1645: "cuda:0 bf16[1, 32, 512, 64]" - t1646 = prims.convert_element_type(t1645, dtypes.float32) # t1646: "cuda:0 f32[1, 32, 512, 64]" - t1647 = prims.neg(t1646) # t1647: "cuda:0 f32[1, 32, 512, 64]" - t1648 = prims.convert_element_type(t1647, dtypes.bfloat16) # t1648: "cuda:0 bf16[1, 32, 512, 64]" - t1650 = prims.cat((t1648, t1644), -1) # t1650: "cuda:0 bf16[1, 32, 512, 128]" - t1651 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1651: "cuda:0 f32[1, 32, 512, 128]" - t1652 = prims.convert_element_type(t1643, dtypes.float32) # t1652: "cuda:0 f32[1, 32, 512, 128]" - t1653 = ltorch.mul(t1652, t1651) # t1653: "cuda:0 f32[1, 32, 512, 128]" - # t1653 = prims.mul(t1652, t1651) # t1653: "cuda:0 f32[1, 32, 512, 128]" - t1654 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1654: "cuda:0 f32[1, 32, 512, 128]" - t1655 = prims.convert_element_type(t1650, dtypes.float32) # t1655: "cuda:0 f32[1, 32, 512, 128]" - t1656 = ltorch.mul(t1655, t1654) # t1656: "cuda:0 f32[1, 32, 512, 128]" - # t1656 = prims.mul(t1655, t1654) # t1656: "cuda:0 f32[1, 32, 512, 128]" - t1657 = ltorch.add(t1653, t1656, alpha=None) # t1657: "cuda:0 f32[1, 32, 512, 128]" - # t1657 = prims.add(t1653, t1656) # t1657: "cuda:0 f32[1, 32, 512, 128]" - t1658 = prims.convert_element_type(t1657, dtypes.bfloat16) # t1658: "cuda:0 bf16[1, 32, 512, 128]" - t1659 = prims.slice_prim(t1614, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1659: "cuda:0 bf16[1, 32, 512, 0]" - t1661 = prims.cat((t1642, t1659), -1) # t1661: "cuda:0 bf16[1, 32, 512, 128]" - t1662 = prims.slice_prim(t1620, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1662: "cuda:0 bf16[1, 32, 512, 0]" - t1664 = prims.cat((t1658, t1662), -1) # t1664: "cuda:0 bf16[1, 32, 512, 128]" - (t1665, t1666, t1667, t1668) = cudnn_sdpa_fwd(t1661, t1664, t1626, None, 0.0, True, scale=0.08838834764831843) - t1671 = prims.transpose(t1665, (0, 2, 1, 3)) # t1671: "cuda:0 bf16[1, 512, 32, 128]" - t1675 = prims.reshape(t1671, (1, 512, 4096)) # t1675: "cuda:0 bf16[1, 512, 4096]" - t1676 = prims.linear(t1675, t_transformer_h_11_attn_proj_weight, None) # t1676: "cuda:0 bf16[1, 512, 4096]" - t1677 = prims.convert_element_type(t1676, dtypes.float32) # t1677: "cuda:0 f32[1, 512, 4096]" - t1678 = prims.convert_element_type(t1574, dtypes.float32) # t1678: "cuda:0 f32[1, 512, 4096]" - t1679 = ltorch.add(t1677, t1678, alpha=None) # t1679: "cuda:0 f32[1, 512, 4096]" - # t1679 = prims.add(t1677, t1678) # t1679: "cuda:0 f32[1, 512, 4096]" - t1680 = prims.convert_element_type(t1679, dtypes.bfloat16) # t1680: "cuda:0 bf16[1, 512, 4096]" - t1681 = prims.convert_element_type(t1680, dtypes.float32) # t1681: "cuda:0 f32[1, 512, 4096]" - t1682 = ltorch.mul(t1681, t1681) # t1682: "cuda:0 f32[1, 512, 4096]" - # t1682 = prims.mul(t1681, t1681) # t1682: "cuda:0 f32[1, 512, 4096]" - t1684 = prims.sum(t1682, (2,)) # t1684: "cuda:0 f32[1, 512]" - t1685 = prims.broadcast_in_dim(t1684, [1, 512, 1], [0, 1]) # t1685: "cuda:0 f32[1, 512, 1]" - t1687 = ltorch.true_divide(t1685, 4096.0) # t1687: "cuda:0 f32[1, 512, 1]" - # t1687 = prims.div(t1685, 4096.0) # t1687: "cuda:0 f32[1, 512, 1]" - t1689 = ltorch.add(t1687, 1e-05, alpha=None) # t1689: "cuda:0 f32[1, 512, 1]" - # t1689 = prims.add(t1687, 1e-05) # t1689: "cuda:0 f32[1, 512, 1]" - t1690 = prims.rsqrt(t1689) # t1690: "cuda:0 f32[1, 512, 1]" - t1691 = prims.broadcast_in_dim(t1690, (1, 512, 4096), (0, 1, 2)) # t1691: "cuda:0 f32[1, 512, 4096]" - t1692 = ltorch.mul(t1681, t1691) # t1692: "cuda:0 f32[1, 512, 4096]" - # t1692 = prims.mul(t1681, t1691) # t1692: "cuda:0 f32[1, 512, 4096]" - t1693 = prims.convert_element_type(t1692, dtypes.bfloat16) # t1693: "cuda:0 bf16[1, 512, 4096]" - t1694 = prims.broadcast_in_dim(t_transformer_h_11_norm_2_weight, (1, 512, 4096), (2,)) # t1694: "cuda:0 bf16[1, 512, 4096]" - t1695 = prims.convert_element_type(t1693, dtypes.float32) # t1695: "cuda:0 f32[1, 512, 4096]" - t1696 = prims.convert_element_type(t1694, dtypes.float32) # t1696: "cuda:0 f32[1, 512, 4096]" - t1697 = ltorch.mul(t1695, t1696) # t1697: "cuda:0 f32[1, 512, 4096]" - # t1697 = prims.mul(t1695, t1696) # t1697: "cuda:0 f32[1, 512, 4096]" - t1698 = prims.convert_element_type(t1697, dtypes.bfloat16) # t1698: "cuda:0 bf16[1, 512, 4096]" - t1699 = prims.linear(t1698, t_transformer_h_11_mlp_fc_1_weight, None) # t1699: "cuda:0 bf16[1, 512, 11008]" - t1700 = prims.linear(t1698, t_transformer_h_11_mlp_fc_2_weight, None) # t1700: "cuda:0 bf16[1, 512, 11008]" - t1701 = prims.convert_element_type(t1699, dtypes.float32) # t1701: "cuda:0 f32[1, 512, 11008]" - t1702 = prims.neg(t1701) # t1702: "cuda:0 f32[1, 512, 11008]" - t1703 = prims.exp(t1702) # t1703: "cuda:0 f32[1, 512, 11008]" - t1704 = ltorch.add(1.0, t1703, alpha=None) # t1704: "cuda:0 f32[1, 512, 11008]" - # t1704 = prims.add(1.0, t1703) # t1704: "cuda:0 f32[1, 512, 11008]" - t1705 = prims.reciprocal(t1704) # t1705: "cuda:0 f32[1, 512, 11008]" - t1706 = prims.convert_element_type(t1705, dtypes.bfloat16) # t1706: "cuda:0 bf16[1, 512, 11008]" - t1707 = prims.convert_element_type(t1699, dtypes.float32) # t1707: "cuda:0 f32[1, 512, 11008]" - t1708 = prims.convert_element_type(t1706, dtypes.float32) # t1708: "cuda:0 f32[1, 512, 11008]" - t1709 = ltorch.mul(t1707, t1708) # t1709: "cuda:0 f32[1, 512, 11008]" - # t1709 = prims.mul(t1707, t1708) # t1709: "cuda:0 f32[1, 512, 11008]" - t1710 = prims.convert_element_type(t1709, dtypes.bfloat16) # t1710: "cuda:0 bf16[1, 512, 11008]" - t1711 = prims.convert_element_type(t1710, dtypes.float32) # t1711: "cuda:0 f32[1, 512, 11008]" - t1712 = prims.convert_element_type(t1700, dtypes.float32) # t1712: "cuda:0 f32[1, 512, 11008]" - t1713 = ltorch.mul(t1711, t1712) # t1713: "cuda:0 f32[1, 512, 11008]" - # t1713 = prims.mul(t1711, t1712) # t1713: "cuda:0 f32[1, 512, 11008]" - t1714 = prims.convert_element_type(t1713, dtypes.bfloat16) # t1714: "cuda:0 bf16[1, 512, 11008]" - t1715 = prims.linear(t1714, t_transformer_h_11_mlp_proj_weight, None) # t1715: "cuda:0 bf16[1, 512, 4096]" - t1716 = prims.convert_element_type(t1715, dtypes.float32) # t1716: "cuda:0 f32[1, 512, 4096]" - t1717 = prims.convert_element_type(t1680, dtypes.float32) # t1717: "cuda:0 f32[1, 512, 4096]" - t1718 = ltorch.add(t1716, t1717, alpha=None) # t1718: "cuda:0 f32[1, 512, 4096]" - # t1718 = prims.add(t1716, t1717) # t1718: "cuda:0 f32[1, 512, 4096]" - t1719 = prims.convert_element_type(t1718, dtypes.bfloat16) # t1719: "cuda:0 bf16[1, 512, 4096]" - t1720 = prims.convert_element_type(t1719, dtypes.float32) # t1720: "cuda:0 f32[1, 512, 4096]" - t1721 = ltorch.mul(t1720, t1720) # t1721: "cuda:0 f32[1, 512, 4096]" - # t1721 = prims.mul(t1720, t1720) # t1721: "cuda:0 f32[1, 512, 4096]" - t1723 = prims.sum(t1721, (2,)) # t1723: "cuda:0 f32[1, 512]" - t1724 = prims.broadcast_in_dim(t1723, [1, 512, 1], [0, 1]) # t1724: "cuda:0 f32[1, 512, 1]" - t1726 = ltorch.true_divide(t1724, 4096.0) # t1726: "cuda:0 f32[1, 512, 1]" - # t1726 = prims.div(t1724, 4096.0) # t1726: "cuda:0 f32[1, 512, 1]" - t1728 = ltorch.add(t1726, 1e-05, alpha=None) # t1728: "cuda:0 f32[1, 512, 1]" - # t1728 = prims.add(t1726, 1e-05) # t1728: "cuda:0 f32[1, 512, 1]" - t1729 = prims.rsqrt(t1728) # t1729: "cuda:0 f32[1, 512, 1]" - t1730 = prims.broadcast_in_dim(t1729, (1, 512, 4096), (0, 1, 2)) # t1730: "cuda:0 f32[1, 512, 4096]" - t1731 = ltorch.mul(t1720, t1730) # t1731: "cuda:0 f32[1, 512, 4096]" - # t1731 = prims.mul(t1720, t1730) # t1731: "cuda:0 f32[1, 512, 4096]" - t1732 = prims.convert_element_type(t1731, dtypes.bfloat16) # t1732: "cuda:0 bf16[1, 512, 4096]" - t1733 = prims.broadcast_in_dim(t_transformer_h_12_norm_1_weight, (1, 512, 4096), (2,)) # t1733: "cuda:0 bf16[1, 512, 4096]" - t1734 = prims.convert_element_type(t1732, dtypes.float32) # t1734: "cuda:0 f32[1, 512, 4096]" - t1735 = prims.convert_element_type(t1733, dtypes.float32) # t1735: "cuda:0 f32[1, 512, 4096]" - t1736 = ltorch.mul(t1734, t1735) # t1736: "cuda:0 f32[1, 512, 4096]" - # t1736 = prims.mul(t1734, t1735) # t1736: "cuda:0 f32[1, 512, 4096]" - t1737 = prims.convert_element_type(t1736, dtypes.bfloat16) # t1737: "cuda:0 bf16[1, 512, 4096]" - t1738 = prims.linear(t1737, t_transformer_h_12_attn_attn_weight, None) # t1738: "cuda:0 bf16[1, 512, 12288]" - t1744 = prims.reshape(t1738, (1, 512, 32, 3, 128)) # t1744: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1750 = prims.transpose(t1744, (0, 2, 3, 1, 4)) # t1750: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1751, t1752, t1753) = ltorch.split(t1750, (1, 1, 1), 2) - # t1751 = prims.slice_prim(t1750, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1751: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1752 = prims.slice_prim(t1750, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1752: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1753 = prims.slice_prim(t1750, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1753: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1759 = prims.reshape(t1751, (1, 32, 512, 128)) # t1759: "cuda:0 bf16[1, 32, 512, 128]" - t1765 = prims.reshape(t1752, (1, 32, 512, 128)) # t1765: "cuda:0 bf16[1, 32, 512, 128]" - t1771 = prims.reshape(t1753, (1, 32, 512, 128)) # t1771: "cuda:0 bf16[1, 32, 512, 128]" - t1772 = prims.slice_prim(t1759, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1772: "cuda:0 bf16[1, 32, 512, 128]" - t1773 = prims.slice_prim(t1772, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1773: "cuda:0 bf16[1, 32, 512, 64]" - t1774 = prims.slice_prim(t1772, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1774: "cuda:0 bf16[1, 32, 512, 64]" - t1775 = prims.convert_element_type(t1774, dtypes.float32) # t1775: "cuda:0 f32[1, 32, 512, 64]" - t1776 = prims.neg(t1775) # t1776: "cuda:0 f32[1, 32, 512, 64]" - t1777 = prims.convert_element_type(t1776, dtypes.bfloat16) # t1777: "cuda:0 bf16[1, 32, 512, 64]" - t1779 = prims.cat((t1777, t1773), -1) # t1779: "cuda:0 bf16[1, 32, 512, 128]" - t1780 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1780: "cuda:0 f32[1, 32, 512, 128]" - t1781 = prims.convert_element_type(t1772, dtypes.float32) # t1781: "cuda:0 f32[1, 32, 512, 128]" - t1782 = ltorch.mul(t1781, t1780) # t1782: "cuda:0 f32[1, 32, 512, 128]" - # t1782 = prims.mul(t1781, t1780) # t1782: "cuda:0 f32[1, 32, 512, 128]" - t1783 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1783: "cuda:0 f32[1, 32, 512, 128]" - t1784 = prims.convert_element_type(t1779, dtypes.float32) # t1784: "cuda:0 f32[1, 32, 512, 128]" - t1785 = ltorch.mul(t1784, t1783) # t1785: "cuda:0 f32[1, 32, 512, 128]" - # t1785 = prims.mul(t1784, t1783) # t1785: "cuda:0 f32[1, 32, 512, 128]" - t1786 = ltorch.add(t1782, t1785, alpha=None) # t1786: "cuda:0 f32[1, 32, 512, 128]" - # t1786 = prims.add(t1782, t1785) # t1786: "cuda:0 f32[1, 32, 512, 128]" - t1787 = prims.convert_element_type(t1786, dtypes.bfloat16) # t1787: "cuda:0 bf16[1, 32, 512, 128]" - t1788 = prims.slice_prim(t1765, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1788: "cuda:0 bf16[1, 32, 512, 128]" - t1789 = prims.slice_prim(t1788, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1789: "cuda:0 bf16[1, 32, 512, 64]" - t1790 = prims.slice_prim(t1788, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1790: "cuda:0 bf16[1, 32, 512, 64]" - t1791 = prims.convert_element_type(t1790, dtypes.float32) # t1791: "cuda:0 f32[1, 32, 512, 64]" - t1792 = prims.neg(t1791) # t1792: "cuda:0 f32[1, 32, 512, 64]" - t1793 = prims.convert_element_type(t1792, dtypes.bfloat16) # t1793: "cuda:0 bf16[1, 32, 512, 64]" - t1795 = prims.cat((t1793, t1789), -1) # t1795: "cuda:0 bf16[1, 32, 512, 128]" - t1796 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1796: "cuda:0 f32[1, 32, 512, 128]" - t1797 = prims.convert_element_type(t1788, dtypes.float32) # t1797: "cuda:0 f32[1, 32, 512, 128]" - t1798 = ltorch.mul(t1797, t1796) # t1798: "cuda:0 f32[1, 32, 512, 128]" - # t1798 = prims.mul(t1797, t1796) # t1798: "cuda:0 f32[1, 32, 512, 128]" - t1799 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1799: "cuda:0 f32[1, 32, 512, 128]" - t1800 = prims.convert_element_type(t1795, dtypes.float32) # t1800: "cuda:0 f32[1, 32, 512, 128]" - t1801 = ltorch.mul(t1800, t1799) # t1801: "cuda:0 f32[1, 32, 512, 128]" - # t1801 = prims.mul(t1800, t1799) # t1801: "cuda:0 f32[1, 32, 512, 128]" - t1802 = ltorch.add(t1798, t1801, alpha=None) # t1802: "cuda:0 f32[1, 32, 512, 128]" - # t1802 = prims.add(t1798, t1801) # t1802: "cuda:0 f32[1, 32, 512, 128]" - t1803 = prims.convert_element_type(t1802, dtypes.bfloat16) # t1803: "cuda:0 bf16[1, 32, 512, 128]" - t1804 = prims.slice_prim(t1759, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1804: "cuda:0 bf16[1, 32, 512, 0]" - t1806 = prims.cat((t1787, t1804), -1) # t1806: "cuda:0 bf16[1, 32, 512, 128]" - t1807 = prims.slice_prim(t1765, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1807: "cuda:0 bf16[1, 32, 512, 0]" - t1809 = prims.cat((t1803, t1807), -1) # t1809: "cuda:0 bf16[1, 32, 512, 128]" - (t1810, t1811, t1812, t1813) = cudnn_sdpa_fwd(t1806, t1809, t1771, None, 0.0, True, scale=0.08838834764831843) - t1816 = prims.transpose(t1810, (0, 2, 1, 3)) # t1816: "cuda:0 bf16[1, 512, 32, 128]" - t1820 = prims.reshape(t1816, (1, 512, 4096)) # t1820: "cuda:0 bf16[1, 512, 4096]" - t1821 = prims.linear(t1820, t_transformer_h_12_attn_proj_weight, None) # t1821: "cuda:0 bf16[1, 512, 4096]" - t1822 = prims.convert_element_type(t1821, dtypes.float32) # t1822: "cuda:0 f32[1, 512, 4096]" - t1823 = prims.convert_element_type(t1719, dtypes.float32) # t1823: "cuda:0 f32[1, 512, 4096]" - t1824 = ltorch.add(t1822, t1823, alpha=None) # t1824: "cuda:0 f32[1, 512, 4096]" - # t1824 = prims.add(t1822, t1823) # t1824: "cuda:0 f32[1, 512, 4096]" - t1825 = prims.convert_element_type(t1824, dtypes.bfloat16) # t1825: "cuda:0 bf16[1, 512, 4096]" - t1826 = prims.convert_element_type(t1825, dtypes.float32) # t1826: "cuda:0 f32[1, 512, 4096]" - t1827 = ltorch.mul(t1826, t1826) # t1827: "cuda:0 f32[1, 512, 4096]" - # t1827 = prims.mul(t1826, t1826) # t1827: "cuda:0 f32[1, 512, 4096]" - t1829 = prims.sum(t1827, (2,)) # t1829: "cuda:0 f32[1, 512]" - t1830 = prims.broadcast_in_dim(t1829, [1, 512, 1], [0, 1]) # t1830: "cuda:0 f32[1, 512, 1]" - t1832 = ltorch.true_divide(t1830, 4096.0) # t1832: "cuda:0 f32[1, 512, 1]" - # t1832 = prims.div(t1830, 4096.0) # t1832: "cuda:0 f32[1, 512, 1]" - t1834 = ltorch.add(t1832, 1e-05, alpha=None) # t1834: "cuda:0 f32[1, 512, 1]" - # t1834 = prims.add(t1832, 1e-05) # t1834: "cuda:0 f32[1, 512, 1]" - t1835 = prims.rsqrt(t1834) # t1835: "cuda:0 f32[1, 512, 1]" - t1836 = prims.broadcast_in_dim(t1835, (1, 512, 4096), (0, 1, 2)) # t1836: "cuda:0 f32[1, 512, 4096]" - t1837 = ltorch.mul(t1826, t1836) # t1837: "cuda:0 f32[1, 512, 4096]" - # t1837 = prims.mul(t1826, t1836) # t1837: "cuda:0 f32[1, 512, 4096]" - t1838 = prims.convert_element_type(t1837, dtypes.bfloat16) # t1838: "cuda:0 bf16[1, 512, 4096]" - t1839 = prims.broadcast_in_dim(t_transformer_h_12_norm_2_weight, (1, 512, 4096), (2,)) # t1839: "cuda:0 bf16[1, 512, 4096]" - t1840 = prims.convert_element_type(t1838, dtypes.float32) # t1840: "cuda:0 f32[1, 512, 4096]" - t1841 = prims.convert_element_type(t1839, dtypes.float32) # t1841: "cuda:0 f32[1, 512, 4096]" - t1842 = ltorch.mul(t1840, t1841) # t1842: "cuda:0 f32[1, 512, 4096]" - # t1842 = prims.mul(t1840, t1841) # t1842: "cuda:0 f32[1, 512, 4096]" - t1843 = prims.convert_element_type(t1842, dtypes.bfloat16) # t1843: "cuda:0 bf16[1, 512, 4096]" - t1844 = prims.linear(t1843, t_transformer_h_12_mlp_fc_1_weight, None) # t1844: "cuda:0 bf16[1, 512, 11008]" - t1845 = prims.linear(t1843, t_transformer_h_12_mlp_fc_2_weight, None) # t1845: "cuda:0 bf16[1, 512, 11008]" - t1846 = prims.convert_element_type(t1844, dtypes.float32) # t1846: "cuda:0 f32[1, 512, 11008]" - t1847 = prims.neg(t1846) # t1847: "cuda:0 f32[1, 512, 11008]" - t1848 = prims.exp(t1847) # t1848: "cuda:0 f32[1, 512, 11008]" - t1849 = ltorch.add(1.0, t1848, alpha=None) # t1849: "cuda:0 f32[1, 512, 11008]" - # t1849 = prims.add(1.0, t1848) # t1849: "cuda:0 f32[1, 512, 11008]" - t1850 = prims.reciprocal(t1849) # t1850: "cuda:0 f32[1, 512, 11008]" - t1851 = prims.convert_element_type(t1850, dtypes.bfloat16) # t1851: "cuda:0 bf16[1, 512, 11008]" - t1852 = prims.convert_element_type(t1844, dtypes.float32) # t1852: "cuda:0 f32[1, 512, 11008]" - t1853 = prims.convert_element_type(t1851, dtypes.float32) # t1853: "cuda:0 f32[1, 512, 11008]" - t1854 = ltorch.mul(t1852, t1853) # t1854: "cuda:0 f32[1, 512, 11008]" - # t1854 = prims.mul(t1852, t1853) # t1854: "cuda:0 f32[1, 512, 11008]" - t1855 = prims.convert_element_type(t1854, dtypes.bfloat16) # t1855: "cuda:0 bf16[1, 512, 11008]" - t1856 = prims.convert_element_type(t1855, dtypes.float32) # t1856: "cuda:0 f32[1, 512, 11008]" - t1857 = prims.convert_element_type(t1845, dtypes.float32) # t1857: "cuda:0 f32[1, 512, 11008]" - t1858 = ltorch.mul(t1856, t1857) # t1858: "cuda:0 f32[1, 512, 11008]" - # t1858 = prims.mul(t1856, t1857) # t1858: "cuda:0 f32[1, 512, 11008]" - t1859 = prims.convert_element_type(t1858, dtypes.bfloat16) # t1859: "cuda:0 bf16[1, 512, 11008]" - t1860 = prims.linear(t1859, t_transformer_h_12_mlp_proj_weight, None) # t1860: "cuda:0 bf16[1, 512, 4096]" - t1861 = prims.convert_element_type(t1860, dtypes.float32) # t1861: "cuda:0 f32[1, 512, 4096]" - t1862 = prims.convert_element_type(t1825, dtypes.float32) # t1862: "cuda:0 f32[1, 512, 4096]" - t1863 = ltorch.add(t1861, t1862, alpha=None) # t1863: "cuda:0 f32[1, 512, 4096]" - # t1863 = prims.add(t1861, t1862) # t1863: "cuda:0 f32[1, 512, 4096]" - t1864 = prims.convert_element_type(t1863, dtypes.bfloat16) # t1864: "cuda:0 bf16[1, 512, 4096]" - t1865 = prims.convert_element_type(t1864, dtypes.float32) # t1865: "cuda:0 f32[1, 512, 4096]" - t1866 = ltorch.mul(t1865, t1865) # t1866: "cuda:0 f32[1, 512, 4096]" - # t1866 = prims.mul(t1865, t1865) # t1866: "cuda:0 f32[1, 512, 4096]" - t1868 = prims.sum(t1866, (2,)) # t1868: "cuda:0 f32[1, 512]" - t1869 = prims.broadcast_in_dim(t1868, [1, 512, 1], [0, 1]) # t1869: "cuda:0 f32[1, 512, 1]" - t1871 = ltorch.true_divide(t1869, 4096.0) # t1871: "cuda:0 f32[1, 512, 1]" - # t1871 = prims.div(t1869, 4096.0) # t1871: "cuda:0 f32[1, 512, 1]" - t1873 = ltorch.add(t1871, 1e-05, alpha=None) # t1873: "cuda:0 f32[1, 512, 1]" - # t1873 = prims.add(t1871, 1e-05) # t1873: "cuda:0 f32[1, 512, 1]" - t1874 = prims.rsqrt(t1873) # t1874: "cuda:0 f32[1, 512, 1]" - t1875 = prims.broadcast_in_dim(t1874, (1, 512, 4096), (0, 1, 2)) # t1875: "cuda:0 f32[1, 512, 4096]" - t1876 = ltorch.mul(t1865, t1875) # t1876: "cuda:0 f32[1, 512, 4096]" - # t1876 = prims.mul(t1865, t1875) # t1876: "cuda:0 f32[1, 512, 4096]" - t1877 = prims.convert_element_type(t1876, dtypes.bfloat16) # t1877: "cuda:0 bf16[1, 512, 4096]" - t1878 = prims.broadcast_in_dim(t_transformer_h_13_norm_1_weight, (1, 512, 4096), (2,)) # t1878: "cuda:0 bf16[1, 512, 4096]" - t1879 = prims.convert_element_type(t1877, dtypes.float32) # t1879: "cuda:0 f32[1, 512, 4096]" - t1880 = prims.convert_element_type(t1878, dtypes.float32) # t1880: "cuda:0 f32[1, 512, 4096]" - t1881 = ltorch.mul(t1879, t1880) # t1881: "cuda:0 f32[1, 512, 4096]" - # t1881 = prims.mul(t1879, t1880) # t1881: "cuda:0 f32[1, 512, 4096]" - t1882 = prims.convert_element_type(t1881, dtypes.bfloat16) # t1882: "cuda:0 bf16[1, 512, 4096]" - t1883 = prims.linear(t1882, t_transformer_h_13_attn_attn_weight, None) # t1883: "cuda:0 bf16[1, 512, 12288]" - t1889 = prims.reshape(t1883, (1, 512, 32, 3, 128)) # t1889: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1895 = prims.transpose(t1889, (0, 2, 3, 1, 4)) # t1895: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1896, t1897, t1898) = ltorch.split(t1895, (1, 1, 1), 2) - # t1896 = prims.slice_prim(t1895, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1896: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1897 = prims.slice_prim(t1895, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1897: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1898 = prims.slice_prim(t1895, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1898: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1904 = prims.reshape(t1896, (1, 32, 512, 128)) # t1904: "cuda:0 bf16[1, 32, 512, 128]" - t1910 = prims.reshape(t1897, (1, 32, 512, 128)) # t1910: "cuda:0 bf16[1, 32, 512, 128]" - t1916 = prims.reshape(t1898, (1, 32, 512, 128)) # t1916: "cuda:0 bf16[1, 32, 512, 128]" - t1917 = prims.slice_prim(t1904, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1917: "cuda:0 bf16[1, 32, 512, 128]" - t1918 = prims.slice_prim(t1917, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1918: "cuda:0 bf16[1, 32, 512, 64]" - t1919 = prims.slice_prim(t1917, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1919: "cuda:0 bf16[1, 32, 512, 64]" - t1920 = prims.convert_element_type(t1919, dtypes.float32) # t1920: "cuda:0 f32[1, 32, 512, 64]" - t1921 = prims.neg(t1920) # t1921: "cuda:0 f32[1, 32, 512, 64]" - t1922 = prims.convert_element_type(t1921, dtypes.bfloat16) # t1922: "cuda:0 bf16[1, 32, 512, 64]" - t1924 = prims.cat((t1922, t1918), -1) # t1924: "cuda:0 bf16[1, 32, 512, 128]" - t1925 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1925: "cuda:0 f32[1, 32, 512, 128]" - t1926 = prims.convert_element_type(t1917, dtypes.float32) # t1926: "cuda:0 f32[1, 32, 512, 128]" - t1927 = ltorch.mul(t1926, t1925) # t1927: "cuda:0 f32[1, 32, 512, 128]" - # t1927 = prims.mul(t1926, t1925) # t1927: "cuda:0 f32[1, 32, 512, 128]" - t1928 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1928: "cuda:0 f32[1, 32, 512, 128]" - t1929 = prims.convert_element_type(t1924, dtypes.float32) # t1929: "cuda:0 f32[1, 32, 512, 128]" - t1930 = ltorch.mul(t1929, t1928) # t1930: "cuda:0 f32[1, 32, 512, 128]" - # t1930 = prims.mul(t1929, t1928) # t1930: "cuda:0 f32[1, 32, 512, 128]" - t1931 = ltorch.add(t1927, t1930, alpha=None) # t1931: "cuda:0 f32[1, 32, 512, 128]" - # t1931 = prims.add(t1927, t1930) # t1931: "cuda:0 f32[1, 32, 512, 128]" - t1932 = prims.convert_element_type(t1931, dtypes.bfloat16) # t1932: "cuda:0 bf16[1, 32, 512, 128]" - t1933 = prims.slice_prim(t1910, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1933: "cuda:0 bf16[1, 32, 512, 128]" - t1934 = prims.slice_prim(t1933, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1934: "cuda:0 bf16[1, 32, 512, 64]" - t1935 = prims.slice_prim(t1933, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1935: "cuda:0 bf16[1, 32, 512, 64]" - t1936 = prims.convert_element_type(t1935, dtypes.float32) # t1936: "cuda:0 f32[1, 32, 512, 64]" - t1937 = prims.neg(t1936) # t1937: "cuda:0 f32[1, 32, 512, 64]" - t1938 = prims.convert_element_type(t1937, dtypes.bfloat16) # t1938: "cuda:0 bf16[1, 32, 512, 64]" - t1940 = prims.cat((t1938, t1934), -1) # t1940: "cuda:0 bf16[1, 32, 512, 128]" - t1941 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1941: "cuda:0 f32[1, 32, 512, 128]" - t1942 = prims.convert_element_type(t1933, dtypes.float32) # t1942: "cuda:0 f32[1, 32, 512, 128]" - t1943 = ltorch.mul(t1942, t1941) # t1943: "cuda:0 f32[1, 32, 512, 128]" - # t1943 = prims.mul(t1942, t1941) # t1943: "cuda:0 f32[1, 32, 512, 128]" - t1944 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1944: "cuda:0 f32[1, 32, 512, 128]" - t1945 = prims.convert_element_type(t1940, dtypes.float32) # t1945: "cuda:0 f32[1, 32, 512, 128]" - t1946 = ltorch.mul(t1945, t1944) # t1946: "cuda:0 f32[1, 32, 512, 128]" - # t1946 = prims.mul(t1945, t1944) # t1946: "cuda:0 f32[1, 32, 512, 128]" - t1947 = ltorch.add(t1943, t1946, alpha=None) # t1947: "cuda:0 f32[1, 32, 512, 128]" - # t1947 = prims.add(t1943, t1946) # t1947: "cuda:0 f32[1, 32, 512, 128]" - t1948 = prims.convert_element_type(t1947, dtypes.bfloat16) # t1948: "cuda:0 bf16[1, 32, 512, 128]" - t1949 = prims.slice_prim(t1904, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1949: "cuda:0 bf16[1, 32, 512, 0]" - t1951 = prims.cat((t1932, t1949), -1) # t1951: "cuda:0 bf16[1, 32, 512, 128]" - t1952 = prims.slice_prim(t1910, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1952: "cuda:0 bf16[1, 32, 512, 0]" - t1954 = prims.cat((t1948, t1952), -1) # t1954: "cuda:0 bf16[1, 32, 512, 128]" - (t1955, t1956, t1957, t1958) = cudnn_sdpa_fwd(t1951, t1954, t1916, None, 0.0, True, scale=0.08838834764831843) - t1961 = prims.transpose(t1955, (0, 2, 1, 3)) # t1961: "cuda:0 bf16[1, 512, 32, 128]" - t1965 = prims.reshape(t1961, (1, 512, 4096)) # t1965: "cuda:0 bf16[1, 512, 4096]" - t1966 = prims.linear(t1965, t_transformer_h_13_attn_proj_weight, None) # t1966: "cuda:0 bf16[1, 512, 4096]" - t1967 = prims.convert_element_type(t1966, dtypes.float32) # t1967: "cuda:0 f32[1, 512, 4096]" - t1968 = prims.convert_element_type(t1864, dtypes.float32) # t1968: "cuda:0 f32[1, 512, 4096]" - t1969 = ltorch.add(t1967, t1968, alpha=None) # t1969: "cuda:0 f32[1, 512, 4096]" - # t1969 = prims.add(t1967, t1968) # t1969: "cuda:0 f32[1, 512, 4096]" - t1970 = prims.convert_element_type(t1969, dtypes.bfloat16) # t1970: "cuda:0 bf16[1, 512, 4096]" - t1971 = prims.convert_element_type(t1970, dtypes.float32) # t1971: "cuda:0 f32[1, 512, 4096]" - t1972 = ltorch.mul(t1971, t1971) # t1972: "cuda:0 f32[1, 512, 4096]" - # t1972 = prims.mul(t1971, t1971) # t1972: "cuda:0 f32[1, 512, 4096]" - t1974 = prims.sum(t1972, (2,)) # t1974: "cuda:0 f32[1, 512]" - t1975 = prims.broadcast_in_dim(t1974, [1, 512, 1], [0, 1]) # t1975: "cuda:0 f32[1, 512, 1]" - t1977 = ltorch.true_divide(t1975, 4096.0) # t1977: "cuda:0 f32[1, 512, 1]" - # t1977 = prims.div(t1975, 4096.0) # t1977: "cuda:0 f32[1, 512, 1]" - t1979 = ltorch.add(t1977, 1e-05, alpha=None) # t1979: "cuda:0 f32[1, 512, 1]" - # t1979 = prims.add(t1977, 1e-05) # t1979: "cuda:0 f32[1, 512, 1]" - t1980 = prims.rsqrt(t1979) # t1980: "cuda:0 f32[1, 512, 1]" - t1981 = prims.broadcast_in_dim(t1980, (1, 512, 4096), (0, 1, 2)) # t1981: "cuda:0 f32[1, 512, 4096]" - t1982 = ltorch.mul(t1971, t1981) # t1982: "cuda:0 f32[1, 512, 4096]" - # t1982 = prims.mul(t1971, t1981) # t1982: "cuda:0 f32[1, 512, 4096]" - t1983 = prims.convert_element_type(t1982, dtypes.bfloat16) # t1983: "cuda:0 bf16[1, 512, 4096]" - t1984 = prims.broadcast_in_dim(t_transformer_h_13_norm_2_weight, (1, 512, 4096), (2,)) # t1984: "cuda:0 bf16[1, 512, 4096]" - t1985 = prims.convert_element_type(t1983, dtypes.float32) # t1985: "cuda:0 f32[1, 512, 4096]" - t1986 = prims.convert_element_type(t1984, dtypes.float32) # t1986: "cuda:0 f32[1, 512, 4096]" - t1987 = ltorch.mul(t1985, t1986) # t1987: "cuda:0 f32[1, 512, 4096]" - # t1987 = prims.mul(t1985, t1986) # t1987: "cuda:0 f32[1, 512, 4096]" - t1988 = prims.convert_element_type(t1987, dtypes.bfloat16) # t1988: "cuda:0 bf16[1, 512, 4096]" - t1989 = prims.linear(t1988, t_transformer_h_13_mlp_fc_1_weight, None) # t1989: "cuda:0 bf16[1, 512, 11008]" - t1990 = prims.linear(t1988, t_transformer_h_13_mlp_fc_2_weight, None) # t1990: "cuda:0 bf16[1, 512, 11008]" - t1991 = prims.convert_element_type(t1989, dtypes.float32) # t1991: "cuda:0 f32[1, 512, 11008]" - t1992 = prims.neg(t1991) # t1992: "cuda:0 f32[1, 512, 11008]" - t1993 = prims.exp(t1992) # t1993: "cuda:0 f32[1, 512, 11008]" - t1994 = ltorch.add(1.0, t1993, alpha=None) # t1994: "cuda:0 f32[1, 512, 11008]" - # t1994 = prims.add(1.0, t1993) # t1994: "cuda:0 f32[1, 512, 11008]" - t1995 = prims.reciprocal(t1994) # t1995: "cuda:0 f32[1, 512, 11008]" - t1996 = prims.convert_element_type(t1995, dtypes.bfloat16) # t1996: "cuda:0 bf16[1, 512, 11008]" - t1997 = prims.convert_element_type(t1989, dtypes.float32) # t1997: "cuda:0 f32[1, 512, 11008]" - t1998 = prims.convert_element_type(t1996, dtypes.float32) # t1998: "cuda:0 f32[1, 512, 11008]" - t1999 = ltorch.mul(t1997, t1998) # t1999: "cuda:0 f32[1, 512, 11008]" - # t1999 = prims.mul(t1997, t1998) # t1999: "cuda:0 f32[1, 512, 11008]" - t2000 = prims.convert_element_type(t1999, dtypes.bfloat16) # t2000: "cuda:0 bf16[1, 512, 11008]" - t2001 = prims.convert_element_type(t2000, dtypes.float32) # t2001: "cuda:0 f32[1, 512, 11008]" - t2002 = prims.convert_element_type(t1990, dtypes.float32) # t2002: "cuda:0 f32[1, 512, 11008]" - t2003 = ltorch.mul(t2001, t2002) # t2003: "cuda:0 f32[1, 512, 11008]" - # t2003 = prims.mul(t2001, t2002) # t2003: "cuda:0 f32[1, 512, 11008]" - t2004 = prims.convert_element_type(t2003, dtypes.bfloat16) # t2004: "cuda:0 bf16[1, 512, 11008]" - t2005 = prims.linear(t2004, t_transformer_h_13_mlp_proj_weight, None) # t2005: "cuda:0 bf16[1, 512, 4096]" - t2006 = prims.convert_element_type(t2005, dtypes.float32) # t2006: "cuda:0 f32[1, 512, 4096]" - t2007 = prims.convert_element_type(t1970, dtypes.float32) # t2007: "cuda:0 f32[1, 512, 4096]" - t2008 = ltorch.add(t2006, t2007, alpha=None) # t2008: "cuda:0 f32[1, 512, 4096]" - # t2008 = prims.add(t2006, t2007) # t2008: "cuda:0 f32[1, 512, 4096]" - t2009 = prims.convert_element_type(t2008, dtypes.bfloat16) # t2009: "cuda:0 bf16[1, 512, 4096]" - t2010 = prims.convert_element_type(t2009, dtypes.float32) # t2010: "cuda:0 f32[1, 512, 4096]" - t2011 = ltorch.mul(t2010, t2010) # t2011: "cuda:0 f32[1, 512, 4096]" - # t2011 = prims.mul(t2010, t2010) # t2011: "cuda:0 f32[1, 512, 4096]" - t2013 = prims.sum(t2011, (2,)) # t2013: "cuda:0 f32[1, 512]" - t2014 = prims.broadcast_in_dim(t2013, [1, 512, 1], [0, 1]) # t2014: "cuda:0 f32[1, 512, 1]" - t2016 = ltorch.true_divide(t2014, 4096.0) # t2016: "cuda:0 f32[1, 512, 1]" - # t2016 = prims.div(t2014, 4096.0) # t2016: "cuda:0 f32[1, 512, 1]" - t2018 = ltorch.add(t2016, 1e-05, alpha=None) # t2018: "cuda:0 f32[1, 512, 1]" - # t2018 = prims.add(t2016, 1e-05) # t2018: "cuda:0 f32[1, 512, 1]" - t2019 = prims.rsqrt(t2018) # t2019: "cuda:0 f32[1, 512, 1]" - t2020 = prims.broadcast_in_dim(t2019, (1, 512, 4096), (0, 1, 2)) # t2020: "cuda:0 f32[1, 512, 4096]" - t2021 = ltorch.mul(t2010, t2020) # t2021: "cuda:0 f32[1, 512, 4096]" - # t2021 = prims.mul(t2010, t2020) # t2021: "cuda:0 f32[1, 512, 4096]" - t2022 = prims.convert_element_type(t2021, dtypes.bfloat16) # t2022: "cuda:0 bf16[1, 512, 4096]" - t2023 = prims.broadcast_in_dim(t_transformer_h_14_norm_1_weight, (1, 512, 4096), (2,)) # t2023: "cuda:0 bf16[1, 512, 4096]" - t2024 = prims.convert_element_type(t2022, dtypes.float32) # t2024: "cuda:0 f32[1, 512, 4096]" - t2025 = prims.convert_element_type(t2023, dtypes.float32) # t2025: "cuda:0 f32[1, 512, 4096]" - t2026 = ltorch.mul(t2024, t2025) # t2026: "cuda:0 f32[1, 512, 4096]" - # t2026 = prims.mul(t2024, t2025) # t2026: "cuda:0 f32[1, 512, 4096]" - t2027 = prims.convert_element_type(t2026, dtypes.bfloat16) # t2027: "cuda:0 bf16[1, 512, 4096]" - t2028 = prims.linear(t2027, t_transformer_h_14_attn_attn_weight, None) # t2028: "cuda:0 bf16[1, 512, 12288]" - t2034 = prims.reshape(t2028, (1, 512, 32, 3, 128)) # t2034: "cuda:0 bf16[1, 512, 32, 3, 128]" - t2040 = prims.transpose(t2034, (0, 2, 3, 1, 4)) # t2040: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t2041, t2042, t2043) = ltorch.split(t2040, (1, 1, 1), 2) - # t2041 = prims.slice_prim(t2040, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t2041: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2042 = prims.slice_prim(t2040, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t2042: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2043 = prims.slice_prim(t2040, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t2043: "cuda:0 bf16[1, 32, 1, 512, 128]" - t2049 = prims.reshape(t2041, (1, 32, 512, 128)) # t2049: "cuda:0 bf16[1, 32, 512, 128]" - t2055 = prims.reshape(t2042, (1, 32, 512, 128)) # t2055: "cuda:0 bf16[1, 32, 512, 128]" - t2061 = prims.reshape(t2043, (1, 32, 512, 128)) # t2061: "cuda:0 bf16[1, 32, 512, 128]" - t2062 = prims.slice_prim(t2049, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2062: "cuda:0 bf16[1, 32, 512, 128]" - t2063 = prims.slice_prim(t2062, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2063: "cuda:0 bf16[1, 32, 512, 64]" - t2064 = prims.slice_prim(t2062, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2064: "cuda:0 bf16[1, 32, 512, 64]" - t2065 = prims.convert_element_type(t2064, dtypes.float32) # t2065: "cuda:0 f32[1, 32, 512, 64]" - t2066 = prims.neg(t2065) # t2066: "cuda:0 f32[1, 32, 512, 64]" - t2067 = prims.convert_element_type(t2066, dtypes.bfloat16) # t2067: "cuda:0 bf16[1, 32, 512, 64]" - t2069 = prims.cat((t2067, t2063), -1) # t2069: "cuda:0 bf16[1, 32, 512, 128]" - t2070 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t2070: "cuda:0 f32[1, 32, 512, 128]" - t2071 = prims.convert_element_type(t2062, dtypes.float32) # t2071: "cuda:0 f32[1, 32, 512, 128]" - t2072 = ltorch.mul(t2071, t2070) # t2072: "cuda:0 f32[1, 32, 512, 128]" - # t2072 = prims.mul(t2071, t2070) # t2072: "cuda:0 f32[1, 32, 512, 128]" - t2073 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t2073: "cuda:0 f32[1, 32, 512, 128]" - t2074 = prims.convert_element_type(t2069, dtypes.float32) # t2074: "cuda:0 f32[1, 32, 512, 128]" - t2075 = ltorch.mul(t2074, t2073) # t2075: "cuda:0 f32[1, 32, 512, 128]" - # t2075 = prims.mul(t2074, t2073) # t2075: "cuda:0 f32[1, 32, 512, 128]" - t2076 = ltorch.add(t2072, t2075, alpha=None) # t2076: "cuda:0 f32[1, 32, 512, 128]" - # t2076 = prims.add(t2072, t2075) # t2076: "cuda:0 f32[1, 32, 512, 128]" - t2077 = prims.convert_element_type(t2076, dtypes.bfloat16) # t2077: "cuda:0 bf16[1, 32, 512, 128]" - t2078 = prims.slice_prim(t2055, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2078: "cuda:0 bf16[1, 32, 512, 128]" - t2079 = prims.slice_prim(t2078, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2079: "cuda:0 bf16[1, 32, 512, 64]" - t2080 = prims.slice_prim(t2078, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2080: "cuda:0 bf16[1, 32, 512, 64]" - t2081 = prims.convert_element_type(t2080, dtypes.float32) # t2081: "cuda:0 f32[1, 32, 512, 64]" - t2082 = prims.neg(t2081) # t2082: "cuda:0 f32[1, 32, 512, 64]" - t2083 = prims.convert_element_type(t2082, dtypes.bfloat16) # t2083: "cuda:0 bf16[1, 32, 512, 64]" - t2085 = prims.cat((t2083, t2079), -1) # t2085: "cuda:0 bf16[1, 32, 512, 128]" - t2086 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t2086: "cuda:0 f32[1, 32, 512, 128]" - t2087 = prims.convert_element_type(t2078, dtypes.float32) # t2087: "cuda:0 f32[1, 32, 512, 128]" - t2088 = ltorch.mul(t2087, t2086) # t2088: "cuda:0 f32[1, 32, 512, 128]" - # t2088 = prims.mul(t2087, t2086) # t2088: "cuda:0 f32[1, 32, 512, 128]" - t2089 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t2089: "cuda:0 f32[1, 32, 512, 128]" - t2090 = prims.convert_element_type(t2085, dtypes.float32) # t2090: "cuda:0 f32[1, 32, 512, 128]" - t2091 = ltorch.mul(t2090, t2089) # t2091: "cuda:0 f32[1, 32, 512, 128]" - # t2091 = prims.mul(t2090, t2089) # t2091: "cuda:0 f32[1, 32, 512, 128]" - t2092 = ltorch.add(t2088, t2091, alpha=None) # t2092: "cuda:0 f32[1, 32, 512, 128]" - # t2092 = prims.add(t2088, t2091) # t2092: "cuda:0 f32[1, 32, 512, 128]" - t2093 = prims.convert_element_type(t2092, dtypes.bfloat16) # t2093: "cuda:0 bf16[1, 32, 512, 128]" - t2094 = prims.slice_prim(t2049, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2094: "cuda:0 bf16[1, 32, 512, 0]" - t2096 = prims.cat((t2077, t2094), -1) # t2096: "cuda:0 bf16[1, 32, 512, 128]" - t2097 = prims.slice_prim(t2055, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2097: "cuda:0 bf16[1, 32, 512, 0]" - t2099 = prims.cat((t2093, t2097), -1) # t2099: "cuda:0 bf16[1, 32, 512, 128]" - (t2100, t2101, t2102, t2103) = cudnn_sdpa_fwd(t2096, t2099, t2061, None, 0.0, True, scale=0.08838834764831843) - t2106 = prims.transpose(t2100, (0, 2, 1, 3)) # t2106: "cuda:0 bf16[1, 512, 32, 128]" - t2110 = prims.reshape(t2106, (1, 512, 4096)) # t2110: "cuda:0 bf16[1, 512, 4096]" - t2111 = prims.linear(t2110, t_transformer_h_14_attn_proj_weight, None) # t2111: "cuda:0 bf16[1, 512, 4096]" - t2112 = prims.convert_element_type(t2111, dtypes.float32) # t2112: "cuda:0 f32[1, 512, 4096]" - t2113 = prims.convert_element_type(t2009, dtypes.float32) # t2113: "cuda:0 f32[1, 512, 4096]" - t2114 = ltorch.add(t2112, t2113, alpha=None) # t2114: "cuda:0 f32[1, 512, 4096]" - # t2114 = prims.add(t2112, t2113) # t2114: "cuda:0 f32[1, 512, 4096]" - t2115 = prims.convert_element_type(t2114, dtypes.bfloat16) # t2115: "cuda:0 bf16[1, 512, 4096]" - t2116 = prims.convert_element_type(t2115, dtypes.float32) # t2116: "cuda:0 f32[1, 512, 4096]" - t2117 = ltorch.mul(t2116, t2116) # t2117: "cuda:0 f32[1, 512, 4096]" - # t2117 = prims.mul(t2116, t2116) # t2117: "cuda:0 f32[1, 512, 4096]" - t2119 = prims.sum(t2117, (2,)) # t2119: "cuda:0 f32[1, 512]" - t2120 = prims.broadcast_in_dim(t2119, [1, 512, 1], [0, 1]) # t2120: "cuda:0 f32[1, 512, 1]" - t2122 = ltorch.true_divide(t2120, 4096.0) # t2122: "cuda:0 f32[1, 512, 1]" - # t2122 = prims.div(t2120, 4096.0) # t2122: "cuda:0 f32[1, 512, 1]" - t2124 = ltorch.add(t2122, 1e-05, alpha=None) # t2124: "cuda:0 f32[1, 512, 1]" - # t2124 = prims.add(t2122, 1e-05) # t2124: "cuda:0 f32[1, 512, 1]" - t2125 = prims.rsqrt(t2124) # t2125: "cuda:0 f32[1, 512, 1]" - t2126 = prims.broadcast_in_dim(t2125, (1, 512, 4096), (0, 1, 2)) # t2126: "cuda:0 f32[1, 512, 4096]" - t2127 = ltorch.mul(t2116, t2126) # t2127: "cuda:0 f32[1, 512, 4096]" - # t2127 = prims.mul(t2116, t2126) # t2127: "cuda:0 f32[1, 512, 4096]" - t2128 = prims.convert_element_type(t2127, dtypes.bfloat16) # t2128: "cuda:0 bf16[1, 512, 4096]" - t2129 = prims.broadcast_in_dim(t_transformer_h_14_norm_2_weight, (1, 512, 4096), (2,)) # t2129: "cuda:0 bf16[1, 512, 4096]" - t2130 = prims.convert_element_type(t2128, dtypes.float32) # t2130: "cuda:0 f32[1, 512, 4096]" - t2131 = prims.convert_element_type(t2129, dtypes.float32) # t2131: "cuda:0 f32[1, 512, 4096]" - t2132 = ltorch.mul(t2130, t2131) # t2132: "cuda:0 f32[1, 512, 4096]" - # t2132 = prims.mul(t2130, t2131) # t2132: "cuda:0 f32[1, 512, 4096]" - t2133 = prims.convert_element_type(t2132, dtypes.bfloat16) # t2133: "cuda:0 bf16[1, 512, 4096]" - t2134 = prims.linear(t2133, t_transformer_h_14_mlp_fc_1_weight, None) # t2134: "cuda:0 bf16[1, 512, 11008]" - t2135 = prims.linear(t2133, t_transformer_h_14_mlp_fc_2_weight, None) # t2135: "cuda:0 bf16[1, 512, 11008]" - t2136 = prims.convert_element_type(t2134, dtypes.float32) # t2136: "cuda:0 f32[1, 512, 11008]" - t2137 = prims.neg(t2136) # t2137: "cuda:0 f32[1, 512, 11008]" - t2138 = prims.exp(t2137) # t2138: "cuda:0 f32[1, 512, 11008]" - t2139 = ltorch.add(1.0, t2138, alpha=None) # t2139: "cuda:0 f32[1, 512, 11008]" - # t2139 = prims.add(1.0, t2138) # t2139: "cuda:0 f32[1, 512, 11008]" - t2140 = prims.reciprocal(t2139) # t2140: "cuda:0 f32[1, 512, 11008]" - t2141 = prims.convert_element_type(t2140, dtypes.bfloat16) # t2141: "cuda:0 bf16[1, 512, 11008]" - t2142 = prims.convert_element_type(t2134, dtypes.float32) # t2142: "cuda:0 f32[1, 512, 11008]" - t2143 = prims.convert_element_type(t2141, dtypes.float32) # t2143: "cuda:0 f32[1, 512, 11008]" - t2144 = ltorch.mul(t2142, t2143) # t2144: "cuda:0 f32[1, 512, 11008]" - # t2144 = prims.mul(t2142, t2143) # t2144: "cuda:0 f32[1, 512, 11008]" - t2145 = prims.convert_element_type(t2144, dtypes.bfloat16) # t2145: "cuda:0 bf16[1, 512, 11008]" - t2146 = prims.convert_element_type(t2145, dtypes.float32) # t2146: "cuda:0 f32[1, 512, 11008]" - t2147 = prims.convert_element_type(t2135, dtypes.float32) # t2147: "cuda:0 f32[1, 512, 11008]" - t2148 = ltorch.mul(t2146, t2147) # t2148: "cuda:0 f32[1, 512, 11008]" - # t2148 = prims.mul(t2146, t2147) # t2148: "cuda:0 f32[1, 512, 11008]" - t2149 = prims.convert_element_type(t2148, dtypes.bfloat16) # t2149: "cuda:0 bf16[1, 512, 11008]" - t2150 = prims.linear(t2149, t_transformer_h_14_mlp_proj_weight, None) # t2150: "cuda:0 bf16[1, 512, 4096]" - t2151 = prims.convert_element_type(t2150, dtypes.float32) # t2151: "cuda:0 f32[1, 512, 4096]" - t2152 = prims.convert_element_type(t2115, dtypes.float32) # t2152: "cuda:0 f32[1, 512, 4096]" - t2153 = ltorch.add(t2151, t2152, alpha=None) # t2153: "cuda:0 f32[1, 512, 4096]" - # t2153 = prims.add(t2151, t2152) # t2153: "cuda:0 f32[1, 512, 4096]" - t2154 = prims.convert_element_type(t2153, dtypes.bfloat16) # t2154: "cuda:0 bf16[1, 512, 4096]" - t2155 = prims.convert_element_type(t2154, dtypes.float32) # t2155: "cuda:0 f32[1, 512, 4096]" - t2156 = ltorch.mul(t2155, t2155) # t2156: "cuda:0 f32[1, 512, 4096]" - # t2156 = prims.mul(t2155, t2155) # t2156: "cuda:0 f32[1, 512, 4096]" - t2158 = prims.sum(t2156, (2,)) # t2158: "cuda:0 f32[1, 512]" - t2159 = prims.broadcast_in_dim(t2158, [1, 512, 1], [0, 1]) # t2159: "cuda:0 f32[1, 512, 1]" - t2161 = ltorch.true_divide(t2159, 4096.0) # t2161: "cuda:0 f32[1, 512, 1]" - # t2161 = prims.div(t2159, 4096.0) # t2161: "cuda:0 f32[1, 512, 1]" - t2163 = ltorch.add(t2161, 1e-05, alpha=None) # t2163: "cuda:0 f32[1, 512, 1]" - # t2163 = prims.add(t2161, 1e-05) # t2163: "cuda:0 f32[1, 512, 1]" - t2164 = prims.rsqrt(t2163) # t2164: "cuda:0 f32[1, 512, 1]" - t2165 = prims.broadcast_in_dim(t2164, (1, 512, 4096), (0, 1, 2)) # t2165: "cuda:0 f32[1, 512, 4096]" - t2166 = ltorch.mul(t2155, t2165) # t2166: "cuda:0 f32[1, 512, 4096]" - # t2166 = prims.mul(t2155, t2165) # t2166: "cuda:0 f32[1, 512, 4096]" - t2167 = prims.convert_element_type(t2166, dtypes.bfloat16) # t2167: "cuda:0 bf16[1, 512, 4096]" - t2168 = prims.broadcast_in_dim(t_transformer_h_15_norm_1_weight, (1, 512, 4096), (2,)) # t2168: "cuda:0 bf16[1, 512, 4096]" - t2169 = prims.convert_element_type(t2167, dtypes.float32) # t2169: "cuda:0 f32[1, 512, 4096]" - t2170 = prims.convert_element_type(t2168, dtypes.float32) # t2170: "cuda:0 f32[1, 512, 4096]" - t2171 = ltorch.mul(t2169, t2170) # t2171: "cuda:0 f32[1, 512, 4096]" - # t2171 = prims.mul(t2169, t2170) # t2171: "cuda:0 f32[1, 512, 4096]" - t2172 = prims.convert_element_type(t2171, dtypes.bfloat16) # t2172: "cuda:0 bf16[1, 512, 4096]" - t2173 = prims.linear(t2172, t_transformer_h_15_attn_attn_weight, None) # t2173: "cuda:0 bf16[1, 512, 12288]" - t2179 = prims.reshape(t2173, (1, 512, 32, 3, 128)) # t2179: "cuda:0 bf16[1, 512, 32, 3, 128]" - t2185 = prims.transpose(t2179, (0, 2, 3, 1, 4)) # t2185: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t2186, t2187, t2188) = ltorch.split(t2185, (1, 1, 1), 2) - # t2186 = prims.slice_prim(t2185, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t2186: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2187 = prims.slice_prim(t2185, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t2187: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2188 = prims.slice_prim(t2185, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t2188: "cuda:0 bf16[1, 32, 1, 512, 128]" - t2194 = prims.reshape(t2186, (1, 32, 512, 128)) # t2194: "cuda:0 bf16[1, 32, 512, 128]" - t2200 = prims.reshape(t2187, (1, 32, 512, 128)) # t2200: "cuda:0 bf16[1, 32, 512, 128]" - t2206 = prims.reshape(t2188, (1, 32, 512, 128)) # t2206: "cuda:0 bf16[1, 32, 512, 128]" - t2207 = prims.slice_prim(t2194, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2207: "cuda:0 bf16[1, 32, 512, 128]" - t2208 = prims.slice_prim(t2207, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2208: "cuda:0 bf16[1, 32, 512, 64]" - t2209 = prims.slice_prim(t2207, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2209: "cuda:0 bf16[1, 32, 512, 64]" - t2210 = prims.convert_element_type(t2209, dtypes.float32) # t2210: "cuda:0 f32[1, 32, 512, 64]" - t2211 = prims.neg(t2210) # t2211: "cuda:0 f32[1, 32, 512, 64]" - t2212 = prims.convert_element_type(t2211, dtypes.bfloat16) # t2212: "cuda:0 bf16[1, 32, 512, 64]" - t2214 = prims.cat((t2212, t2208), -1) # t2214: "cuda:0 bf16[1, 32, 512, 128]" - t2215 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t2215: "cuda:0 f32[1, 32, 512, 128]" - t2216 = prims.convert_element_type(t2207, dtypes.float32) # t2216: "cuda:0 f32[1, 32, 512, 128]" - t2217 = ltorch.mul(t2216, t2215) # t2217: "cuda:0 f32[1, 32, 512, 128]" - # t2217 = prims.mul(t2216, t2215) # t2217: "cuda:0 f32[1, 32, 512, 128]" - t2218 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t2218: "cuda:0 f32[1, 32, 512, 128]" - t2219 = prims.convert_element_type(t2214, dtypes.float32) # t2219: "cuda:0 f32[1, 32, 512, 128]" - t2220 = ltorch.mul(t2219, t2218) # t2220: "cuda:0 f32[1, 32, 512, 128]" - # t2220 = prims.mul(t2219, t2218) # t2220: "cuda:0 f32[1, 32, 512, 128]" - t2221 = ltorch.add(t2217, t2220, alpha=None) # t2221: "cuda:0 f32[1, 32, 512, 128]" - # t2221 = prims.add(t2217, t2220) # t2221: "cuda:0 f32[1, 32, 512, 128]" - t2222 = prims.convert_element_type(t2221, dtypes.bfloat16) # t2222: "cuda:0 bf16[1, 32, 512, 128]" - t2223 = prims.slice_prim(t2200, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2223: "cuda:0 bf16[1, 32, 512, 128]" - t2224 = prims.slice_prim(t2223, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2224: "cuda:0 bf16[1, 32, 512, 64]" - t2225 = prims.slice_prim(t2223, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2225: "cuda:0 bf16[1, 32, 512, 64]" - t2226 = prims.convert_element_type(t2225, dtypes.float32) # t2226: "cuda:0 f32[1, 32, 512, 64]" - t2227 = prims.neg(t2226) # t2227: "cuda:0 f32[1, 32, 512, 64]" - t2228 = prims.convert_element_type(t2227, dtypes.bfloat16) # t2228: "cuda:0 bf16[1, 32, 512, 64]" - t2230 = prims.cat((t2228, t2224), -1) # t2230: "cuda:0 bf16[1, 32, 512, 128]" - t2231 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t2231: "cuda:0 f32[1, 32, 512, 128]" - t2232 = prims.convert_element_type(t2223, dtypes.float32) # t2232: "cuda:0 f32[1, 32, 512, 128]" - t2233 = ltorch.mul(t2232, t2231) # t2233: "cuda:0 f32[1, 32, 512, 128]" - # t2233 = prims.mul(t2232, t2231) # t2233: "cuda:0 f32[1, 32, 512, 128]" - t2234 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t2234: "cuda:0 f32[1, 32, 512, 128]" - t2235 = prims.convert_element_type(t2230, dtypes.float32) # t2235: "cuda:0 f32[1, 32, 512, 128]" - t2236 = ltorch.mul(t2235, t2234) # t2236: "cuda:0 f32[1, 32, 512, 128]" - # t2236 = prims.mul(t2235, t2234) # t2236: "cuda:0 f32[1, 32, 512, 128]" - t2237 = ltorch.add(t2233, t2236, alpha=None) # t2237: "cuda:0 f32[1, 32, 512, 128]" - # t2237 = prims.add(t2233, t2236) # t2237: "cuda:0 f32[1, 32, 512, 128]" - t2238 = prims.convert_element_type(t2237, dtypes.bfloat16) # t2238: "cuda:0 bf16[1, 32, 512, 128]" - t2239 = prims.slice_prim(t2194, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2239: "cuda:0 bf16[1, 32, 512, 0]" - t2241 = prims.cat((t2222, t2239), -1) # t2241: "cuda:0 bf16[1, 32, 512, 128]" - t2242 = prims.slice_prim(t2200, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2242: "cuda:0 bf16[1, 32, 512, 0]" - t2244 = prims.cat((t2238, t2242), -1) # t2244: "cuda:0 bf16[1, 32, 512, 128]" - (t2245, t2246, t2247, t2248) = cudnn_sdpa_fwd(t2241, t2244, t2206, None, 0.0, True, scale=0.08838834764831843) - t2251 = prims.transpose(t2245, (0, 2, 1, 3)) # t2251: "cuda:0 bf16[1, 512, 32, 128]" - t2255 = prims.reshape(t2251, (1, 512, 4096)) # t2255: "cuda:0 bf16[1, 512, 4096]" - t2256 = prims.linear(t2255, t_transformer_h_15_attn_proj_weight, None) # t2256: "cuda:0 bf16[1, 512, 4096]" - t2257 = prims.convert_element_type(t2256, dtypes.float32) # t2257: "cuda:0 f32[1, 512, 4096]" - t2258 = prims.convert_element_type(t2154, dtypes.float32) # t2258: "cuda:0 f32[1, 512, 4096]" - t2259 = ltorch.add(t2257, t2258, alpha=None) # t2259: "cuda:0 f32[1, 512, 4096]" - # t2259 = prims.add(t2257, t2258) # t2259: "cuda:0 f32[1, 512, 4096]" - t2260 = prims.convert_element_type(t2259, dtypes.bfloat16) # t2260: "cuda:0 bf16[1, 512, 4096]" - t2261 = prims.convert_element_type(t2260, dtypes.float32) # t2261: "cuda:0 f32[1, 512, 4096]" - t2262 = ltorch.mul(t2261, t2261) # t2262: "cuda:0 f32[1, 512, 4096]" - # t2262 = prims.mul(t2261, t2261) # t2262: "cuda:0 f32[1, 512, 4096]" - t2264 = prims.sum(t2262, (2,)) # t2264: "cuda:0 f32[1, 512]" - t2265 = prims.broadcast_in_dim(t2264, [1, 512, 1], [0, 1]) # t2265: "cuda:0 f32[1, 512, 1]" - t2267 = ltorch.true_divide(t2265, 4096.0) # t2267: "cuda:0 f32[1, 512, 1]" - # t2267 = prims.div(t2265, 4096.0) # t2267: "cuda:0 f32[1, 512, 1]" - t2269 = ltorch.add(t2267, 1e-05, alpha=None) # t2269: "cuda:0 f32[1, 512, 1]" - # t2269 = prims.add(t2267, 1e-05) # t2269: "cuda:0 f32[1, 512, 1]" - t2270 = prims.rsqrt(t2269) # t2270: "cuda:0 f32[1, 512, 1]" - t2271 = prims.broadcast_in_dim(t2270, (1, 512, 4096), (0, 1, 2)) # t2271: "cuda:0 f32[1, 512, 4096]" - t2272 = ltorch.mul(t2261, t2271) # t2272: "cuda:0 f32[1, 512, 4096]" - # t2272 = prims.mul(t2261, t2271) # t2272: "cuda:0 f32[1, 512, 4096]" - t2273 = prims.convert_element_type(t2272, dtypes.bfloat16) # t2273: "cuda:0 bf16[1, 512, 4096]" - t2274 = prims.broadcast_in_dim(t_transformer_h_15_norm_2_weight, (1, 512, 4096), (2,)) # t2274: "cuda:0 bf16[1, 512, 4096]" - t2275 = prims.convert_element_type(t2273, dtypes.float32) # t2275: "cuda:0 f32[1, 512, 4096]" - t2276 = prims.convert_element_type(t2274, dtypes.float32) # t2276: "cuda:0 f32[1, 512, 4096]" - t2277 = ltorch.mul(t2275, t2276) # t2277: "cuda:0 f32[1, 512, 4096]" - # t2277 = prims.mul(t2275, t2276) # t2277: "cuda:0 f32[1, 512, 4096]" - t2278 = prims.convert_element_type(t2277, dtypes.bfloat16) # t2278: "cuda:0 bf16[1, 512, 4096]" - t2279 = prims.linear(t2278, t_transformer_h_15_mlp_fc_1_weight, None) # t2279: "cuda:0 bf16[1, 512, 11008]" - t2280 = prims.linear(t2278, t_transformer_h_15_mlp_fc_2_weight, None) # t2280: "cuda:0 bf16[1, 512, 11008]" - t2281 = prims.convert_element_type(t2279, dtypes.float32) # t2281: "cuda:0 f32[1, 512, 11008]" - t2282 = prims.neg(t2281) # t2282: "cuda:0 f32[1, 512, 11008]" - t2283 = prims.exp(t2282) # t2283: "cuda:0 f32[1, 512, 11008]" - t2284 = ltorch.add(1.0, t2283, alpha=None) # t2284: "cuda:0 f32[1, 512, 11008]" - # t2284 = prims.add(1.0, t2283) # t2284: "cuda:0 f32[1, 512, 11008]" - t2285 = prims.reciprocal(t2284) # t2285: "cuda:0 f32[1, 512, 11008]" - t2286 = prims.convert_element_type(t2285, dtypes.bfloat16) # t2286: "cuda:0 bf16[1, 512, 11008]" - t2287 = prims.convert_element_type(t2279, dtypes.float32) # t2287: "cuda:0 f32[1, 512, 11008]" - t2288 = prims.convert_element_type(t2286, dtypes.float32) # t2288: "cuda:0 f32[1, 512, 11008]" - t2289 = ltorch.mul(t2287, t2288) # t2289: "cuda:0 f32[1, 512, 11008]" - # t2289 = prims.mul(t2287, t2288) # t2289: "cuda:0 f32[1, 512, 11008]" - t2290 = prims.convert_element_type(t2289, dtypes.bfloat16) # t2290: "cuda:0 bf16[1, 512, 11008]" - t2291 = prims.convert_element_type(t2290, dtypes.float32) # t2291: "cuda:0 f32[1, 512, 11008]" - t2292 = prims.convert_element_type(t2280, dtypes.float32) # t2292: "cuda:0 f32[1, 512, 11008]" - t2293 = ltorch.mul(t2291, t2292) # t2293: "cuda:0 f32[1, 512, 11008]" - # t2293 = prims.mul(t2291, t2292) # t2293: "cuda:0 f32[1, 512, 11008]" - t2294 = prims.convert_element_type(t2293, dtypes.bfloat16) # t2294: "cuda:0 bf16[1, 512, 11008]" - t2295 = prims.linear(t2294, t_transformer_h_15_mlp_proj_weight, None) # t2295: "cuda:0 bf16[1, 512, 4096]" - t2296 = prims.convert_element_type(t2295, dtypes.float32) # t2296: "cuda:0 f32[1, 512, 4096]" - t2297 = prims.convert_element_type(t2260, dtypes.float32) # t2297: "cuda:0 f32[1, 512, 4096]" - t2298 = ltorch.add(t2296, t2297, alpha=None) # t2298: "cuda:0 f32[1, 512, 4096]" - # t2298 = prims.add(t2296, t2297) # t2298: "cuda:0 f32[1, 512, 4096]" - t2299 = prims.convert_element_type(t2298, dtypes.bfloat16) # t2299: "cuda:0 bf16[1, 512, 4096]" - t2300 = prims.convert_element_type(t2299, dtypes.float32) # t2300: "cuda:0 f32[1, 512, 4096]" - t2301 = ltorch.mul(t2300, t2300) # t2301: "cuda:0 f32[1, 512, 4096]" - # t2301 = prims.mul(t2300, t2300) # t2301: "cuda:0 f32[1, 512, 4096]" - t2303 = prims.sum(t2301, (2,)) # t2303: "cuda:0 f32[1, 512]" - t2304 = prims.broadcast_in_dim(t2303, [1, 512, 1], [0, 1]) # t2304: "cuda:0 f32[1, 512, 1]" - t2306 = ltorch.true_divide(t2304, 4096.0) # t2306: "cuda:0 f32[1, 512, 1]" - # t2306 = prims.div(t2304, 4096.0) # t2306: "cuda:0 f32[1, 512, 1]" - t2308 = ltorch.add(t2306, 1e-05, alpha=None) # t2308: "cuda:0 f32[1, 512, 1]" - # t2308 = prims.add(t2306, 1e-05) # t2308: "cuda:0 f32[1, 512, 1]" - t2309 = prims.rsqrt(t2308) # t2309: "cuda:0 f32[1, 512, 1]" - t2310 = prims.broadcast_in_dim(t2309, (1, 512, 4096), (0, 1, 2)) # t2310: "cuda:0 f32[1, 512, 4096]" - t2311 = ltorch.mul(t2300, t2310) # t2311: "cuda:0 f32[1, 512, 4096]" - # t2311 = prims.mul(t2300, t2310) # t2311: "cuda:0 f32[1, 512, 4096]" - t2312 = prims.convert_element_type(t2311, dtypes.bfloat16) # t2312: "cuda:0 bf16[1, 512, 4096]" - t2313 = prims.broadcast_in_dim(t_transformer_ln_f_weight, (1, 512, 4096), (2,)) # t2313: "cuda:0 bf16[1, 512, 4096]" - t2314 = prims.convert_element_type(t2312, dtypes.float32) # t2314: "cuda:0 f32[1, 512, 4096]" - t2315 = prims.convert_element_type(t2313, dtypes.float32) # t2315: "cuda:0 f32[1, 512, 4096]" - t2316 = ltorch.mul(t2314, t2315) # t2316: "cuda:0 f32[1, 512, 4096]" - # t2316 = prims.mul(t2314, t2315) # t2316: "cuda:0 f32[1, 512, 4096]" - t2317 = prims.convert_element_type(t2316, dtypes.bfloat16) # t2317: "cuda:0 bf16[1, 512, 4096]" - t2318 = prims.linear(t2317, t_lm_head_weight, None) # t2318: "cuda:0 bf16[1, 512, 32000]" - return {'output': t2318, 'flat_args': [idx, tos1, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_1_weight, t_transformer_h_0_mlp_fc_2_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_weight, t_transformer_h_0_norm_2_weight, t_transformer_h_1_attn_attn_weight, t_transformer_h_1_attn_proj_weight, t_transformer_h_1_mlp_fc_1_weight, t_transformer_h_1_mlp_fc_2_weight, t_transformer_h_1_mlp_proj_weight, t_transformer_h_1_norm_1_weight, t_transformer_h_1_norm_2_weight, t_transformer_h_2_attn_attn_weight, t_transformer_h_2_attn_proj_weight, t_transformer_h_2_mlp_fc_1_weight, t_transformer_h_2_mlp_fc_2_weight, t_transformer_h_2_mlp_proj_weight, t_transformer_h_2_norm_1_weight, t_transformer_h_2_norm_2_weight, t_transformer_h_3_attn_attn_weight, t_transformer_h_3_attn_proj_weight, t_transformer_h_3_mlp_fc_1_weight, t_transformer_h_3_mlp_fc_2_weight, t_transformer_h_3_mlp_proj_weight, t_transformer_h_3_norm_1_weight, t_transformer_h_3_norm_2_weight, t_transformer_h_4_attn_attn_weight, t_transformer_h_4_attn_proj_weight, t_transformer_h_4_mlp_fc_1_weight, t_transformer_h_4_mlp_fc_2_weight, t_transformer_h_4_mlp_proj_weight, t_transformer_h_4_norm_1_weight, t_transformer_h_4_norm_2_weight, t_transformer_h_5_attn_attn_weight, t_transformer_h_5_attn_proj_weight, t_transformer_h_5_mlp_fc_1_weight, t_transformer_h_5_mlp_fc_2_weight, t_transformer_h_5_mlp_proj_weight, t_transformer_h_5_norm_1_weight, t_transformer_h_5_norm_2_weight, t_transformer_h_6_attn_attn_weight, t_transformer_h_6_attn_proj_weight, t_transformer_h_6_mlp_fc_1_weight, t_transformer_h_6_mlp_fc_2_weight, t_transformer_h_6_mlp_proj_weight, t_transformer_h_6_norm_1_weight, t_transformer_h_6_norm_2_weight, t_transformer_h_7_attn_attn_weight, t_transformer_h_7_attn_proj_weight, t_transformer_h_7_mlp_fc_1_weight, t_transformer_h_7_mlp_fc_2_weight, t_transformer_h_7_mlp_proj_weight, t_transformer_h_7_norm_1_weight, t_transformer_h_7_norm_2_weight, t_transformer_h_8_attn_attn_weight, t_transformer_h_8_attn_proj_weight, t_transformer_h_8_mlp_fc_1_weight, t_transformer_h_8_mlp_fc_2_weight, t_transformer_h_8_mlp_proj_weight, t_transformer_h_8_norm_1_weight, t_transformer_h_8_norm_2_weight, t_transformer_h_9_attn_attn_weight, t_transformer_h_9_attn_proj_weight, t_transformer_h_9_mlp_fc_1_weight, t_transformer_h_9_mlp_fc_2_weight, t_transformer_h_9_mlp_proj_weight, t_transformer_h_9_norm_1_weight, t_transformer_h_9_norm_2_weight, t_transformer_h_10_attn_attn_weight, t_transformer_h_10_attn_proj_weight, t_transformer_h_10_mlp_fc_1_weight, t_transformer_h_10_mlp_fc_2_weight, t_transformer_h_10_mlp_proj_weight, t_transformer_h_10_norm_1_weight, t_transformer_h_10_norm_2_weight, t_transformer_h_11_attn_attn_weight, t_transformer_h_11_attn_proj_weight, t_transformer_h_11_mlp_fc_1_weight, t_transformer_h_11_mlp_fc_2_weight, t_transformer_h_11_mlp_proj_weight, t_transformer_h_11_norm_1_weight, t_transformer_h_11_norm_2_weight, t_transformer_h_12_attn_attn_weight, t_transformer_h_12_attn_proj_weight, t_transformer_h_12_mlp_fc_1_weight, t_transformer_h_12_mlp_fc_2_weight, t_transformer_h_12_mlp_proj_weight, t_transformer_h_12_norm_1_weight, t_transformer_h_12_norm_2_weight, t_transformer_h_13_attn_attn_weight, t_transformer_h_13_attn_proj_weight, t_transformer_h_13_mlp_fc_1_weight, t_transformer_h_13_mlp_fc_2_weight, t_transformer_h_13_mlp_proj_weight, t_transformer_h_13_norm_1_weight, t_transformer_h_13_norm_2_weight, t_transformer_h_14_attn_attn_weight, t_transformer_h_14_attn_proj_weight, t_transformer_h_14_mlp_fc_1_weight, t_transformer_h_14_mlp_fc_2_weight, t_transformer_h_14_mlp_proj_weight, t_transformer_h_14_norm_1_weight, t_transformer_h_14_norm_2_weight, t_transformer_h_15_attn_attn_weight, t_transformer_h_15_attn_proj_weight, t_transformer_h_15_mlp_fc_1_weight, t_transformer_h_15_mlp_fc_2_weight, t_transformer_h_15_mlp_proj_weight, t_transformer_h_15_norm_1_weight, t_transformer_h_15_norm_2_weight, t_transformer_ln_f_weight, t_transformer_wte_weight], 'flat_output': (t2318,)}, ((idx, t5, t11, t12, t17, t16, t19, t_transformer_h_0_attn_attn_weight, t46, t47, t49, t50, t62, t63, t65, t66, t71, t74, t38, t75, t76, t77, t78, t80, t_transformer_h_0_attn_proj_weight, t86, t95, t96, t101, t100, t103, t_transformer_h_0_mlp_fc_1_weight, t_transformer_h_0_mlp_fc_2_weight, t108, t110, t113, t112, t117, t116, t119, t_transformer_h_0_mlp_proj_weight, t125, t134, t135, t140, t139, t142, t_transformer_h_1_attn_attn_weight, t185, t186, t188, t189, t201, t202, t204, t205, t211, t214, t176, t215, t216, t217, t218, t225, t_transformer_h_1_attn_proj_weight, t231, t240, t241, t246, t245, t248, t_transformer_h_1_mlp_fc_1_weight, t_transformer_h_1_mlp_fc_2_weight, t253, t255, t258, t257, t262, t261, t264, t_transformer_h_1_mlp_proj_weight, t270, t279, t280, t285, t284, t287, t_transformer_h_2_attn_attn_weight, t330, t331, t333, t334, t346, t347, t349, t350, t356, t359, t321, t360, t361, t362, t363, t370, t_transformer_h_2_attn_proj_weight, t376, t385, t386, t391, t390, t393, t_transformer_h_2_mlp_fc_1_weight, t_transformer_h_2_mlp_fc_2_weight, t398, t400, t403, t402, t407, t406, t409, t_transformer_h_2_mlp_proj_weight, t415, t424, t425, t430, t429, t432, t_transformer_h_3_attn_attn_weight, t475, t476, t478, t479, t491, t492, t494, t495, t501, t504, t466, t505, t506, t507, t508, t515, t_transformer_h_3_attn_proj_weight, t521, t530, t531, t536, t535, t538, t_transformer_h_3_mlp_fc_1_weight, t_transformer_h_3_mlp_fc_2_weight, t543, t545, t548, t547, t552, t551, t554, t_transformer_h_3_mlp_proj_weight, t560, t569, t570, t575, t574, t577, t_transformer_h_4_attn_attn_weight, t620, t621, t623, t624, t636, t637, t639, t640, t646, t649, t611, t650, t651, t652, t653, t660, t_transformer_h_4_attn_proj_weight, t666, t675, t676, t681, t680, t683, t_transformer_h_4_mlp_fc_1_weight, t_transformer_h_4_mlp_fc_2_weight, t688, t690, t693, t692, t697, t696, t699, t_transformer_h_4_mlp_proj_weight, t705, t714, t715, t720, t719, t722, t_transformer_h_5_attn_attn_weight, t765, t766, t768, t769, t781, t782, t784, t785, t791, t794, t756, t795, t796, t797, t798, t805, t_transformer_h_5_attn_proj_weight, t811, t820, t821, t826, t825, t828, t_transformer_h_5_mlp_fc_1_weight, t_transformer_h_5_mlp_fc_2_weight, t833, t835, t838, t837, t842, t841, t844, t_transformer_h_5_mlp_proj_weight, t850, t859, t860, t865, t864, t867, t_transformer_h_6_attn_attn_weight, t910, t911, t913, t914, t926, t927, t929, t930, t936, t939, t901, t940, t941, t942, t943, t950, t_transformer_h_6_attn_proj_weight, t956, t965, t966, t971, t970, t973, t_transformer_h_6_mlp_fc_1_weight, t_transformer_h_6_mlp_fc_2_weight, t978, t980, t983, t982, t987, t986, t989, t_transformer_h_6_mlp_proj_weight, t995, t1004, t1005, t1010, t1009, t1012, t_transformer_h_7_attn_attn_weight, t1055, t1056, t1058, t1059, t1071, t1072, t1074, t1075, t1081, t1084, t1046, t1085, t1086, t1087, t1088, t1095, t_transformer_h_7_attn_proj_weight, t1101, t1110, t1111, t1116, t1115, t1118, t_transformer_h_7_mlp_fc_1_weight, t_transformer_h_7_mlp_fc_2_weight, t1123, t1125, t1128, t1127, t1132, t1131, t1134, t_transformer_h_7_mlp_proj_weight, t1140, t1149, t1150, t1155, t1154, t1157, t_transformer_h_8_attn_attn_weight, t1200, t1201, t1203, t1204, t1216, t1217, t1219, t1220, t1226, t1229, t1191, t1230, t1231, t1232, t1233, t1240, t_transformer_h_8_attn_proj_weight, t1246, t1255, t1256, t1261, t1260, t1263, t_transformer_h_8_mlp_fc_1_weight, t_transformer_h_8_mlp_fc_2_weight, t1268, t1270, t1273, t1272, t1277, t1276, t1279, t_transformer_h_8_mlp_proj_weight, t1285, t1294, t1295, t1300, t1299, t1302, t_transformer_h_9_attn_attn_weight, t1345, t1346, t1348, t1349, t1361, t1362, t1364, t1365, t1371, t1374, t1336, t1375, t1376, t1377, t1378, t1385, t_transformer_h_9_attn_proj_weight, t1391, t1400, t1401, t1406, t1405, t1408, t_transformer_h_9_mlp_fc_1_weight, t_transformer_h_9_mlp_fc_2_weight, t1413, t1415, t1418, t1417, t1422, t1421, t1424, t_transformer_h_9_mlp_proj_weight, t1430, t1439, t1440, t1445, t1444, t1447, t_transformer_h_10_attn_attn_weight, t1490, t1491, t1493, t1494, t1506, t1507, t1509, t1510, t1516, t1519, t1481, t1520, t1521, t1522, t1523, t1530, t_transformer_h_10_attn_proj_weight, t1536, t1545, t1546, t1551, t1550, t1553, t_transformer_h_10_mlp_fc_1_weight, t_transformer_h_10_mlp_fc_2_weight, t1558, t1560, t1563, t1562, t1567, t1566, t1569, t_transformer_h_10_mlp_proj_weight, t1575, t1584, t1585, t1590, t1589, t1592, t_transformer_h_11_attn_attn_weight, t1635, t1636, t1638, t1639, t1651, t1652, t1654, t1655, t1661, t1664, t1626, t1665, t1666, t1667, t1668, t1675, t_transformer_h_11_attn_proj_weight, t1681, t1690, t1691, t1696, t1695, t1698, t_transformer_h_11_mlp_fc_1_weight, t_transformer_h_11_mlp_fc_2_weight, t1703, t1705, t1708, t1707, t1712, t1711, t1714, t_transformer_h_11_mlp_proj_weight, t1720, t1729, t1730, t1735, t1734, t1737, t_transformer_h_12_attn_attn_weight, t1780, t1781, t1783, t1784, t1796, t1797, t1799, t1800, t1806, t1809, t1771, t1810, t1811, t1812, t1813, t1820, t_transformer_h_12_attn_proj_weight, t1826, t1835, t1836, t1841, t1840, t1843, t_transformer_h_12_mlp_fc_1_weight, t_transformer_h_12_mlp_fc_2_weight, t1848, t1850, t1853, t1852, t1857, t1856, t1859, t_transformer_h_12_mlp_proj_weight, t1865, t1874, t1875, t1880, t1879, t1882, t_transformer_h_13_attn_attn_weight, t1925, t1926, t1928, t1929, t1941, t1942, t1944, t1945, t1951, t1954, t1916, t1955, t1956, t1957, t1958, t1965, t_transformer_h_13_attn_proj_weight, t1971, t1980, t1981, t1986, t1985, t1988, t_transformer_h_13_mlp_fc_1_weight, t_transformer_h_13_mlp_fc_2_weight, t1993, t1995, t1998, t1997, t2002, t2001, t2004, t_transformer_h_13_mlp_proj_weight, t2010, t2019, t2020, t2025, t2024, t2027, t_transformer_h_14_attn_attn_weight, t2070, t2071, t2073, t2074, t2086, t2087, t2089, t2090, t2096, t2099, t2061, t2100, t2101, t2102, t2103, t2110, t_transformer_h_14_attn_proj_weight, t2116, t2125, t2126, t2131, t2130, t2133, t_transformer_h_14_mlp_fc_1_weight, t_transformer_h_14_mlp_fc_2_weight, t2138, t2140, t2143, t2142, t2147, t2146, t2149, t_transformer_h_14_mlp_proj_weight, t2155, t2164, t2165, t2170, t2169, t2172, t_transformer_h_15_attn_attn_weight, t2215, t2216, t2218, t2219, t2231, t2232, t2234, t2235, t2241, t2244, t2206, t2245, t2246, t2247, t2248, t2255, t_transformer_h_15_attn_proj_weight, t2261, t2270, t2271, t2276, t2275, t2278, t_transformer_h_15_mlp_fc_1_weight, t_transformer_h_15_mlp_fc_2_weight, t2283, t2285, t2288, t2287, t2292, t2291, t2294, t_transformer_h_15_mlp_proj_weight, t2300, t2309, t2310, t2315, t2314, t2317, t_lm_head_weight), (32000, False, False, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0)) -============================================ END: primal_trace forward_and_backward_from_trace -============================================ START: before forward_trc transform_for_execution -# Constructed by Augmented forward pass -import thunder -import thunder.core.dtypes as dtypes -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(idx, tos1, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_1_weight, t_transformer_h_0_mlp_fc_2_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_weight, t_transformer_h_0_norm_2_weight, t_transformer_h_1_attn_attn_weight, t_transformer_h_1_attn_proj_weight, t_transformer_h_1_mlp_fc_1_weight, t_transformer_h_1_mlp_fc_2_weight, t_transformer_h_1_mlp_proj_weight, t_transformer_h_1_norm_1_weight, t_transformer_h_1_norm_2_weight, t_transformer_h_2_attn_attn_weight, t_transformer_h_2_attn_proj_weight, t_transformer_h_2_mlp_fc_1_weight, t_transformer_h_2_mlp_fc_2_weight, t_transformer_h_2_mlp_proj_weight, t_transformer_h_2_norm_1_weight, t_transformer_h_2_norm_2_weight, t_transformer_h_3_attn_attn_weight, t_transformer_h_3_attn_proj_weight, t_transformer_h_3_mlp_fc_1_weight, t_transformer_h_3_mlp_fc_2_weight, t_transformer_h_3_mlp_proj_weight, t_transformer_h_3_norm_1_weight, t_transformer_h_3_norm_2_weight, t_transformer_h_4_attn_attn_weight, t_transformer_h_4_attn_proj_weight, t_transformer_h_4_mlp_fc_1_weight, t_transformer_h_4_mlp_fc_2_weight, t_transformer_h_4_mlp_proj_weight, t_transformer_h_4_norm_1_weight, t_transformer_h_4_norm_2_weight, t_transformer_h_5_attn_attn_weight, t_transformer_h_5_attn_proj_weight, t_transformer_h_5_mlp_fc_1_weight, t_transformer_h_5_mlp_fc_2_weight, t_transformer_h_5_mlp_proj_weight, t_transformer_h_5_norm_1_weight, t_transformer_h_5_norm_2_weight, t_transformer_h_6_attn_attn_weight, t_transformer_h_6_attn_proj_weight, t_transformer_h_6_mlp_fc_1_weight, t_transformer_h_6_mlp_fc_2_weight, t_transformer_h_6_mlp_proj_weight, t_transformer_h_6_norm_1_weight, t_transformer_h_6_norm_2_weight, t_transformer_h_7_attn_attn_weight, t_transformer_h_7_attn_proj_weight, t_transformer_h_7_mlp_fc_1_weight, t_transformer_h_7_mlp_fc_2_weight, t_transformer_h_7_mlp_proj_weight, t_transformer_h_7_norm_1_weight, t_transformer_h_7_norm_2_weight, t_transformer_h_8_attn_attn_weight, t_transformer_h_8_attn_proj_weight, t_transformer_h_8_mlp_fc_1_weight, t_transformer_h_8_mlp_fc_2_weight, t_transformer_h_8_mlp_proj_weight, t_transformer_h_8_norm_1_weight, t_transformer_h_8_norm_2_weight, t_transformer_h_9_attn_attn_weight, t_transformer_h_9_attn_proj_weight, t_transformer_h_9_mlp_fc_1_weight, t_transformer_h_9_mlp_fc_2_weight, t_transformer_h_9_mlp_proj_weight, t_transformer_h_9_norm_1_weight, t_transformer_h_9_norm_2_weight, t_transformer_h_10_attn_attn_weight, t_transformer_h_10_attn_proj_weight, t_transformer_h_10_mlp_fc_1_weight, t_transformer_h_10_mlp_fc_2_weight, t_transformer_h_10_mlp_proj_weight, t_transformer_h_10_norm_1_weight, t_transformer_h_10_norm_2_weight, t_transformer_h_11_attn_attn_weight, t_transformer_h_11_attn_proj_weight, t_transformer_h_11_mlp_fc_1_weight, t_transformer_h_11_mlp_fc_2_weight, t_transformer_h_11_mlp_proj_weight, t_transformer_h_11_norm_1_weight, t_transformer_h_11_norm_2_weight, t_transformer_h_12_attn_attn_weight, t_transformer_h_12_attn_proj_weight, t_transformer_h_12_mlp_fc_1_weight, t_transformer_h_12_mlp_fc_2_weight, t_transformer_h_12_mlp_proj_weight, t_transformer_h_12_norm_1_weight, t_transformer_h_12_norm_2_weight, t_transformer_h_13_attn_attn_weight, t_transformer_h_13_attn_proj_weight, t_transformer_h_13_mlp_fc_1_weight, t_transformer_h_13_mlp_fc_2_weight, t_transformer_h_13_mlp_proj_weight, t_transformer_h_13_norm_1_weight, t_transformer_h_13_norm_2_weight, t_transformer_h_14_attn_attn_weight, t_transformer_h_14_attn_proj_weight, t_transformer_h_14_mlp_fc_1_weight, t_transformer_h_14_mlp_fc_2_weight, t_transformer_h_14_mlp_proj_weight, t_transformer_h_14_norm_1_weight, t_transformer_h_14_norm_2_weight, t_transformer_h_15_attn_attn_weight, t_transformer_h_15_attn_proj_weight, t_transformer_h_15_mlp_fc_1_weight, t_transformer_h_15_mlp_fc_2_weight, t_transformer_h_15_mlp_proj_weight, t_transformer_h_15_norm_1_weight, t_transformer_h_15_norm_2_weight, t_transformer_ln_f_weight, t_transformer_wte_weight): - # idx: "cuda:0 i64[1, 512]" - # tos1: "cuda:0 f32[4096, 128]" - # t_lm_head_weight: "cuda:0 bf16[32000, 4096]" - # t_sin: "cuda:0 f32[4096, 128]" - # t_transformer_h_0_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_0_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_0_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_0_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_0_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_0_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_0_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_1_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_1_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_1_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_1_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_1_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_1_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_1_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_2_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_2_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_2_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_2_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_2_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_2_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_2_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_3_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_3_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_3_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_3_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_3_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_3_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_3_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_4_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_4_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_4_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_4_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_4_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_4_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_4_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_5_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_5_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_5_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_5_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_5_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_5_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_5_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_6_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_6_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_6_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_6_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_6_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_6_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_6_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_7_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_7_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_7_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_7_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_7_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_7_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_7_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_8_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_8_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_8_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_8_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_8_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_8_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_8_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_9_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_9_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_9_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_9_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_9_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_9_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_9_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_10_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_10_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_10_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_10_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_10_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_10_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_10_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_11_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_11_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_11_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_11_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_11_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_11_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_11_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_12_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_12_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_12_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_12_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_12_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_12_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_12_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_13_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_13_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_13_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_13_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_13_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_13_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_13_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_14_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_14_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_14_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_14_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_14_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_14_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_14_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_15_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_15_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_15_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_15_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_15_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_15_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_15_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_ln_f_weight: "cuda:0 bf16[4096]" - # t_transformer_wte_weight: "cuda:0 bf16[32000, 4096]" - t0 = prims.slice_prim(tos1, [0, 0], [512, 128], [1, 1]) # t0: "cuda:0 f32[512, 128]" - t1 = prims.slice_prim(t_sin, [0, 0], [512, 128], [1, 1]) # t1: "cuda:0 f32[512, 128]" - t4 = ltorch.embedding(idx, t_transformer_wte_weight, None, None, 2.0, False, False) # t4: "cuda:0 bf16[1, 512, 4096]" - # t2 = ltorch.reshape(idx, [512]) # t2: "cuda:0 i64[512]" - # t2 = prims.reshape(idx, (512,)) # t2: "cuda:0 i64[512]" - # t3 = prims.take(t_transformer_wte_weight, t2, 0) # t3: "cuda:0 bf16[512, 4096]" - # t4 = ltorch.reshape(t3, [1, 512, 4096]) # t4: "cuda:0 bf16[1, 512, 4096]" - # t4 = prims.reshape(t3, (1, 512, 4096)) # t4: "cuda:0 bf16[1, 512, 4096]" - t5 = prims.convert_element_type(t4, dtypes.float32) # t5: "cuda:0 f32[1, 512, 4096]" - t6 = ltorch.mul(t5, t5) # t6: "cuda:0 f32[1, 512, 4096]" - # t6 = prims.mul(t5, t5) # t6: "cuda:0 f32[1, 512, 4096]" - t7 = prims.sum(t6, (2,)) # t7: "cuda:0 f32[1, 512]" - t8 = prims.broadcast_in_dim(t7, [1, 512, 1], [0, 1]) # t8: "cuda:0 f32[1, 512, 1]" - t9 = ltorch.true_divide(t8, 4096.0) # t9: "cuda:0 f32[1, 512, 1]" - # t9 = prims.div(t8, 4096.0) # t9: "cuda:0 f32[1, 512, 1]" - t10 = ltorch.add(t9, 1e-05, alpha=None) # t10: "cuda:0 f32[1, 512, 1]" - # t10 = prims.add(t9, 1e-05) # t10: "cuda:0 f32[1, 512, 1]" - t11 = prims.rsqrt(t10) # t11: "cuda:0 f32[1, 512, 1]" - t12 = prims.broadcast_in_dim(t11, (1, 512, 4096), (0, 1, 2)) # t12: "cuda:0 f32[1, 512, 4096]" - t13 = ltorch.mul(t5, t12) # t13: "cuda:0 f32[1, 512, 4096]" - # t13 = prims.mul(t5, t12) # t13: "cuda:0 f32[1, 512, 4096]" - t14 = prims.convert_element_type(t13, dtypes.bfloat16) # t14: "cuda:0 bf16[1, 512, 4096]" - t15 = prims.broadcast_in_dim(t_transformer_h_0_norm_1_weight, (1, 512, 4096), (2,)) # t15: "cuda:0 bf16[1, 512, 4096]" - t16 = prims.convert_element_type(t14, dtypes.float32) # t16: "cuda:0 f32[1, 512, 4096]" - t17 = prims.convert_element_type(t15, dtypes.float32) # t17: "cuda:0 f32[1, 512, 4096]" - t18 = ltorch.mul(t16, t17) # t18: "cuda:0 f32[1, 512, 4096]" - # t18 = prims.mul(t16, t17) # t18: "cuda:0 f32[1, 512, 4096]" - t19 = prims.convert_element_type(t18, dtypes.bfloat16) # t19: "cuda:0 bf16[1, 512, 4096]" - t20 = prims.linear(t19, t_transformer_h_0_attn_attn_weight, None) # t20: "cuda:0 bf16[1, 512, 12288]" - t21 = prims.reshape(t20, (1, 512, 32, 3, 128)) # t21: "cuda:0 bf16[1, 512, 32, 3, 128]" - t22 = prims.transpose(t21, (0, 2, 3, 1, 4)) # t22: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t23, t24, t25) = ltorch.split(t22, (1, 1, 1), 2) - # t23 = prims.slice_prim(t22, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t23: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t24 = prims.slice_prim(t22, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t24: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t25 = prims.slice_prim(t22, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t25: "cuda:0 bf16[1, 32, 1, 512, 128]" - t26 = prims.reshape(t23, (1, 32, 512, 128)) # t26: "cuda:0 bf16[1, 32, 512, 128]" - t32 = prims.reshape(t24, (1, 32, 512, 128)) # t32: "cuda:0 bf16[1, 32, 512, 128]" - t38 = prims.reshape(t25, (1, 32, 512, 128)) # t38: "cuda:0 bf16[1, 32, 512, 128]" - t39 = prims.slice_prim(t26, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t39: "cuda:0 bf16[1, 32, 512, 128]" - t40 = prims.slice_prim(t39, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t40: "cuda:0 bf16[1, 32, 512, 64]" - t41 = prims.slice_prim(t39, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t41: "cuda:0 bf16[1, 32, 512, 64]" - t42 = prims.convert_element_type(t41, dtypes.float32) # t42: "cuda:0 f32[1, 32, 512, 64]" - t43 = prims.neg(t42) # t43: "cuda:0 f32[1, 32, 512, 64]" - t44 = prims.convert_element_type(t43, dtypes.bfloat16) # t44: "cuda:0 bf16[1, 32, 512, 64]" - t45 = prims.cat((t44, t40), -1) # t45: "cuda:0 bf16[1, 32, 512, 128]" - t46 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t46: "cuda:0 f32[1, 32, 512, 128]" - t47 = prims.convert_element_type(t39, dtypes.float32) # t47: "cuda:0 f32[1, 32, 512, 128]" - t48 = ltorch.mul(t47, t46) # t48: "cuda:0 f32[1, 32, 512, 128]" - # t48 = prims.mul(t47, t46) # t48: "cuda:0 f32[1, 32, 512, 128]" - t49 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t49: "cuda:0 f32[1, 32, 512, 128]" - t50 = prims.convert_element_type(t45, dtypes.float32) # t50: "cuda:0 f32[1, 32, 512, 128]" - t51 = ltorch.mul(t50, t49) # t51: "cuda:0 f32[1, 32, 512, 128]" - # t51 = prims.mul(t50, t49) # t51: "cuda:0 f32[1, 32, 512, 128]" - t52 = ltorch.add(t48, t51, alpha=None) # t52: "cuda:0 f32[1, 32, 512, 128]" - # t52 = prims.add(t48, t51) # t52: "cuda:0 f32[1, 32, 512, 128]" - t53 = prims.convert_element_type(t52, dtypes.bfloat16) # t53: "cuda:0 bf16[1, 32, 512, 128]" - t54 = prims.slice_prim(t32, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t54: "cuda:0 bf16[1, 32, 512, 128]" - t55 = prims.slice_prim(t54, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t55: "cuda:0 bf16[1, 32, 512, 64]" - t56 = prims.slice_prim(t54, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t56: "cuda:0 bf16[1, 32, 512, 64]" - t57 = prims.convert_element_type(t56, dtypes.float32) # t57: "cuda:0 f32[1, 32, 512, 64]" - t58 = prims.neg(t57) # t58: "cuda:0 f32[1, 32, 512, 64]" - t59 = prims.convert_element_type(t58, dtypes.bfloat16) # t59: "cuda:0 bf16[1, 32, 512, 64]" - t61 = prims.cat((t59, t55), -1) # t61: "cuda:0 bf16[1, 32, 512, 128]" - t62 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t62: "cuda:0 f32[1, 32, 512, 128]" - t63 = prims.convert_element_type(t54, dtypes.float32) # t63: "cuda:0 f32[1, 32, 512, 128]" - t64 = ltorch.mul(t63, t62) # t64: "cuda:0 f32[1, 32, 512, 128]" - # t64 = prims.mul(t63, t62) # t64: "cuda:0 f32[1, 32, 512, 128]" - t65 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t65: "cuda:0 f32[1, 32, 512, 128]" - t66 = prims.convert_element_type(t61, dtypes.float32) # t66: "cuda:0 f32[1, 32, 512, 128]" - t67 = ltorch.mul(t66, t65) # t67: "cuda:0 f32[1, 32, 512, 128]" - # t67 = prims.mul(t66, t65) # t67: "cuda:0 f32[1, 32, 512, 128]" - t68 = ltorch.add(t64, t67, alpha=None) # t68: "cuda:0 f32[1, 32, 512, 128]" - # t68 = prims.add(t64, t67) # t68: "cuda:0 f32[1, 32, 512, 128]" - t69 = prims.convert_element_type(t68, dtypes.bfloat16) # t69: "cuda:0 bf16[1, 32, 512, 128]" - t70 = prims.slice_prim(t26, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t70: "cuda:0 bf16[1, 32, 512, 0]" - t71 = prims.cat((t53, t70), -1) # t71: "cuda:0 bf16[1, 32, 512, 128]" - t72 = prims.slice_prim(t32, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t72: "cuda:0 bf16[1, 32, 512, 0]" - t74 = prims.cat((t69, t72), -1) # t74: "cuda:0 bf16[1, 32, 512, 128]" - (t75, t76, t77, t78) = cudnn_sdpa_fwd(t71, t74, t38, None, 0.0, True, scale=0.08838834764831843) - t79 = prims.transpose(t75, (0, 2, 1, 3)) # t79: "cuda:0 bf16[1, 512, 32, 128]" - t80 = prims.reshape(t79, (1, 512, 4096)) # t80: "cuda:0 bf16[1, 512, 4096]" - t81 = prims.linear(t80, t_transformer_h_0_attn_proj_weight, None) # t81: "cuda:0 bf16[1, 512, 4096]" - t82 = prims.convert_element_type(t81, dtypes.float32) # t82: "cuda:0 f32[1, 512, 4096]" - t83 = prims.convert_element_type(t4, dtypes.float32) # t83: "cuda:0 f32[1, 512, 4096]" - t84 = ltorch.add(t82, t83, alpha=None) # t84: "cuda:0 f32[1, 512, 4096]" - # t84 = prims.add(t82, t83) # t84: "cuda:0 f32[1, 512, 4096]" - t85 = prims.convert_element_type(t84, dtypes.bfloat16) # t85: "cuda:0 bf16[1, 512, 4096]" - t86 = prims.convert_element_type(t85, dtypes.float32) # t86: "cuda:0 f32[1, 512, 4096]" - t87 = ltorch.mul(t86, t86) # t87: "cuda:0 f32[1, 512, 4096]" - # t87 = prims.mul(t86, t86) # t87: "cuda:0 f32[1, 512, 4096]" - t89 = prims.sum(t87, (2,)) # t89: "cuda:0 f32[1, 512]" - t90 = prims.broadcast_in_dim(t89, [1, 512, 1], [0, 1]) # t90: "cuda:0 f32[1, 512, 1]" - t92 = ltorch.true_divide(t90, 4096.0) # t92: "cuda:0 f32[1, 512, 1]" - # t92 = prims.div(t90, 4096.0) # t92: "cuda:0 f32[1, 512, 1]" - t94 = ltorch.add(t92, 1e-05, alpha=None) # t94: "cuda:0 f32[1, 512, 1]" - # t94 = prims.add(t92, 1e-05) # t94: "cuda:0 f32[1, 512, 1]" - t95 = prims.rsqrt(t94) # t95: "cuda:0 f32[1, 512, 1]" - t96 = prims.broadcast_in_dim(t95, (1, 512, 4096), (0, 1, 2)) # t96: "cuda:0 f32[1, 512, 4096]" - t97 = ltorch.mul(t86, t96) # t97: "cuda:0 f32[1, 512, 4096]" - # t97 = prims.mul(t86, t96) # t97: "cuda:0 f32[1, 512, 4096]" - t98 = prims.convert_element_type(t97, dtypes.bfloat16) # t98: "cuda:0 bf16[1, 512, 4096]" - t99 = prims.broadcast_in_dim(t_transformer_h_0_norm_2_weight, (1, 512, 4096), (2,)) # t99: "cuda:0 bf16[1, 512, 4096]" - t100 = prims.convert_element_type(t98, dtypes.float32) # t100: "cuda:0 f32[1, 512, 4096]" - t101 = prims.convert_element_type(t99, dtypes.float32) # t101: "cuda:0 f32[1, 512, 4096]" - t102 = ltorch.mul(t100, t101) # t102: "cuda:0 f32[1, 512, 4096]" - # t102 = prims.mul(t100, t101) # t102: "cuda:0 f32[1, 512, 4096]" - t103 = prims.convert_element_type(t102, dtypes.bfloat16) # t103: "cuda:0 bf16[1, 512, 4096]" - t104 = prims.linear(t103, t_transformer_h_0_mlp_fc_1_weight, None) # t104: "cuda:0 bf16[1, 512, 11008]" - t105 = prims.linear(t103, t_transformer_h_0_mlp_fc_2_weight, None) # t105: "cuda:0 bf16[1, 512, 11008]" - t106 = prims.convert_element_type(t104, dtypes.float32) # t106: "cuda:0 f32[1, 512, 11008]" - t107 = prims.neg(t106) # t107: "cuda:0 f32[1, 512, 11008]" - t108 = prims.exp(t107) # t108: "cuda:0 f32[1, 512, 11008]" - t109 = ltorch.add(1.0, t108, alpha=None) # t109: "cuda:0 f32[1, 512, 11008]" - # t109 = prims.add(1.0, t108) # t109: "cuda:0 f32[1, 512, 11008]" - t110 = prims.reciprocal(t109) # t110: "cuda:0 f32[1, 512, 11008]" - t111 = prims.convert_element_type(t110, dtypes.bfloat16) # t111: "cuda:0 bf16[1, 512, 11008]" - t112 = prims.convert_element_type(t104, dtypes.float32) # t112: "cuda:0 f32[1, 512, 11008]" - t113 = prims.convert_element_type(t111, dtypes.float32) # t113: "cuda:0 f32[1, 512, 11008]" - t114 = ltorch.mul(t112, t113) # t114: "cuda:0 f32[1, 512, 11008]" - # t114 = prims.mul(t112, t113) # t114: "cuda:0 f32[1, 512, 11008]" - t115 = prims.convert_element_type(t114, dtypes.bfloat16) # t115: "cuda:0 bf16[1, 512, 11008]" - t116 = prims.convert_element_type(t115, dtypes.float32) # t116: "cuda:0 f32[1, 512, 11008]" - t117 = prims.convert_element_type(t105, dtypes.float32) # t117: "cuda:0 f32[1, 512, 11008]" - t118 = ltorch.mul(t116, t117) # t118: "cuda:0 f32[1, 512, 11008]" - # t118 = prims.mul(t116, t117) # t118: "cuda:0 f32[1, 512, 11008]" - t119 = prims.convert_element_type(t118, dtypes.bfloat16) # t119: "cuda:0 bf16[1, 512, 11008]" - t120 = prims.linear(t119, t_transformer_h_0_mlp_proj_weight, None) # t120: "cuda:0 bf16[1, 512, 4096]" - t121 = prims.convert_element_type(t120, dtypes.float32) # t121: "cuda:0 f32[1, 512, 4096]" - t122 = prims.convert_element_type(t85, dtypes.float32) # t122: "cuda:0 f32[1, 512, 4096]" - t123 = ltorch.add(t121, t122, alpha=None) # t123: "cuda:0 f32[1, 512, 4096]" - # t123 = prims.add(t121, t122) # t123: "cuda:0 f32[1, 512, 4096]" - t124 = prims.convert_element_type(t123, dtypes.bfloat16) # t124: "cuda:0 bf16[1, 512, 4096]" - t125 = prims.convert_element_type(t124, dtypes.float32) # t125: "cuda:0 f32[1, 512, 4096]" - t126 = ltorch.mul(t125, t125) # t126: "cuda:0 f32[1, 512, 4096]" - # t126 = prims.mul(t125, t125) # t126: "cuda:0 f32[1, 512, 4096]" - t128 = prims.sum(t126, (2,)) # t128: "cuda:0 f32[1, 512]" - t129 = prims.broadcast_in_dim(t128, [1, 512, 1], [0, 1]) # t129: "cuda:0 f32[1, 512, 1]" - t131 = ltorch.true_divide(t129, 4096.0) # t131: "cuda:0 f32[1, 512, 1]" - # t131 = prims.div(t129, 4096.0) # t131: "cuda:0 f32[1, 512, 1]" - t133 = ltorch.add(t131, 1e-05, alpha=None) # t133: "cuda:0 f32[1, 512, 1]" - # t133 = prims.add(t131, 1e-05) # t133: "cuda:0 f32[1, 512, 1]" - t134 = prims.rsqrt(t133) # t134: "cuda:0 f32[1, 512, 1]" - t135 = prims.broadcast_in_dim(t134, (1, 512, 4096), (0, 1, 2)) # t135: "cuda:0 f32[1, 512, 4096]" - t136 = ltorch.mul(t125, t135) # t136: "cuda:0 f32[1, 512, 4096]" - # t136 = prims.mul(t125, t135) # t136: "cuda:0 f32[1, 512, 4096]" - t137 = prims.convert_element_type(t136, dtypes.bfloat16) # t137: "cuda:0 bf16[1, 512, 4096]" - t138 = prims.broadcast_in_dim(t_transformer_h_1_norm_1_weight, (1, 512, 4096), (2,)) # t138: "cuda:0 bf16[1, 512, 4096]" - t139 = prims.convert_element_type(t137, dtypes.float32) # t139: "cuda:0 f32[1, 512, 4096]" - t140 = prims.convert_element_type(t138, dtypes.float32) # t140: "cuda:0 f32[1, 512, 4096]" - t141 = ltorch.mul(t139, t140) # t141: "cuda:0 f32[1, 512, 4096]" - # t141 = prims.mul(t139, t140) # t141: "cuda:0 f32[1, 512, 4096]" - t142 = prims.convert_element_type(t141, dtypes.bfloat16) # t142: "cuda:0 bf16[1, 512, 4096]" - t143 = prims.linear(t142, t_transformer_h_1_attn_attn_weight, None) # t143: "cuda:0 bf16[1, 512, 12288]" - t149 = prims.reshape(t143, (1, 512, 32, 3, 128)) # t149: "cuda:0 bf16[1, 512, 32, 3, 128]" - t155 = prims.transpose(t149, (0, 2, 3, 1, 4)) # t155: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t156, t157, t158) = ltorch.split(t155, (1, 1, 1), 2) - # t156 = prims.slice_prim(t155, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t156: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t157 = prims.slice_prim(t155, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t157: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t158 = prims.slice_prim(t155, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t158: "cuda:0 bf16[1, 32, 1, 512, 128]" - t164 = prims.reshape(t156, (1, 32, 512, 128)) # t164: "cuda:0 bf16[1, 32, 512, 128]" - t170 = prims.reshape(t157, (1, 32, 512, 128)) # t170: "cuda:0 bf16[1, 32, 512, 128]" - t176 = prims.reshape(t158, (1, 32, 512, 128)) # t176: "cuda:0 bf16[1, 32, 512, 128]" - t177 = prims.slice_prim(t164, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t177: "cuda:0 bf16[1, 32, 512, 128]" - t178 = prims.slice_prim(t177, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t178: "cuda:0 bf16[1, 32, 512, 64]" - t179 = prims.slice_prim(t177, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t179: "cuda:0 bf16[1, 32, 512, 64]" - t180 = prims.convert_element_type(t179, dtypes.float32) # t180: "cuda:0 f32[1, 32, 512, 64]" - t181 = prims.neg(t180) # t181: "cuda:0 f32[1, 32, 512, 64]" - t182 = prims.convert_element_type(t181, dtypes.bfloat16) # t182: "cuda:0 bf16[1, 32, 512, 64]" - t184 = prims.cat((t182, t178), -1) # t184: "cuda:0 bf16[1, 32, 512, 128]" - t185 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t185: "cuda:0 f32[1, 32, 512, 128]" - t186 = prims.convert_element_type(t177, dtypes.float32) # t186: "cuda:0 f32[1, 32, 512, 128]" - t187 = ltorch.mul(t186, t185) # t187: "cuda:0 f32[1, 32, 512, 128]" - # t187 = prims.mul(t186, t185) # t187: "cuda:0 f32[1, 32, 512, 128]" - t188 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t188: "cuda:0 f32[1, 32, 512, 128]" - t189 = prims.convert_element_type(t184, dtypes.float32) # t189: "cuda:0 f32[1, 32, 512, 128]" - t190 = ltorch.mul(t189, t188) # t190: "cuda:0 f32[1, 32, 512, 128]" - # t190 = prims.mul(t189, t188) # t190: "cuda:0 f32[1, 32, 512, 128]" - t191 = ltorch.add(t187, t190, alpha=None) # t191: "cuda:0 f32[1, 32, 512, 128]" - # t191 = prims.add(t187, t190) # t191: "cuda:0 f32[1, 32, 512, 128]" - t192 = prims.convert_element_type(t191, dtypes.bfloat16) # t192: "cuda:0 bf16[1, 32, 512, 128]" - t193 = prims.slice_prim(t170, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t193: "cuda:0 bf16[1, 32, 512, 128]" - t194 = prims.slice_prim(t193, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t194: "cuda:0 bf16[1, 32, 512, 64]" - t195 = prims.slice_prim(t193, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t195: "cuda:0 bf16[1, 32, 512, 64]" - t196 = prims.convert_element_type(t195, dtypes.float32) # t196: "cuda:0 f32[1, 32, 512, 64]" - t197 = prims.neg(t196) # t197: "cuda:0 f32[1, 32, 512, 64]" - t198 = prims.convert_element_type(t197, dtypes.bfloat16) # t198: "cuda:0 bf16[1, 32, 512, 64]" - t200 = prims.cat((t198, t194), -1) # t200: "cuda:0 bf16[1, 32, 512, 128]" - t201 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t201: "cuda:0 f32[1, 32, 512, 128]" - t202 = prims.convert_element_type(t193, dtypes.float32) # t202: "cuda:0 f32[1, 32, 512, 128]" - t203 = ltorch.mul(t202, t201) # t203: "cuda:0 f32[1, 32, 512, 128]" - # t203 = prims.mul(t202, t201) # t203: "cuda:0 f32[1, 32, 512, 128]" - t204 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t204: "cuda:0 f32[1, 32, 512, 128]" - t205 = prims.convert_element_type(t200, dtypes.float32) # t205: "cuda:0 f32[1, 32, 512, 128]" - t206 = ltorch.mul(t205, t204) # t206: "cuda:0 f32[1, 32, 512, 128]" - # t206 = prims.mul(t205, t204) # t206: "cuda:0 f32[1, 32, 512, 128]" - t207 = ltorch.add(t203, t206, alpha=None) # t207: "cuda:0 f32[1, 32, 512, 128]" - # t207 = prims.add(t203, t206) # t207: "cuda:0 f32[1, 32, 512, 128]" - t208 = prims.convert_element_type(t207, dtypes.bfloat16) # t208: "cuda:0 bf16[1, 32, 512, 128]" - t209 = prims.slice_prim(t164, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t209: "cuda:0 bf16[1, 32, 512, 0]" - t211 = prims.cat((t192, t209), -1) # t211: "cuda:0 bf16[1, 32, 512, 128]" - t212 = prims.slice_prim(t170, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t212: "cuda:0 bf16[1, 32, 512, 0]" - t214 = prims.cat((t208, t212), -1) # t214: "cuda:0 bf16[1, 32, 512, 128]" - (t215, t216, t217, t218) = cudnn_sdpa_fwd(t211, t214, t176, None, 0.0, True, scale=0.08838834764831843) - t221 = prims.transpose(t215, (0, 2, 1, 3)) # t221: "cuda:0 bf16[1, 512, 32, 128]" - t225 = prims.reshape(t221, (1, 512, 4096)) # t225: "cuda:0 bf16[1, 512, 4096]" - t226 = prims.linear(t225, t_transformer_h_1_attn_proj_weight, None) # t226: "cuda:0 bf16[1, 512, 4096]" - t227 = prims.convert_element_type(t226, dtypes.float32) # t227: "cuda:0 f32[1, 512, 4096]" - t228 = prims.convert_element_type(t124, dtypes.float32) # t228: "cuda:0 f32[1, 512, 4096]" - t229 = ltorch.add(t227, t228, alpha=None) # t229: "cuda:0 f32[1, 512, 4096]" - # t229 = prims.add(t227, t228) # t229: "cuda:0 f32[1, 512, 4096]" - t230 = prims.convert_element_type(t229, dtypes.bfloat16) # t230: "cuda:0 bf16[1, 512, 4096]" - t231 = prims.convert_element_type(t230, dtypes.float32) # t231: "cuda:0 f32[1, 512, 4096]" - t232 = ltorch.mul(t231, t231) # t232: "cuda:0 f32[1, 512, 4096]" - # t232 = prims.mul(t231, t231) # t232: "cuda:0 f32[1, 512, 4096]" - t234 = prims.sum(t232, (2,)) # t234: "cuda:0 f32[1, 512]" - t235 = prims.broadcast_in_dim(t234, [1, 512, 1], [0, 1]) # t235: "cuda:0 f32[1, 512, 1]" - t237 = ltorch.true_divide(t235, 4096.0) # t237: "cuda:0 f32[1, 512, 1]" - # t237 = prims.div(t235, 4096.0) # t237: "cuda:0 f32[1, 512, 1]" - t239 = ltorch.add(t237, 1e-05, alpha=None) # t239: "cuda:0 f32[1, 512, 1]" - # t239 = prims.add(t237, 1e-05) # t239: "cuda:0 f32[1, 512, 1]" - t240 = prims.rsqrt(t239) # t240: "cuda:0 f32[1, 512, 1]" - t241 = prims.broadcast_in_dim(t240, (1, 512, 4096), (0, 1, 2)) # t241: "cuda:0 f32[1, 512, 4096]" - t242 = ltorch.mul(t231, t241) # t242: "cuda:0 f32[1, 512, 4096]" - # t242 = prims.mul(t231, t241) # t242: "cuda:0 f32[1, 512, 4096]" - t243 = prims.convert_element_type(t242, dtypes.bfloat16) # t243: "cuda:0 bf16[1, 512, 4096]" - t244 = prims.broadcast_in_dim(t_transformer_h_1_norm_2_weight, (1, 512, 4096), (2,)) # t244: "cuda:0 bf16[1, 512, 4096]" - t245 = prims.convert_element_type(t243, dtypes.float32) # t245: "cuda:0 f32[1, 512, 4096]" - t246 = prims.convert_element_type(t244, dtypes.float32) # t246: "cuda:0 f32[1, 512, 4096]" - t247 = ltorch.mul(t245, t246) # t247: "cuda:0 f32[1, 512, 4096]" - # t247 = prims.mul(t245, t246) # t247: "cuda:0 f32[1, 512, 4096]" - t248 = prims.convert_element_type(t247, dtypes.bfloat16) # t248: "cuda:0 bf16[1, 512, 4096]" - t249 = prims.linear(t248, t_transformer_h_1_mlp_fc_1_weight, None) # t249: "cuda:0 bf16[1, 512, 11008]" - t250 = prims.linear(t248, t_transformer_h_1_mlp_fc_2_weight, None) # t250: "cuda:0 bf16[1, 512, 11008]" - t251 = prims.convert_element_type(t249, dtypes.float32) # t251: "cuda:0 f32[1, 512, 11008]" - t252 = prims.neg(t251) # t252: "cuda:0 f32[1, 512, 11008]" - t253 = prims.exp(t252) # t253: "cuda:0 f32[1, 512, 11008]" - t254 = ltorch.add(1.0, t253, alpha=None) # t254: "cuda:0 f32[1, 512, 11008]" - # t254 = prims.add(1.0, t253) # t254: "cuda:0 f32[1, 512, 11008]" - t255 = prims.reciprocal(t254) # t255: "cuda:0 f32[1, 512, 11008]" - t256 = prims.convert_element_type(t255, dtypes.bfloat16) # t256: "cuda:0 bf16[1, 512, 11008]" - t257 = prims.convert_element_type(t249, dtypes.float32) # t257: "cuda:0 f32[1, 512, 11008]" - t258 = prims.convert_element_type(t256, dtypes.float32) # t258: "cuda:0 f32[1, 512, 11008]" - t259 = ltorch.mul(t257, t258) # t259: "cuda:0 f32[1, 512, 11008]" - # t259 = prims.mul(t257, t258) # t259: "cuda:0 f32[1, 512, 11008]" - t260 = prims.convert_element_type(t259, dtypes.bfloat16) # t260: "cuda:0 bf16[1, 512, 11008]" - t261 = prims.convert_element_type(t260, dtypes.float32) # t261: "cuda:0 f32[1, 512, 11008]" - t262 = prims.convert_element_type(t250, dtypes.float32) # t262: "cuda:0 f32[1, 512, 11008]" - t263 = ltorch.mul(t261, t262) # t263: "cuda:0 f32[1, 512, 11008]" - # t263 = prims.mul(t261, t262) # t263: "cuda:0 f32[1, 512, 11008]" - t264 = prims.convert_element_type(t263, dtypes.bfloat16) # t264: "cuda:0 bf16[1, 512, 11008]" - t265 = prims.linear(t264, t_transformer_h_1_mlp_proj_weight, None) # t265: "cuda:0 bf16[1, 512, 4096]" - t266 = prims.convert_element_type(t265, dtypes.float32) # t266: "cuda:0 f32[1, 512, 4096]" - t267 = prims.convert_element_type(t230, dtypes.float32) # t267: "cuda:0 f32[1, 512, 4096]" - t268 = ltorch.add(t266, t267, alpha=None) # t268: "cuda:0 f32[1, 512, 4096]" - # t268 = prims.add(t266, t267) # t268: "cuda:0 f32[1, 512, 4096]" - t269 = prims.convert_element_type(t268, dtypes.bfloat16) # t269: "cuda:0 bf16[1, 512, 4096]" - t270 = prims.convert_element_type(t269, dtypes.float32) # t270: "cuda:0 f32[1, 512, 4096]" - t271 = ltorch.mul(t270, t270) # t271: "cuda:0 f32[1, 512, 4096]" - # t271 = prims.mul(t270, t270) # t271: "cuda:0 f32[1, 512, 4096]" - t273 = prims.sum(t271, (2,)) # t273: "cuda:0 f32[1, 512]" - t274 = prims.broadcast_in_dim(t273, [1, 512, 1], [0, 1]) # t274: "cuda:0 f32[1, 512, 1]" - t276 = ltorch.true_divide(t274, 4096.0) # t276: "cuda:0 f32[1, 512, 1]" - # t276 = prims.div(t274, 4096.0) # t276: "cuda:0 f32[1, 512, 1]" - t278 = ltorch.add(t276, 1e-05, alpha=None) # t278: "cuda:0 f32[1, 512, 1]" - # t278 = prims.add(t276, 1e-05) # t278: "cuda:0 f32[1, 512, 1]" - t279 = prims.rsqrt(t278) # t279: "cuda:0 f32[1, 512, 1]" - t280 = prims.broadcast_in_dim(t279, (1, 512, 4096), (0, 1, 2)) # t280: "cuda:0 f32[1, 512, 4096]" - t281 = ltorch.mul(t270, t280) # t281: "cuda:0 f32[1, 512, 4096]" - # t281 = prims.mul(t270, t280) # t281: "cuda:0 f32[1, 512, 4096]" - t282 = prims.convert_element_type(t281, dtypes.bfloat16) # t282: "cuda:0 bf16[1, 512, 4096]" - t283 = prims.broadcast_in_dim(t_transformer_h_2_norm_1_weight, (1, 512, 4096), (2,)) # t283: "cuda:0 bf16[1, 512, 4096]" - t284 = prims.convert_element_type(t282, dtypes.float32) # t284: "cuda:0 f32[1, 512, 4096]" - t285 = prims.convert_element_type(t283, dtypes.float32) # t285: "cuda:0 f32[1, 512, 4096]" - t286 = ltorch.mul(t284, t285) # t286: "cuda:0 f32[1, 512, 4096]" - # t286 = prims.mul(t284, t285) # t286: "cuda:0 f32[1, 512, 4096]" - t287 = prims.convert_element_type(t286, dtypes.bfloat16) # t287: "cuda:0 bf16[1, 512, 4096]" - t288 = prims.linear(t287, t_transformer_h_2_attn_attn_weight, None) # t288: "cuda:0 bf16[1, 512, 12288]" - t294 = prims.reshape(t288, (1, 512, 32, 3, 128)) # t294: "cuda:0 bf16[1, 512, 32, 3, 128]" - t300 = prims.transpose(t294, (0, 2, 3, 1, 4)) # t300: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t301, t302, t303) = ltorch.split(t300, (1, 1, 1), 2) - # t301 = prims.slice_prim(t300, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t301: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t302 = prims.slice_prim(t300, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t302: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t303 = prims.slice_prim(t300, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t303: "cuda:0 bf16[1, 32, 1, 512, 128]" - t309 = prims.reshape(t301, (1, 32, 512, 128)) # t309: "cuda:0 bf16[1, 32, 512, 128]" - t315 = prims.reshape(t302, (1, 32, 512, 128)) # t315: "cuda:0 bf16[1, 32, 512, 128]" - t321 = prims.reshape(t303, (1, 32, 512, 128)) # t321: "cuda:0 bf16[1, 32, 512, 128]" - t322 = prims.slice_prim(t309, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t322: "cuda:0 bf16[1, 32, 512, 128]" - t323 = prims.slice_prim(t322, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t323: "cuda:0 bf16[1, 32, 512, 64]" - t324 = prims.slice_prim(t322, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t324: "cuda:0 bf16[1, 32, 512, 64]" - t325 = prims.convert_element_type(t324, dtypes.float32) # t325: "cuda:0 f32[1, 32, 512, 64]" - t326 = prims.neg(t325) # t326: "cuda:0 f32[1, 32, 512, 64]" - t327 = prims.convert_element_type(t326, dtypes.bfloat16) # t327: "cuda:0 bf16[1, 32, 512, 64]" - t329 = prims.cat((t327, t323), -1) # t329: "cuda:0 bf16[1, 32, 512, 128]" - t330 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t330: "cuda:0 f32[1, 32, 512, 128]" - t331 = prims.convert_element_type(t322, dtypes.float32) # t331: "cuda:0 f32[1, 32, 512, 128]" - t332 = ltorch.mul(t331, t330) # t332: "cuda:0 f32[1, 32, 512, 128]" - # t332 = prims.mul(t331, t330) # t332: "cuda:0 f32[1, 32, 512, 128]" - t333 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t333: "cuda:0 f32[1, 32, 512, 128]" - t334 = prims.convert_element_type(t329, dtypes.float32) # t334: "cuda:0 f32[1, 32, 512, 128]" - t335 = ltorch.mul(t334, t333) # t335: "cuda:0 f32[1, 32, 512, 128]" - # t335 = prims.mul(t334, t333) # t335: "cuda:0 f32[1, 32, 512, 128]" - t336 = ltorch.add(t332, t335, alpha=None) # t336: "cuda:0 f32[1, 32, 512, 128]" - # t336 = prims.add(t332, t335) # t336: "cuda:0 f32[1, 32, 512, 128]" - t337 = prims.convert_element_type(t336, dtypes.bfloat16) # t337: "cuda:0 bf16[1, 32, 512, 128]" - t338 = prims.slice_prim(t315, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t338: "cuda:0 bf16[1, 32, 512, 128]" - t339 = prims.slice_prim(t338, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t339: "cuda:0 bf16[1, 32, 512, 64]" - t340 = prims.slice_prim(t338, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t340: "cuda:0 bf16[1, 32, 512, 64]" - t341 = prims.convert_element_type(t340, dtypes.float32) # t341: "cuda:0 f32[1, 32, 512, 64]" - t342 = prims.neg(t341) # t342: "cuda:0 f32[1, 32, 512, 64]" - t343 = prims.convert_element_type(t342, dtypes.bfloat16) # t343: "cuda:0 bf16[1, 32, 512, 64]" - t345 = prims.cat((t343, t339), -1) # t345: "cuda:0 bf16[1, 32, 512, 128]" - t346 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t346: "cuda:0 f32[1, 32, 512, 128]" - t347 = prims.convert_element_type(t338, dtypes.float32) # t347: "cuda:0 f32[1, 32, 512, 128]" - t348 = ltorch.mul(t347, t346) # t348: "cuda:0 f32[1, 32, 512, 128]" - # t348 = prims.mul(t347, t346) # t348: "cuda:0 f32[1, 32, 512, 128]" - t349 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t349: "cuda:0 f32[1, 32, 512, 128]" - t350 = prims.convert_element_type(t345, dtypes.float32) # t350: "cuda:0 f32[1, 32, 512, 128]" - t351 = ltorch.mul(t350, t349) # t351: "cuda:0 f32[1, 32, 512, 128]" - # t351 = prims.mul(t350, t349) # t351: "cuda:0 f32[1, 32, 512, 128]" - t352 = ltorch.add(t348, t351, alpha=None) # t352: "cuda:0 f32[1, 32, 512, 128]" - # t352 = prims.add(t348, t351) # t352: "cuda:0 f32[1, 32, 512, 128]" - t353 = prims.convert_element_type(t352, dtypes.bfloat16) # t353: "cuda:0 bf16[1, 32, 512, 128]" - t354 = prims.slice_prim(t309, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t354: "cuda:0 bf16[1, 32, 512, 0]" - t356 = prims.cat((t337, t354), -1) # t356: "cuda:0 bf16[1, 32, 512, 128]" - t357 = prims.slice_prim(t315, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t357: "cuda:0 bf16[1, 32, 512, 0]" - t359 = prims.cat((t353, t357), -1) # t359: "cuda:0 bf16[1, 32, 512, 128]" - (t360, t361, t362, t363) = cudnn_sdpa_fwd(t356, t359, t321, None, 0.0, True, scale=0.08838834764831843) - t366 = prims.transpose(t360, (0, 2, 1, 3)) # t366: "cuda:0 bf16[1, 512, 32, 128]" - t370 = prims.reshape(t366, (1, 512, 4096)) # t370: "cuda:0 bf16[1, 512, 4096]" - t371 = prims.linear(t370, t_transformer_h_2_attn_proj_weight, None) # t371: "cuda:0 bf16[1, 512, 4096]" - t372 = prims.convert_element_type(t371, dtypes.float32) # t372: "cuda:0 f32[1, 512, 4096]" - t373 = prims.convert_element_type(t269, dtypes.float32) # t373: "cuda:0 f32[1, 512, 4096]" - t374 = ltorch.add(t372, t373, alpha=None) # t374: "cuda:0 f32[1, 512, 4096]" - # t374 = prims.add(t372, t373) # t374: "cuda:0 f32[1, 512, 4096]" - t375 = prims.convert_element_type(t374, dtypes.bfloat16) # t375: "cuda:0 bf16[1, 512, 4096]" - t376 = prims.convert_element_type(t375, dtypes.float32) # t376: "cuda:0 f32[1, 512, 4096]" - t377 = ltorch.mul(t376, t376) # t377: "cuda:0 f32[1, 512, 4096]" - # t377 = prims.mul(t376, t376) # t377: "cuda:0 f32[1, 512, 4096]" - t379 = prims.sum(t377, (2,)) # t379: "cuda:0 f32[1, 512]" - t380 = prims.broadcast_in_dim(t379, [1, 512, 1], [0, 1]) # t380: "cuda:0 f32[1, 512, 1]" - t382 = ltorch.true_divide(t380, 4096.0) # t382: "cuda:0 f32[1, 512, 1]" - # t382 = prims.div(t380, 4096.0) # t382: "cuda:0 f32[1, 512, 1]" - t384 = ltorch.add(t382, 1e-05, alpha=None) # t384: "cuda:0 f32[1, 512, 1]" - # t384 = prims.add(t382, 1e-05) # t384: "cuda:0 f32[1, 512, 1]" - t385 = prims.rsqrt(t384) # t385: "cuda:0 f32[1, 512, 1]" - t386 = prims.broadcast_in_dim(t385, (1, 512, 4096), (0, 1, 2)) # t386: "cuda:0 f32[1, 512, 4096]" - t387 = ltorch.mul(t376, t386) # t387: "cuda:0 f32[1, 512, 4096]" - # t387 = prims.mul(t376, t386) # t387: "cuda:0 f32[1, 512, 4096]" - t388 = prims.convert_element_type(t387, dtypes.bfloat16) # t388: "cuda:0 bf16[1, 512, 4096]" - t389 = prims.broadcast_in_dim(t_transformer_h_2_norm_2_weight, (1, 512, 4096), (2,)) # t389: "cuda:0 bf16[1, 512, 4096]" - t390 = prims.convert_element_type(t388, dtypes.float32) # t390: "cuda:0 f32[1, 512, 4096]" - t391 = prims.convert_element_type(t389, dtypes.float32) # t391: "cuda:0 f32[1, 512, 4096]" - t392 = ltorch.mul(t390, t391) # t392: "cuda:0 f32[1, 512, 4096]" - # t392 = prims.mul(t390, t391) # t392: "cuda:0 f32[1, 512, 4096]" - t393 = prims.convert_element_type(t392, dtypes.bfloat16) # t393: "cuda:0 bf16[1, 512, 4096]" - t394 = prims.linear(t393, t_transformer_h_2_mlp_fc_1_weight, None) # t394: "cuda:0 bf16[1, 512, 11008]" - t395 = prims.linear(t393, t_transformer_h_2_mlp_fc_2_weight, None) # t395: "cuda:0 bf16[1, 512, 11008]" - t396 = prims.convert_element_type(t394, dtypes.float32) # t396: "cuda:0 f32[1, 512, 11008]" - t397 = prims.neg(t396) # t397: "cuda:0 f32[1, 512, 11008]" - t398 = prims.exp(t397) # t398: "cuda:0 f32[1, 512, 11008]" - t399 = ltorch.add(1.0, t398, alpha=None) # t399: "cuda:0 f32[1, 512, 11008]" - # t399 = prims.add(1.0, t398) # t399: "cuda:0 f32[1, 512, 11008]" - t400 = prims.reciprocal(t399) # t400: "cuda:0 f32[1, 512, 11008]" - t401 = prims.convert_element_type(t400, dtypes.bfloat16) # t401: "cuda:0 bf16[1, 512, 11008]" - t402 = prims.convert_element_type(t394, dtypes.float32) # t402: "cuda:0 f32[1, 512, 11008]" - t403 = prims.convert_element_type(t401, dtypes.float32) # t403: "cuda:0 f32[1, 512, 11008]" - t404 = ltorch.mul(t402, t403) # t404: "cuda:0 f32[1, 512, 11008]" - # t404 = prims.mul(t402, t403) # t404: "cuda:0 f32[1, 512, 11008]" - t405 = prims.convert_element_type(t404, dtypes.bfloat16) # t405: "cuda:0 bf16[1, 512, 11008]" - t406 = prims.convert_element_type(t405, dtypes.float32) # t406: "cuda:0 f32[1, 512, 11008]" - t407 = prims.convert_element_type(t395, dtypes.float32) # t407: "cuda:0 f32[1, 512, 11008]" - t408 = ltorch.mul(t406, t407) # t408: "cuda:0 f32[1, 512, 11008]" - # t408 = prims.mul(t406, t407) # t408: "cuda:0 f32[1, 512, 11008]" - t409 = prims.convert_element_type(t408, dtypes.bfloat16) # t409: "cuda:0 bf16[1, 512, 11008]" - t410 = prims.linear(t409, t_transformer_h_2_mlp_proj_weight, None) # t410: "cuda:0 bf16[1, 512, 4096]" - t411 = prims.convert_element_type(t410, dtypes.float32) # t411: "cuda:0 f32[1, 512, 4096]" - t412 = prims.convert_element_type(t375, dtypes.float32) # t412: "cuda:0 f32[1, 512, 4096]" - t413 = ltorch.add(t411, t412, alpha=None) # t413: "cuda:0 f32[1, 512, 4096]" - # t413 = prims.add(t411, t412) # t413: "cuda:0 f32[1, 512, 4096]" - t414 = prims.convert_element_type(t413, dtypes.bfloat16) # t414: "cuda:0 bf16[1, 512, 4096]" - t415 = prims.convert_element_type(t414, dtypes.float32) # t415: "cuda:0 f32[1, 512, 4096]" - t416 = ltorch.mul(t415, t415) # t416: "cuda:0 f32[1, 512, 4096]" - # t416 = prims.mul(t415, t415) # t416: "cuda:0 f32[1, 512, 4096]" - t418 = prims.sum(t416, (2,)) # t418: "cuda:0 f32[1, 512]" - t419 = prims.broadcast_in_dim(t418, [1, 512, 1], [0, 1]) # t419: "cuda:0 f32[1, 512, 1]" - t421 = ltorch.true_divide(t419, 4096.0) # t421: "cuda:0 f32[1, 512, 1]" - # t421 = prims.div(t419, 4096.0) # t421: "cuda:0 f32[1, 512, 1]" - t423 = ltorch.add(t421, 1e-05, alpha=None) # t423: "cuda:0 f32[1, 512, 1]" - # t423 = prims.add(t421, 1e-05) # t423: "cuda:0 f32[1, 512, 1]" - t424 = prims.rsqrt(t423) # t424: "cuda:0 f32[1, 512, 1]" - t425 = prims.broadcast_in_dim(t424, (1, 512, 4096), (0, 1, 2)) # t425: "cuda:0 f32[1, 512, 4096]" - t426 = ltorch.mul(t415, t425) # t426: "cuda:0 f32[1, 512, 4096]" - # t426 = prims.mul(t415, t425) # t426: "cuda:0 f32[1, 512, 4096]" - t427 = prims.convert_element_type(t426, dtypes.bfloat16) # t427: "cuda:0 bf16[1, 512, 4096]" - t428 = prims.broadcast_in_dim(t_transformer_h_3_norm_1_weight, (1, 512, 4096), (2,)) # t428: "cuda:0 bf16[1, 512, 4096]" - t429 = prims.convert_element_type(t427, dtypes.float32) # t429: "cuda:0 f32[1, 512, 4096]" - t430 = prims.convert_element_type(t428, dtypes.float32) # t430: "cuda:0 f32[1, 512, 4096]" - t431 = ltorch.mul(t429, t430) # t431: "cuda:0 f32[1, 512, 4096]" - # t431 = prims.mul(t429, t430) # t431: "cuda:0 f32[1, 512, 4096]" - t432 = prims.convert_element_type(t431, dtypes.bfloat16) # t432: "cuda:0 bf16[1, 512, 4096]" - t433 = prims.linear(t432, t_transformer_h_3_attn_attn_weight, None) # t433: "cuda:0 bf16[1, 512, 12288]" - t439 = prims.reshape(t433, (1, 512, 32, 3, 128)) # t439: "cuda:0 bf16[1, 512, 32, 3, 128]" - t445 = prims.transpose(t439, (0, 2, 3, 1, 4)) # t445: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t446, t447, t448) = ltorch.split(t445, (1, 1, 1), 2) - # t446 = prims.slice_prim(t445, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t446: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t447 = prims.slice_prim(t445, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t447: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t448 = prims.slice_prim(t445, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t448: "cuda:0 bf16[1, 32, 1, 512, 128]" - t454 = prims.reshape(t446, (1, 32, 512, 128)) # t454: "cuda:0 bf16[1, 32, 512, 128]" - t460 = prims.reshape(t447, (1, 32, 512, 128)) # t460: "cuda:0 bf16[1, 32, 512, 128]" - t466 = prims.reshape(t448, (1, 32, 512, 128)) # t466: "cuda:0 bf16[1, 32, 512, 128]" - t467 = prims.slice_prim(t454, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t467: "cuda:0 bf16[1, 32, 512, 128]" - t468 = prims.slice_prim(t467, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t468: "cuda:0 bf16[1, 32, 512, 64]" - t469 = prims.slice_prim(t467, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t469: "cuda:0 bf16[1, 32, 512, 64]" - t470 = prims.convert_element_type(t469, dtypes.float32) # t470: "cuda:0 f32[1, 32, 512, 64]" - t471 = prims.neg(t470) # t471: "cuda:0 f32[1, 32, 512, 64]" - t472 = prims.convert_element_type(t471, dtypes.bfloat16) # t472: "cuda:0 bf16[1, 32, 512, 64]" - t474 = prims.cat((t472, t468), -1) # t474: "cuda:0 bf16[1, 32, 512, 128]" - t475 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t475: "cuda:0 f32[1, 32, 512, 128]" - t476 = prims.convert_element_type(t467, dtypes.float32) # t476: "cuda:0 f32[1, 32, 512, 128]" - t477 = ltorch.mul(t476, t475) # t477: "cuda:0 f32[1, 32, 512, 128]" - # t477 = prims.mul(t476, t475) # t477: "cuda:0 f32[1, 32, 512, 128]" - t478 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t478: "cuda:0 f32[1, 32, 512, 128]" - t479 = prims.convert_element_type(t474, dtypes.float32) # t479: "cuda:0 f32[1, 32, 512, 128]" - t480 = ltorch.mul(t479, t478) # t480: "cuda:0 f32[1, 32, 512, 128]" - # t480 = prims.mul(t479, t478) # t480: "cuda:0 f32[1, 32, 512, 128]" - t481 = ltorch.add(t477, t480, alpha=None) # t481: "cuda:0 f32[1, 32, 512, 128]" - # t481 = prims.add(t477, t480) # t481: "cuda:0 f32[1, 32, 512, 128]" - t482 = prims.convert_element_type(t481, dtypes.bfloat16) # t482: "cuda:0 bf16[1, 32, 512, 128]" - t483 = prims.slice_prim(t460, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t483: "cuda:0 bf16[1, 32, 512, 128]" - t484 = prims.slice_prim(t483, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t484: "cuda:0 bf16[1, 32, 512, 64]" - t485 = prims.slice_prim(t483, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t485: "cuda:0 bf16[1, 32, 512, 64]" - t486 = prims.convert_element_type(t485, dtypes.float32) # t486: "cuda:0 f32[1, 32, 512, 64]" - t487 = prims.neg(t486) # t487: "cuda:0 f32[1, 32, 512, 64]" - t488 = prims.convert_element_type(t487, dtypes.bfloat16) # t488: "cuda:0 bf16[1, 32, 512, 64]" - t490 = prims.cat((t488, t484), -1) # t490: "cuda:0 bf16[1, 32, 512, 128]" - t491 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t491: "cuda:0 f32[1, 32, 512, 128]" - t492 = prims.convert_element_type(t483, dtypes.float32) # t492: "cuda:0 f32[1, 32, 512, 128]" - t493 = ltorch.mul(t492, t491) # t493: "cuda:0 f32[1, 32, 512, 128]" - # t493 = prims.mul(t492, t491) # t493: "cuda:0 f32[1, 32, 512, 128]" - t494 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t494: "cuda:0 f32[1, 32, 512, 128]" - t495 = prims.convert_element_type(t490, dtypes.float32) # t495: "cuda:0 f32[1, 32, 512, 128]" - t496 = ltorch.mul(t495, t494) # t496: "cuda:0 f32[1, 32, 512, 128]" - # t496 = prims.mul(t495, t494) # t496: "cuda:0 f32[1, 32, 512, 128]" - t497 = ltorch.add(t493, t496, alpha=None) # t497: "cuda:0 f32[1, 32, 512, 128]" - # t497 = prims.add(t493, t496) # t497: "cuda:0 f32[1, 32, 512, 128]" - t498 = prims.convert_element_type(t497, dtypes.bfloat16) # t498: "cuda:0 bf16[1, 32, 512, 128]" - t499 = prims.slice_prim(t454, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t499: "cuda:0 bf16[1, 32, 512, 0]" - t501 = prims.cat((t482, t499), -1) # t501: "cuda:0 bf16[1, 32, 512, 128]" - t502 = prims.slice_prim(t460, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t502: "cuda:0 bf16[1, 32, 512, 0]" - t504 = prims.cat((t498, t502), -1) # t504: "cuda:0 bf16[1, 32, 512, 128]" - (t505, t506, t507, t508) = cudnn_sdpa_fwd(t501, t504, t466, None, 0.0, True, scale=0.08838834764831843) - t511 = prims.transpose(t505, (0, 2, 1, 3)) # t511: "cuda:0 bf16[1, 512, 32, 128]" - t515 = prims.reshape(t511, (1, 512, 4096)) # t515: "cuda:0 bf16[1, 512, 4096]" - t516 = prims.linear(t515, t_transformer_h_3_attn_proj_weight, None) # t516: "cuda:0 bf16[1, 512, 4096]" - t517 = prims.convert_element_type(t516, dtypes.float32) # t517: "cuda:0 f32[1, 512, 4096]" - t518 = prims.convert_element_type(t414, dtypes.float32) # t518: "cuda:0 f32[1, 512, 4096]" - t519 = ltorch.add(t517, t518, alpha=None) # t519: "cuda:0 f32[1, 512, 4096]" - # t519 = prims.add(t517, t518) # t519: "cuda:0 f32[1, 512, 4096]" - t520 = prims.convert_element_type(t519, dtypes.bfloat16) # t520: "cuda:0 bf16[1, 512, 4096]" - t521 = prims.convert_element_type(t520, dtypes.float32) # t521: "cuda:0 f32[1, 512, 4096]" - t522 = ltorch.mul(t521, t521) # t522: "cuda:0 f32[1, 512, 4096]" - # t522 = prims.mul(t521, t521) # t522: "cuda:0 f32[1, 512, 4096]" - t524 = prims.sum(t522, (2,)) # t524: "cuda:0 f32[1, 512]" - t525 = prims.broadcast_in_dim(t524, [1, 512, 1], [0, 1]) # t525: "cuda:0 f32[1, 512, 1]" - t527 = ltorch.true_divide(t525, 4096.0) # t527: "cuda:0 f32[1, 512, 1]" - # t527 = prims.div(t525, 4096.0) # t527: "cuda:0 f32[1, 512, 1]" - t529 = ltorch.add(t527, 1e-05, alpha=None) # t529: "cuda:0 f32[1, 512, 1]" - # t529 = prims.add(t527, 1e-05) # t529: "cuda:0 f32[1, 512, 1]" - t530 = prims.rsqrt(t529) # t530: "cuda:0 f32[1, 512, 1]" - t531 = prims.broadcast_in_dim(t530, (1, 512, 4096), (0, 1, 2)) # t531: "cuda:0 f32[1, 512, 4096]" - t532 = ltorch.mul(t521, t531) # t532: "cuda:0 f32[1, 512, 4096]" - # t532 = prims.mul(t521, t531) # t532: "cuda:0 f32[1, 512, 4096]" - t533 = prims.convert_element_type(t532, dtypes.bfloat16) # t533: "cuda:0 bf16[1, 512, 4096]" - t534 = prims.broadcast_in_dim(t_transformer_h_3_norm_2_weight, (1, 512, 4096), (2,)) # t534: "cuda:0 bf16[1, 512, 4096]" - t535 = prims.convert_element_type(t533, dtypes.float32) # t535: "cuda:0 f32[1, 512, 4096]" - t536 = prims.convert_element_type(t534, dtypes.float32) # t536: "cuda:0 f32[1, 512, 4096]" - t537 = ltorch.mul(t535, t536) # t537: "cuda:0 f32[1, 512, 4096]" - # t537 = prims.mul(t535, t536) # t537: "cuda:0 f32[1, 512, 4096]" - t538 = prims.convert_element_type(t537, dtypes.bfloat16) # t538: "cuda:0 bf16[1, 512, 4096]" - t539 = prims.linear(t538, t_transformer_h_3_mlp_fc_1_weight, None) # t539: "cuda:0 bf16[1, 512, 11008]" - t540 = prims.linear(t538, t_transformer_h_3_mlp_fc_2_weight, None) # t540: "cuda:0 bf16[1, 512, 11008]" - t541 = prims.convert_element_type(t539, dtypes.float32) # t541: "cuda:0 f32[1, 512, 11008]" - t542 = prims.neg(t541) # t542: "cuda:0 f32[1, 512, 11008]" - t543 = prims.exp(t542) # t543: "cuda:0 f32[1, 512, 11008]" - t544 = ltorch.add(1.0, t543, alpha=None) # t544: "cuda:0 f32[1, 512, 11008]" - # t544 = prims.add(1.0, t543) # t544: "cuda:0 f32[1, 512, 11008]" - t545 = prims.reciprocal(t544) # t545: "cuda:0 f32[1, 512, 11008]" - t546 = prims.convert_element_type(t545, dtypes.bfloat16) # t546: "cuda:0 bf16[1, 512, 11008]" - t547 = prims.convert_element_type(t539, dtypes.float32) # t547: "cuda:0 f32[1, 512, 11008]" - t548 = prims.convert_element_type(t546, dtypes.float32) # t548: "cuda:0 f32[1, 512, 11008]" - t549 = ltorch.mul(t547, t548) # t549: "cuda:0 f32[1, 512, 11008]" - # t549 = prims.mul(t547, t548) # t549: "cuda:0 f32[1, 512, 11008]" - t550 = prims.convert_element_type(t549, dtypes.bfloat16) # t550: "cuda:0 bf16[1, 512, 11008]" - t551 = prims.convert_element_type(t550, dtypes.float32) # t551: "cuda:0 f32[1, 512, 11008]" - t552 = prims.convert_element_type(t540, dtypes.float32) # t552: "cuda:0 f32[1, 512, 11008]" - t553 = ltorch.mul(t551, t552) # t553: "cuda:0 f32[1, 512, 11008]" - # t553 = prims.mul(t551, t552) # t553: "cuda:0 f32[1, 512, 11008]" - t554 = prims.convert_element_type(t553, dtypes.bfloat16) # t554: "cuda:0 bf16[1, 512, 11008]" - t555 = prims.linear(t554, t_transformer_h_3_mlp_proj_weight, None) # t555: "cuda:0 bf16[1, 512, 4096]" - t556 = prims.convert_element_type(t555, dtypes.float32) # t556: "cuda:0 f32[1, 512, 4096]" - t557 = prims.convert_element_type(t520, dtypes.float32) # t557: "cuda:0 f32[1, 512, 4096]" - t558 = ltorch.add(t556, t557, alpha=None) # t558: "cuda:0 f32[1, 512, 4096]" - # t558 = prims.add(t556, t557) # t558: "cuda:0 f32[1, 512, 4096]" - t559 = prims.convert_element_type(t558, dtypes.bfloat16) # t559: "cuda:0 bf16[1, 512, 4096]" - t560 = prims.convert_element_type(t559, dtypes.float32) # t560: "cuda:0 f32[1, 512, 4096]" - t561 = ltorch.mul(t560, t560) # t561: "cuda:0 f32[1, 512, 4096]" - # t561 = prims.mul(t560, t560) # t561: "cuda:0 f32[1, 512, 4096]" - t563 = prims.sum(t561, (2,)) # t563: "cuda:0 f32[1, 512]" - t564 = prims.broadcast_in_dim(t563, [1, 512, 1], [0, 1]) # t564: "cuda:0 f32[1, 512, 1]" - t566 = ltorch.true_divide(t564, 4096.0) # t566: "cuda:0 f32[1, 512, 1]" - # t566 = prims.div(t564, 4096.0) # t566: "cuda:0 f32[1, 512, 1]" - t568 = ltorch.add(t566, 1e-05, alpha=None) # t568: "cuda:0 f32[1, 512, 1]" - # t568 = prims.add(t566, 1e-05) # t568: "cuda:0 f32[1, 512, 1]" - t569 = prims.rsqrt(t568) # t569: "cuda:0 f32[1, 512, 1]" - t570 = prims.broadcast_in_dim(t569, (1, 512, 4096), (0, 1, 2)) # t570: "cuda:0 f32[1, 512, 4096]" - t571 = ltorch.mul(t560, t570) # t571: "cuda:0 f32[1, 512, 4096]" - # t571 = prims.mul(t560, t570) # t571: "cuda:0 f32[1, 512, 4096]" - t572 = prims.convert_element_type(t571, dtypes.bfloat16) # t572: "cuda:0 bf16[1, 512, 4096]" - t573 = prims.broadcast_in_dim(t_transformer_h_4_norm_1_weight, (1, 512, 4096), (2,)) # t573: "cuda:0 bf16[1, 512, 4096]" - t574 = prims.convert_element_type(t572, dtypes.float32) # t574: "cuda:0 f32[1, 512, 4096]" - t575 = prims.convert_element_type(t573, dtypes.float32) # t575: "cuda:0 f32[1, 512, 4096]" - t576 = ltorch.mul(t574, t575) # t576: "cuda:0 f32[1, 512, 4096]" - # t576 = prims.mul(t574, t575) # t576: "cuda:0 f32[1, 512, 4096]" - t577 = prims.convert_element_type(t576, dtypes.bfloat16) # t577: "cuda:0 bf16[1, 512, 4096]" - t578 = prims.linear(t577, t_transformer_h_4_attn_attn_weight, None) # t578: "cuda:0 bf16[1, 512, 12288]" - t584 = prims.reshape(t578, (1, 512, 32, 3, 128)) # t584: "cuda:0 bf16[1, 512, 32, 3, 128]" - t590 = prims.transpose(t584, (0, 2, 3, 1, 4)) # t590: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t591, t592, t593) = ltorch.split(t590, (1, 1, 1), 2) - # t591 = prims.slice_prim(t590, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t591: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t592 = prims.slice_prim(t590, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t592: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t593 = prims.slice_prim(t590, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t593: "cuda:0 bf16[1, 32, 1, 512, 128]" - t599 = prims.reshape(t591, (1, 32, 512, 128)) # t599: "cuda:0 bf16[1, 32, 512, 128]" - t605 = prims.reshape(t592, (1, 32, 512, 128)) # t605: "cuda:0 bf16[1, 32, 512, 128]" - t611 = prims.reshape(t593, (1, 32, 512, 128)) # t611: "cuda:0 bf16[1, 32, 512, 128]" - t612 = prims.slice_prim(t599, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t612: "cuda:0 bf16[1, 32, 512, 128]" - t613 = prims.slice_prim(t612, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t613: "cuda:0 bf16[1, 32, 512, 64]" - t614 = prims.slice_prim(t612, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t614: "cuda:0 bf16[1, 32, 512, 64]" - t615 = prims.convert_element_type(t614, dtypes.float32) # t615: "cuda:0 f32[1, 32, 512, 64]" - t616 = prims.neg(t615) # t616: "cuda:0 f32[1, 32, 512, 64]" - t617 = prims.convert_element_type(t616, dtypes.bfloat16) # t617: "cuda:0 bf16[1, 32, 512, 64]" - t619 = prims.cat((t617, t613), -1) # t619: "cuda:0 bf16[1, 32, 512, 128]" - t620 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t620: "cuda:0 f32[1, 32, 512, 128]" - t621 = prims.convert_element_type(t612, dtypes.float32) # t621: "cuda:0 f32[1, 32, 512, 128]" - t622 = ltorch.mul(t621, t620) # t622: "cuda:0 f32[1, 32, 512, 128]" - # t622 = prims.mul(t621, t620) # t622: "cuda:0 f32[1, 32, 512, 128]" - t623 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t623: "cuda:0 f32[1, 32, 512, 128]" - t624 = prims.convert_element_type(t619, dtypes.float32) # t624: "cuda:0 f32[1, 32, 512, 128]" - t625 = ltorch.mul(t624, t623) # t625: "cuda:0 f32[1, 32, 512, 128]" - # t625 = prims.mul(t624, t623) # t625: "cuda:0 f32[1, 32, 512, 128]" - t626 = ltorch.add(t622, t625, alpha=None) # t626: "cuda:0 f32[1, 32, 512, 128]" - # t626 = prims.add(t622, t625) # t626: "cuda:0 f32[1, 32, 512, 128]" - t627 = prims.convert_element_type(t626, dtypes.bfloat16) # t627: "cuda:0 bf16[1, 32, 512, 128]" - t628 = prims.slice_prim(t605, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t628: "cuda:0 bf16[1, 32, 512, 128]" - t629 = prims.slice_prim(t628, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t629: "cuda:0 bf16[1, 32, 512, 64]" - t630 = prims.slice_prim(t628, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t630: "cuda:0 bf16[1, 32, 512, 64]" - t631 = prims.convert_element_type(t630, dtypes.float32) # t631: "cuda:0 f32[1, 32, 512, 64]" - t632 = prims.neg(t631) # t632: "cuda:0 f32[1, 32, 512, 64]" - t633 = prims.convert_element_type(t632, dtypes.bfloat16) # t633: "cuda:0 bf16[1, 32, 512, 64]" - t635 = prims.cat((t633, t629), -1) # t635: "cuda:0 bf16[1, 32, 512, 128]" - t636 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t636: "cuda:0 f32[1, 32, 512, 128]" - t637 = prims.convert_element_type(t628, dtypes.float32) # t637: "cuda:0 f32[1, 32, 512, 128]" - t638 = ltorch.mul(t637, t636) # t638: "cuda:0 f32[1, 32, 512, 128]" - # t638 = prims.mul(t637, t636) # t638: "cuda:0 f32[1, 32, 512, 128]" - t639 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t639: "cuda:0 f32[1, 32, 512, 128]" - t640 = prims.convert_element_type(t635, dtypes.float32) # t640: "cuda:0 f32[1, 32, 512, 128]" - t641 = ltorch.mul(t640, t639) # t641: "cuda:0 f32[1, 32, 512, 128]" - # t641 = prims.mul(t640, t639) # t641: "cuda:0 f32[1, 32, 512, 128]" - t642 = ltorch.add(t638, t641, alpha=None) # t642: "cuda:0 f32[1, 32, 512, 128]" - # t642 = prims.add(t638, t641) # t642: "cuda:0 f32[1, 32, 512, 128]" - t643 = prims.convert_element_type(t642, dtypes.bfloat16) # t643: "cuda:0 bf16[1, 32, 512, 128]" - t644 = prims.slice_prim(t599, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t644: "cuda:0 bf16[1, 32, 512, 0]" - t646 = prims.cat((t627, t644), -1) # t646: "cuda:0 bf16[1, 32, 512, 128]" - t647 = prims.slice_prim(t605, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t647: "cuda:0 bf16[1, 32, 512, 0]" - t649 = prims.cat((t643, t647), -1) # t649: "cuda:0 bf16[1, 32, 512, 128]" - (t650, t651, t652, t653) = cudnn_sdpa_fwd(t646, t649, t611, None, 0.0, True, scale=0.08838834764831843) - t656 = prims.transpose(t650, (0, 2, 1, 3)) # t656: "cuda:0 bf16[1, 512, 32, 128]" - t660 = prims.reshape(t656, (1, 512, 4096)) # t660: "cuda:0 bf16[1, 512, 4096]" - t661 = prims.linear(t660, t_transformer_h_4_attn_proj_weight, None) # t661: "cuda:0 bf16[1, 512, 4096]" - t662 = prims.convert_element_type(t661, dtypes.float32) # t662: "cuda:0 f32[1, 512, 4096]" - t663 = prims.convert_element_type(t559, dtypes.float32) # t663: "cuda:0 f32[1, 512, 4096]" - t664 = ltorch.add(t662, t663, alpha=None) # t664: "cuda:0 f32[1, 512, 4096]" - # t664 = prims.add(t662, t663) # t664: "cuda:0 f32[1, 512, 4096]" - t665 = prims.convert_element_type(t664, dtypes.bfloat16) # t665: "cuda:0 bf16[1, 512, 4096]" - t666 = prims.convert_element_type(t665, dtypes.float32) # t666: "cuda:0 f32[1, 512, 4096]" - t667 = ltorch.mul(t666, t666) # t667: "cuda:0 f32[1, 512, 4096]" - # t667 = prims.mul(t666, t666) # t667: "cuda:0 f32[1, 512, 4096]" - t669 = prims.sum(t667, (2,)) # t669: "cuda:0 f32[1, 512]" - t670 = prims.broadcast_in_dim(t669, [1, 512, 1], [0, 1]) # t670: "cuda:0 f32[1, 512, 1]" - t672 = ltorch.true_divide(t670, 4096.0) # t672: "cuda:0 f32[1, 512, 1]" - # t672 = prims.div(t670, 4096.0) # t672: "cuda:0 f32[1, 512, 1]" - t674 = ltorch.add(t672, 1e-05, alpha=None) # t674: "cuda:0 f32[1, 512, 1]" - # t674 = prims.add(t672, 1e-05) # t674: "cuda:0 f32[1, 512, 1]" - t675 = prims.rsqrt(t674) # t675: "cuda:0 f32[1, 512, 1]" - t676 = prims.broadcast_in_dim(t675, (1, 512, 4096), (0, 1, 2)) # t676: "cuda:0 f32[1, 512, 4096]" - t677 = ltorch.mul(t666, t676) # t677: "cuda:0 f32[1, 512, 4096]" - # t677 = prims.mul(t666, t676) # t677: "cuda:0 f32[1, 512, 4096]" - t678 = prims.convert_element_type(t677, dtypes.bfloat16) # t678: "cuda:0 bf16[1, 512, 4096]" - t679 = prims.broadcast_in_dim(t_transformer_h_4_norm_2_weight, (1, 512, 4096), (2,)) # t679: "cuda:0 bf16[1, 512, 4096]" - t680 = prims.convert_element_type(t678, dtypes.float32) # t680: "cuda:0 f32[1, 512, 4096]" - t681 = prims.convert_element_type(t679, dtypes.float32) # t681: "cuda:0 f32[1, 512, 4096]" - t682 = ltorch.mul(t680, t681) # t682: "cuda:0 f32[1, 512, 4096]" - # t682 = prims.mul(t680, t681) # t682: "cuda:0 f32[1, 512, 4096]" - t683 = prims.convert_element_type(t682, dtypes.bfloat16) # t683: "cuda:0 bf16[1, 512, 4096]" - t684 = prims.linear(t683, t_transformer_h_4_mlp_fc_1_weight, None) # t684: "cuda:0 bf16[1, 512, 11008]" - t685 = prims.linear(t683, t_transformer_h_4_mlp_fc_2_weight, None) # t685: "cuda:0 bf16[1, 512, 11008]" - t686 = prims.convert_element_type(t684, dtypes.float32) # t686: "cuda:0 f32[1, 512, 11008]" - t687 = prims.neg(t686) # t687: "cuda:0 f32[1, 512, 11008]" - t688 = prims.exp(t687) # t688: "cuda:0 f32[1, 512, 11008]" - t689 = ltorch.add(1.0, t688, alpha=None) # t689: "cuda:0 f32[1, 512, 11008]" - # t689 = prims.add(1.0, t688) # t689: "cuda:0 f32[1, 512, 11008]" - t690 = prims.reciprocal(t689) # t690: "cuda:0 f32[1, 512, 11008]" - t691 = prims.convert_element_type(t690, dtypes.bfloat16) # t691: "cuda:0 bf16[1, 512, 11008]" - t692 = prims.convert_element_type(t684, dtypes.float32) # t692: "cuda:0 f32[1, 512, 11008]" - t693 = prims.convert_element_type(t691, dtypes.float32) # t693: "cuda:0 f32[1, 512, 11008]" - t694 = ltorch.mul(t692, t693) # t694: "cuda:0 f32[1, 512, 11008]" - # t694 = prims.mul(t692, t693) # t694: "cuda:0 f32[1, 512, 11008]" - t695 = prims.convert_element_type(t694, dtypes.bfloat16) # t695: "cuda:0 bf16[1, 512, 11008]" - t696 = prims.convert_element_type(t695, dtypes.float32) # t696: "cuda:0 f32[1, 512, 11008]" - t697 = prims.convert_element_type(t685, dtypes.float32) # t697: "cuda:0 f32[1, 512, 11008]" - t698 = ltorch.mul(t696, t697) # t698: "cuda:0 f32[1, 512, 11008]" - # t698 = prims.mul(t696, t697) # t698: "cuda:0 f32[1, 512, 11008]" - t699 = prims.convert_element_type(t698, dtypes.bfloat16) # t699: "cuda:0 bf16[1, 512, 11008]" - t700 = prims.linear(t699, t_transformer_h_4_mlp_proj_weight, None) # t700: "cuda:0 bf16[1, 512, 4096]" - t701 = prims.convert_element_type(t700, dtypes.float32) # t701: "cuda:0 f32[1, 512, 4096]" - t702 = prims.convert_element_type(t665, dtypes.float32) # t702: "cuda:0 f32[1, 512, 4096]" - t703 = ltorch.add(t701, t702, alpha=None) # t703: "cuda:0 f32[1, 512, 4096]" - # t703 = prims.add(t701, t702) # t703: "cuda:0 f32[1, 512, 4096]" - t704 = prims.convert_element_type(t703, dtypes.bfloat16) # t704: "cuda:0 bf16[1, 512, 4096]" - t705 = prims.convert_element_type(t704, dtypes.float32) # t705: "cuda:0 f32[1, 512, 4096]" - t706 = ltorch.mul(t705, t705) # t706: "cuda:0 f32[1, 512, 4096]" - # t706 = prims.mul(t705, t705) # t706: "cuda:0 f32[1, 512, 4096]" - t708 = prims.sum(t706, (2,)) # t708: "cuda:0 f32[1, 512]" - t709 = prims.broadcast_in_dim(t708, [1, 512, 1], [0, 1]) # t709: "cuda:0 f32[1, 512, 1]" - t711 = ltorch.true_divide(t709, 4096.0) # t711: "cuda:0 f32[1, 512, 1]" - # t711 = prims.div(t709, 4096.0) # t711: "cuda:0 f32[1, 512, 1]" - t713 = ltorch.add(t711, 1e-05, alpha=None) # t713: "cuda:0 f32[1, 512, 1]" - # t713 = prims.add(t711, 1e-05) # t713: "cuda:0 f32[1, 512, 1]" - t714 = prims.rsqrt(t713) # t714: "cuda:0 f32[1, 512, 1]" - t715 = prims.broadcast_in_dim(t714, (1, 512, 4096), (0, 1, 2)) # t715: "cuda:0 f32[1, 512, 4096]" - t716 = ltorch.mul(t705, t715) # t716: "cuda:0 f32[1, 512, 4096]" - # t716 = prims.mul(t705, t715) # t716: "cuda:0 f32[1, 512, 4096]" - t717 = prims.convert_element_type(t716, dtypes.bfloat16) # t717: "cuda:0 bf16[1, 512, 4096]" - t718 = prims.broadcast_in_dim(t_transformer_h_5_norm_1_weight, (1, 512, 4096), (2,)) # t718: "cuda:0 bf16[1, 512, 4096]" - t719 = prims.convert_element_type(t717, dtypes.float32) # t719: "cuda:0 f32[1, 512, 4096]" - t720 = prims.convert_element_type(t718, dtypes.float32) # t720: "cuda:0 f32[1, 512, 4096]" - t721 = ltorch.mul(t719, t720) # t721: "cuda:0 f32[1, 512, 4096]" - # t721 = prims.mul(t719, t720) # t721: "cuda:0 f32[1, 512, 4096]" - t722 = prims.convert_element_type(t721, dtypes.bfloat16) # t722: "cuda:0 bf16[1, 512, 4096]" - t723 = prims.linear(t722, t_transformer_h_5_attn_attn_weight, None) # t723: "cuda:0 bf16[1, 512, 12288]" - t729 = prims.reshape(t723, (1, 512, 32, 3, 128)) # t729: "cuda:0 bf16[1, 512, 32, 3, 128]" - t735 = prims.transpose(t729, (0, 2, 3, 1, 4)) # t735: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t736, t737, t738) = ltorch.split(t735, (1, 1, 1), 2) - # t736 = prims.slice_prim(t735, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t736: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t737 = prims.slice_prim(t735, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t737: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t738 = prims.slice_prim(t735, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t738: "cuda:0 bf16[1, 32, 1, 512, 128]" - t744 = prims.reshape(t736, (1, 32, 512, 128)) # t744: "cuda:0 bf16[1, 32, 512, 128]" - t750 = prims.reshape(t737, (1, 32, 512, 128)) # t750: "cuda:0 bf16[1, 32, 512, 128]" - t756 = prims.reshape(t738, (1, 32, 512, 128)) # t756: "cuda:0 bf16[1, 32, 512, 128]" - t757 = prims.slice_prim(t744, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t757: "cuda:0 bf16[1, 32, 512, 128]" - t758 = prims.slice_prim(t757, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t758: "cuda:0 bf16[1, 32, 512, 64]" - t759 = prims.slice_prim(t757, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t759: "cuda:0 bf16[1, 32, 512, 64]" - t760 = prims.convert_element_type(t759, dtypes.float32) # t760: "cuda:0 f32[1, 32, 512, 64]" - t761 = prims.neg(t760) # t761: "cuda:0 f32[1, 32, 512, 64]" - t762 = prims.convert_element_type(t761, dtypes.bfloat16) # t762: "cuda:0 bf16[1, 32, 512, 64]" - t764 = prims.cat((t762, t758), -1) # t764: "cuda:0 bf16[1, 32, 512, 128]" - t765 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t765: "cuda:0 f32[1, 32, 512, 128]" - t766 = prims.convert_element_type(t757, dtypes.float32) # t766: "cuda:0 f32[1, 32, 512, 128]" - t767 = ltorch.mul(t766, t765) # t767: "cuda:0 f32[1, 32, 512, 128]" - # t767 = prims.mul(t766, t765) # t767: "cuda:0 f32[1, 32, 512, 128]" - t768 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t768: "cuda:0 f32[1, 32, 512, 128]" - t769 = prims.convert_element_type(t764, dtypes.float32) # t769: "cuda:0 f32[1, 32, 512, 128]" - t770 = ltorch.mul(t769, t768) # t770: "cuda:0 f32[1, 32, 512, 128]" - # t770 = prims.mul(t769, t768) # t770: "cuda:0 f32[1, 32, 512, 128]" - t771 = ltorch.add(t767, t770, alpha=None) # t771: "cuda:0 f32[1, 32, 512, 128]" - # t771 = prims.add(t767, t770) # t771: "cuda:0 f32[1, 32, 512, 128]" - t772 = prims.convert_element_type(t771, dtypes.bfloat16) # t772: "cuda:0 bf16[1, 32, 512, 128]" - t773 = prims.slice_prim(t750, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t773: "cuda:0 bf16[1, 32, 512, 128]" - t774 = prims.slice_prim(t773, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t774: "cuda:0 bf16[1, 32, 512, 64]" - t775 = prims.slice_prim(t773, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t775: "cuda:0 bf16[1, 32, 512, 64]" - t776 = prims.convert_element_type(t775, dtypes.float32) # t776: "cuda:0 f32[1, 32, 512, 64]" - t777 = prims.neg(t776) # t777: "cuda:0 f32[1, 32, 512, 64]" - t778 = prims.convert_element_type(t777, dtypes.bfloat16) # t778: "cuda:0 bf16[1, 32, 512, 64]" - t780 = prims.cat((t778, t774), -1) # t780: "cuda:0 bf16[1, 32, 512, 128]" - t781 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t781: "cuda:0 f32[1, 32, 512, 128]" - t782 = prims.convert_element_type(t773, dtypes.float32) # t782: "cuda:0 f32[1, 32, 512, 128]" - t783 = ltorch.mul(t782, t781) # t783: "cuda:0 f32[1, 32, 512, 128]" - # t783 = prims.mul(t782, t781) # t783: "cuda:0 f32[1, 32, 512, 128]" - t784 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t784: "cuda:0 f32[1, 32, 512, 128]" - t785 = prims.convert_element_type(t780, dtypes.float32) # t785: "cuda:0 f32[1, 32, 512, 128]" - t786 = ltorch.mul(t785, t784) # t786: "cuda:0 f32[1, 32, 512, 128]" - # t786 = prims.mul(t785, t784) # t786: "cuda:0 f32[1, 32, 512, 128]" - t787 = ltorch.add(t783, t786, alpha=None) # t787: "cuda:0 f32[1, 32, 512, 128]" - # t787 = prims.add(t783, t786) # t787: "cuda:0 f32[1, 32, 512, 128]" - t788 = prims.convert_element_type(t787, dtypes.bfloat16) # t788: "cuda:0 bf16[1, 32, 512, 128]" - t789 = prims.slice_prim(t744, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t789: "cuda:0 bf16[1, 32, 512, 0]" - t791 = prims.cat((t772, t789), -1) # t791: "cuda:0 bf16[1, 32, 512, 128]" - t792 = prims.slice_prim(t750, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t792: "cuda:0 bf16[1, 32, 512, 0]" - t794 = prims.cat((t788, t792), -1) # t794: "cuda:0 bf16[1, 32, 512, 128]" - (t795, t796, t797, t798) = cudnn_sdpa_fwd(t791, t794, t756, None, 0.0, True, scale=0.08838834764831843) - t801 = prims.transpose(t795, (0, 2, 1, 3)) # t801: "cuda:0 bf16[1, 512, 32, 128]" - t805 = prims.reshape(t801, (1, 512, 4096)) # t805: "cuda:0 bf16[1, 512, 4096]" - t806 = prims.linear(t805, t_transformer_h_5_attn_proj_weight, None) # t806: "cuda:0 bf16[1, 512, 4096]" - t807 = prims.convert_element_type(t806, dtypes.float32) # t807: "cuda:0 f32[1, 512, 4096]" - t808 = prims.convert_element_type(t704, dtypes.float32) # t808: "cuda:0 f32[1, 512, 4096]" - t809 = ltorch.add(t807, t808, alpha=None) # t809: "cuda:0 f32[1, 512, 4096]" - # t809 = prims.add(t807, t808) # t809: "cuda:0 f32[1, 512, 4096]" - t810 = prims.convert_element_type(t809, dtypes.bfloat16) # t810: "cuda:0 bf16[1, 512, 4096]" - t811 = prims.convert_element_type(t810, dtypes.float32) # t811: "cuda:0 f32[1, 512, 4096]" - t812 = ltorch.mul(t811, t811) # t812: "cuda:0 f32[1, 512, 4096]" - # t812 = prims.mul(t811, t811) # t812: "cuda:0 f32[1, 512, 4096]" - t814 = prims.sum(t812, (2,)) # t814: "cuda:0 f32[1, 512]" - t815 = prims.broadcast_in_dim(t814, [1, 512, 1], [0, 1]) # t815: "cuda:0 f32[1, 512, 1]" - t817 = ltorch.true_divide(t815, 4096.0) # t817: "cuda:0 f32[1, 512, 1]" - # t817 = prims.div(t815, 4096.0) # t817: "cuda:0 f32[1, 512, 1]" - t819 = ltorch.add(t817, 1e-05, alpha=None) # t819: "cuda:0 f32[1, 512, 1]" - # t819 = prims.add(t817, 1e-05) # t819: "cuda:0 f32[1, 512, 1]" - t820 = prims.rsqrt(t819) # t820: "cuda:0 f32[1, 512, 1]" - t821 = prims.broadcast_in_dim(t820, (1, 512, 4096), (0, 1, 2)) # t821: "cuda:0 f32[1, 512, 4096]" - t822 = ltorch.mul(t811, t821) # t822: "cuda:0 f32[1, 512, 4096]" - # t822 = prims.mul(t811, t821) # t822: "cuda:0 f32[1, 512, 4096]" - t823 = prims.convert_element_type(t822, dtypes.bfloat16) # t823: "cuda:0 bf16[1, 512, 4096]" - t824 = prims.broadcast_in_dim(t_transformer_h_5_norm_2_weight, (1, 512, 4096), (2,)) # t824: "cuda:0 bf16[1, 512, 4096]" - t825 = prims.convert_element_type(t823, dtypes.float32) # t825: "cuda:0 f32[1, 512, 4096]" - t826 = prims.convert_element_type(t824, dtypes.float32) # t826: "cuda:0 f32[1, 512, 4096]" - t827 = ltorch.mul(t825, t826) # t827: "cuda:0 f32[1, 512, 4096]" - # t827 = prims.mul(t825, t826) # t827: "cuda:0 f32[1, 512, 4096]" - t828 = prims.convert_element_type(t827, dtypes.bfloat16) # t828: "cuda:0 bf16[1, 512, 4096]" - t829 = prims.linear(t828, t_transformer_h_5_mlp_fc_1_weight, None) # t829: "cuda:0 bf16[1, 512, 11008]" - t830 = prims.linear(t828, t_transformer_h_5_mlp_fc_2_weight, None) # t830: "cuda:0 bf16[1, 512, 11008]" - t831 = prims.convert_element_type(t829, dtypes.float32) # t831: "cuda:0 f32[1, 512, 11008]" - t832 = prims.neg(t831) # t832: "cuda:0 f32[1, 512, 11008]" - t833 = prims.exp(t832) # t833: "cuda:0 f32[1, 512, 11008]" - t834 = ltorch.add(1.0, t833, alpha=None) # t834: "cuda:0 f32[1, 512, 11008]" - # t834 = prims.add(1.0, t833) # t834: "cuda:0 f32[1, 512, 11008]" - t835 = prims.reciprocal(t834) # t835: "cuda:0 f32[1, 512, 11008]" - t836 = prims.convert_element_type(t835, dtypes.bfloat16) # t836: "cuda:0 bf16[1, 512, 11008]" - t837 = prims.convert_element_type(t829, dtypes.float32) # t837: "cuda:0 f32[1, 512, 11008]" - t838 = prims.convert_element_type(t836, dtypes.float32) # t838: "cuda:0 f32[1, 512, 11008]" - t839 = ltorch.mul(t837, t838) # t839: "cuda:0 f32[1, 512, 11008]" - # t839 = prims.mul(t837, t838) # t839: "cuda:0 f32[1, 512, 11008]" - t840 = prims.convert_element_type(t839, dtypes.bfloat16) # t840: "cuda:0 bf16[1, 512, 11008]" - t841 = prims.convert_element_type(t840, dtypes.float32) # t841: "cuda:0 f32[1, 512, 11008]" - t842 = prims.convert_element_type(t830, dtypes.float32) # t842: "cuda:0 f32[1, 512, 11008]" - t843 = ltorch.mul(t841, t842) # t843: "cuda:0 f32[1, 512, 11008]" - # t843 = prims.mul(t841, t842) # t843: "cuda:0 f32[1, 512, 11008]" - t844 = prims.convert_element_type(t843, dtypes.bfloat16) # t844: "cuda:0 bf16[1, 512, 11008]" - t845 = prims.linear(t844, t_transformer_h_5_mlp_proj_weight, None) # t845: "cuda:0 bf16[1, 512, 4096]" - t846 = prims.convert_element_type(t845, dtypes.float32) # t846: "cuda:0 f32[1, 512, 4096]" - t847 = prims.convert_element_type(t810, dtypes.float32) # t847: "cuda:0 f32[1, 512, 4096]" - t848 = ltorch.add(t846, t847, alpha=None) # t848: "cuda:0 f32[1, 512, 4096]" - # t848 = prims.add(t846, t847) # t848: "cuda:0 f32[1, 512, 4096]" - t849 = prims.convert_element_type(t848, dtypes.bfloat16) # t849: "cuda:0 bf16[1, 512, 4096]" - t850 = prims.convert_element_type(t849, dtypes.float32) # t850: "cuda:0 f32[1, 512, 4096]" - t851 = ltorch.mul(t850, t850) # t851: "cuda:0 f32[1, 512, 4096]" - # t851 = prims.mul(t850, t850) # t851: "cuda:0 f32[1, 512, 4096]" - t853 = prims.sum(t851, (2,)) # t853: "cuda:0 f32[1, 512]" - t854 = prims.broadcast_in_dim(t853, [1, 512, 1], [0, 1]) # t854: "cuda:0 f32[1, 512, 1]" - t856 = ltorch.true_divide(t854, 4096.0) # t856: "cuda:0 f32[1, 512, 1]" - # t856 = prims.div(t854, 4096.0) # t856: "cuda:0 f32[1, 512, 1]" - t858 = ltorch.add(t856, 1e-05, alpha=None) # t858: "cuda:0 f32[1, 512, 1]" - # t858 = prims.add(t856, 1e-05) # t858: "cuda:0 f32[1, 512, 1]" - t859 = prims.rsqrt(t858) # t859: "cuda:0 f32[1, 512, 1]" - t860 = prims.broadcast_in_dim(t859, (1, 512, 4096), (0, 1, 2)) # t860: "cuda:0 f32[1, 512, 4096]" - t861 = ltorch.mul(t850, t860) # t861: "cuda:0 f32[1, 512, 4096]" - # t861 = prims.mul(t850, t860) # t861: "cuda:0 f32[1, 512, 4096]" - t862 = prims.convert_element_type(t861, dtypes.bfloat16) # t862: "cuda:0 bf16[1, 512, 4096]" - t863 = prims.broadcast_in_dim(t_transformer_h_6_norm_1_weight, (1, 512, 4096), (2,)) # t863: "cuda:0 bf16[1, 512, 4096]" - t864 = prims.convert_element_type(t862, dtypes.float32) # t864: "cuda:0 f32[1, 512, 4096]" - t865 = prims.convert_element_type(t863, dtypes.float32) # t865: "cuda:0 f32[1, 512, 4096]" - t866 = ltorch.mul(t864, t865) # t866: "cuda:0 f32[1, 512, 4096]" - # t866 = prims.mul(t864, t865) # t866: "cuda:0 f32[1, 512, 4096]" - t867 = prims.convert_element_type(t866, dtypes.bfloat16) # t867: "cuda:0 bf16[1, 512, 4096]" - t868 = prims.linear(t867, t_transformer_h_6_attn_attn_weight, None) # t868: "cuda:0 bf16[1, 512, 12288]" - t874 = prims.reshape(t868, (1, 512, 32, 3, 128)) # t874: "cuda:0 bf16[1, 512, 32, 3, 128]" - t880 = prims.transpose(t874, (0, 2, 3, 1, 4)) # t880: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t881, t882, t883) = ltorch.split(t880, (1, 1, 1), 2) - # t881 = prims.slice_prim(t880, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t881: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t882 = prims.slice_prim(t880, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t882: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t883 = prims.slice_prim(t880, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t883: "cuda:0 bf16[1, 32, 1, 512, 128]" - t889 = prims.reshape(t881, (1, 32, 512, 128)) # t889: "cuda:0 bf16[1, 32, 512, 128]" - t895 = prims.reshape(t882, (1, 32, 512, 128)) # t895: "cuda:0 bf16[1, 32, 512, 128]" - t901 = prims.reshape(t883, (1, 32, 512, 128)) # t901: "cuda:0 bf16[1, 32, 512, 128]" - t902 = prims.slice_prim(t889, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t902: "cuda:0 bf16[1, 32, 512, 128]" - t903 = prims.slice_prim(t902, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t903: "cuda:0 bf16[1, 32, 512, 64]" - t904 = prims.slice_prim(t902, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t904: "cuda:0 bf16[1, 32, 512, 64]" - t905 = prims.convert_element_type(t904, dtypes.float32) # t905: "cuda:0 f32[1, 32, 512, 64]" - t906 = prims.neg(t905) # t906: "cuda:0 f32[1, 32, 512, 64]" - t907 = prims.convert_element_type(t906, dtypes.bfloat16) # t907: "cuda:0 bf16[1, 32, 512, 64]" - t909 = prims.cat((t907, t903), -1) # t909: "cuda:0 bf16[1, 32, 512, 128]" - t910 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t910: "cuda:0 f32[1, 32, 512, 128]" - t911 = prims.convert_element_type(t902, dtypes.float32) # t911: "cuda:0 f32[1, 32, 512, 128]" - t912 = ltorch.mul(t911, t910) # t912: "cuda:0 f32[1, 32, 512, 128]" - # t912 = prims.mul(t911, t910) # t912: "cuda:0 f32[1, 32, 512, 128]" - t913 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t913: "cuda:0 f32[1, 32, 512, 128]" - t914 = prims.convert_element_type(t909, dtypes.float32) # t914: "cuda:0 f32[1, 32, 512, 128]" - t915 = ltorch.mul(t914, t913) # t915: "cuda:0 f32[1, 32, 512, 128]" - # t915 = prims.mul(t914, t913) # t915: "cuda:0 f32[1, 32, 512, 128]" - t916 = ltorch.add(t912, t915, alpha=None) # t916: "cuda:0 f32[1, 32, 512, 128]" - # t916 = prims.add(t912, t915) # t916: "cuda:0 f32[1, 32, 512, 128]" - t917 = prims.convert_element_type(t916, dtypes.bfloat16) # t917: "cuda:0 bf16[1, 32, 512, 128]" - t918 = prims.slice_prim(t895, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t918: "cuda:0 bf16[1, 32, 512, 128]" - t919 = prims.slice_prim(t918, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t919: "cuda:0 bf16[1, 32, 512, 64]" - t920 = prims.slice_prim(t918, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t920: "cuda:0 bf16[1, 32, 512, 64]" - t921 = prims.convert_element_type(t920, dtypes.float32) # t921: "cuda:0 f32[1, 32, 512, 64]" - t922 = prims.neg(t921) # t922: "cuda:0 f32[1, 32, 512, 64]" - t923 = prims.convert_element_type(t922, dtypes.bfloat16) # t923: "cuda:0 bf16[1, 32, 512, 64]" - t925 = prims.cat((t923, t919), -1) # t925: "cuda:0 bf16[1, 32, 512, 128]" - t926 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t926: "cuda:0 f32[1, 32, 512, 128]" - t927 = prims.convert_element_type(t918, dtypes.float32) # t927: "cuda:0 f32[1, 32, 512, 128]" - t928 = ltorch.mul(t927, t926) # t928: "cuda:0 f32[1, 32, 512, 128]" - # t928 = prims.mul(t927, t926) # t928: "cuda:0 f32[1, 32, 512, 128]" - t929 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t929: "cuda:0 f32[1, 32, 512, 128]" - t930 = prims.convert_element_type(t925, dtypes.float32) # t930: "cuda:0 f32[1, 32, 512, 128]" - t931 = ltorch.mul(t930, t929) # t931: "cuda:0 f32[1, 32, 512, 128]" - # t931 = prims.mul(t930, t929) # t931: "cuda:0 f32[1, 32, 512, 128]" - t932 = ltorch.add(t928, t931, alpha=None) # t932: "cuda:0 f32[1, 32, 512, 128]" - # t932 = prims.add(t928, t931) # t932: "cuda:0 f32[1, 32, 512, 128]" - t933 = prims.convert_element_type(t932, dtypes.bfloat16) # t933: "cuda:0 bf16[1, 32, 512, 128]" - t934 = prims.slice_prim(t889, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t934: "cuda:0 bf16[1, 32, 512, 0]" - t936 = prims.cat((t917, t934), -1) # t936: "cuda:0 bf16[1, 32, 512, 128]" - t937 = prims.slice_prim(t895, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t937: "cuda:0 bf16[1, 32, 512, 0]" - t939 = prims.cat((t933, t937), -1) # t939: "cuda:0 bf16[1, 32, 512, 128]" - (t940, t941, t942, t943) = cudnn_sdpa_fwd(t936, t939, t901, None, 0.0, True, scale=0.08838834764831843) - t946 = prims.transpose(t940, (0, 2, 1, 3)) # t946: "cuda:0 bf16[1, 512, 32, 128]" - t950 = prims.reshape(t946, (1, 512, 4096)) # t950: "cuda:0 bf16[1, 512, 4096]" - t951 = prims.linear(t950, t_transformer_h_6_attn_proj_weight, None) # t951: "cuda:0 bf16[1, 512, 4096]" - t952 = prims.convert_element_type(t951, dtypes.float32) # t952: "cuda:0 f32[1, 512, 4096]" - t953 = prims.convert_element_type(t849, dtypes.float32) # t953: "cuda:0 f32[1, 512, 4096]" - t954 = ltorch.add(t952, t953, alpha=None) # t954: "cuda:0 f32[1, 512, 4096]" - # t954 = prims.add(t952, t953) # t954: "cuda:0 f32[1, 512, 4096]" - t955 = prims.convert_element_type(t954, dtypes.bfloat16) # t955: "cuda:0 bf16[1, 512, 4096]" - t956 = prims.convert_element_type(t955, dtypes.float32) # t956: "cuda:0 f32[1, 512, 4096]" - t957 = ltorch.mul(t956, t956) # t957: "cuda:0 f32[1, 512, 4096]" - # t957 = prims.mul(t956, t956) # t957: "cuda:0 f32[1, 512, 4096]" - t959 = prims.sum(t957, (2,)) # t959: "cuda:0 f32[1, 512]" - t960 = prims.broadcast_in_dim(t959, [1, 512, 1], [0, 1]) # t960: "cuda:0 f32[1, 512, 1]" - t962 = ltorch.true_divide(t960, 4096.0) # t962: "cuda:0 f32[1, 512, 1]" - # t962 = prims.div(t960, 4096.0) # t962: "cuda:0 f32[1, 512, 1]" - t964 = ltorch.add(t962, 1e-05, alpha=None) # t964: "cuda:0 f32[1, 512, 1]" - # t964 = prims.add(t962, 1e-05) # t964: "cuda:0 f32[1, 512, 1]" - t965 = prims.rsqrt(t964) # t965: "cuda:0 f32[1, 512, 1]" - t966 = prims.broadcast_in_dim(t965, (1, 512, 4096), (0, 1, 2)) # t966: "cuda:0 f32[1, 512, 4096]" - t967 = ltorch.mul(t956, t966) # t967: "cuda:0 f32[1, 512, 4096]" - # t967 = prims.mul(t956, t966) # t967: "cuda:0 f32[1, 512, 4096]" - t968 = prims.convert_element_type(t967, dtypes.bfloat16) # t968: "cuda:0 bf16[1, 512, 4096]" - t969 = prims.broadcast_in_dim(t_transformer_h_6_norm_2_weight, (1, 512, 4096), (2,)) # t969: "cuda:0 bf16[1, 512, 4096]" - t970 = prims.convert_element_type(t968, dtypes.float32) # t970: "cuda:0 f32[1, 512, 4096]" - t971 = prims.convert_element_type(t969, dtypes.float32) # t971: "cuda:0 f32[1, 512, 4096]" - t972 = ltorch.mul(t970, t971) # t972: "cuda:0 f32[1, 512, 4096]" - # t972 = prims.mul(t970, t971) # t972: "cuda:0 f32[1, 512, 4096]" - t973 = prims.convert_element_type(t972, dtypes.bfloat16) # t973: "cuda:0 bf16[1, 512, 4096]" - t974 = prims.linear(t973, t_transformer_h_6_mlp_fc_1_weight, None) # t974: "cuda:0 bf16[1, 512, 11008]" - t975 = prims.linear(t973, t_transformer_h_6_mlp_fc_2_weight, None) # t975: "cuda:0 bf16[1, 512, 11008]" - t976 = prims.convert_element_type(t974, dtypes.float32) # t976: "cuda:0 f32[1, 512, 11008]" - t977 = prims.neg(t976) # t977: "cuda:0 f32[1, 512, 11008]" - t978 = prims.exp(t977) # t978: "cuda:0 f32[1, 512, 11008]" - t979 = ltorch.add(1.0, t978, alpha=None) # t979: "cuda:0 f32[1, 512, 11008]" - # t979 = prims.add(1.0, t978) # t979: "cuda:0 f32[1, 512, 11008]" - t980 = prims.reciprocal(t979) # t980: "cuda:0 f32[1, 512, 11008]" - t981 = prims.convert_element_type(t980, dtypes.bfloat16) # t981: "cuda:0 bf16[1, 512, 11008]" - t982 = prims.convert_element_type(t974, dtypes.float32) # t982: "cuda:0 f32[1, 512, 11008]" - t983 = prims.convert_element_type(t981, dtypes.float32) # t983: "cuda:0 f32[1, 512, 11008]" - t984 = ltorch.mul(t982, t983) # t984: "cuda:0 f32[1, 512, 11008]" - # t984 = prims.mul(t982, t983) # t984: "cuda:0 f32[1, 512, 11008]" - t985 = prims.convert_element_type(t984, dtypes.bfloat16) # t985: "cuda:0 bf16[1, 512, 11008]" - t986 = prims.convert_element_type(t985, dtypes.float32) # t986: "cuda:0 f32[1, 512, 11008]" - t987 = prims.convert_element_type(t975, dtypes.float32) # t987: "cuda:0 f32[1, 512, 11008]" - t988 = ltorch.mul(t986, t987) # t988: "cuda:0 f32[1, 512, 11008]" - # t988 = prims.mul(t986, t987) # t988: "cuda:0 f32[1, 512, 11008]" - t989 = prims.convert_element_type(t988, dtypes.bfloat16) # t989: "cuda:0 bf16[1, 512, 11008]" - t990 = prims.linear(t989, t_transformer_h_6_mlp_proj_weight, None) # t990: "cuda:0 bf16[1, 512, 4096]" - t991 = prims.convert_element_type(t990, dtypes.float32) # t991: "cuda:0 f32[1, 512, 4096]" - t992 = prims.convert_element_type(t955, dtypes.float32) # t992: "cuda:0 f32[1, 512, 4096]" - t993 = ltorch.add(t991, t992, alpha=None) # t993: "cuda:0 f32[1, 512, 4096]" - # t993 = prims.add(t991, t992) # t993: "cuda:0 f32[1, 512, 4096]" - t994 = prims.convert_element_type(t993, dtypes.bfloat16) # t994: "cuda:0 bf16[1, 512, 4096]" - t995 = prims.convert_element_type(t994, dtypes.float32) # t995: "cuda:0 f32[1, 512, 4096]" - t996 = ltorch.mul(t995, t995) # t996: "cuda:0 f32[1, 512, 4096]" - # t996 = prims.mul(t995, t995) # t996: "cuda:0 f32[1, 512, 4096]" - t998 = prims.sum(t996, (2,)) # t998: "cuda:0 f32[1, 512]" - t999 = prims.broadcast_in_dim(t998, [1, 512, 1], [0, 1]) # t999: "cuda:0 f32[1, 512, 1]" - t1001 = ltorch.true_divide(t999, 4096.0) # t1001: "cuda:0 f32[1, 512, 1]" - # t1001 = prims.div(t999, 4096.0) # t1001: "cuda:0 f32[1, 512, 1]" - t1003 = ltorch.add(t1001, 1e-05, alpha=None) # t1003: "cuda:0 f32[1, 512, 1]" - # t1003 = prims.add(t1001, 1e-05) # t1003: "cuda:0 f32[1, 512, 1]" - t1004 = prims.rsqrt(t1003) # t1004: "cuda:0 f32[1, 512, 1]" - t1005 = prims.broadcast_in_dim(t1004, (1, 512, 4096), (0, 1, 2)) # t1005: "cuda:0 f32[1, 512, 4096]" - t1006 = ltorch.mul(t995, t1005) # t1006: "cuda:0 f32[1, 512, 4096]" - # t1006 = prims.mul(t995, t1005) # t1006: "cuda:0 f32[1, 512, 4096]" - t1007 = prims.convert_element_type(t1006, dtypes.bfloat16) # t1007: "cuda:0 bf16[1, 512, 4096]" - t1008 = prims.broadcast_in_dim(t_transformer_h_7_norm_1_weight, (1, 512, 4096), (2,)) # t1008: "cuda:0 bf16[1, 512, 4096]" - t1009 = prims.convert_element_type(t1007, dtypes.float32) # t1009: "cuda:0 f32[1, 512, 4096]" - t1010 = prims.convert_element_type(t1008, dtypes.float32) # t1010: "cuda:0 f32[1, 512, 4096]" - t1011 = ltorch.mul(t1009, t1010) # t1011: "cuda:0 f32[1, 512, 4096]" - # t1011 = prims.mul(t1009, t1010) # t1011: "cuda:0 f32[1, 512, 4096]" - t1012 = prims.convert_element_type(t1011, dtypes.bfloat16) # t1012: "cuda:0 bf16[1, 512, 4096]" - t1013 = prims.linear(t1012, t_transformer_h_7_attn_attn_weight, None) # t1013: "cuda:0 bf16[1, 512, 12288]" - t1019 = prims.reshape(t1013, (1, 512, 32, 3, 128)) # t1019: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1025 = prims.transpose(t1019, (0, 2, 3, 1, 4)) # t1025: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1026, t1027, t1028) = ltorch.split(t1025, (1, 1, 1), 2) - # t1026 = prims.slice_prim(t1025, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1026: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1027 = prims.slice_prim(t1025, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1027: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1028 = prims.slice_prim(t1025, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1028: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1034 = prims.reshape(t1026, (1, 32, 512, 128)) # t1034: "cuda:0 bf16[1, 32, 512, 128]" - t1040 = prims.reshape(t1027, (1, 32, 512, 128)) # t1040: "cuda:0 bf16[1, 32, 512, 128]" - t1046 = prims.reshape(t1028, (1, 32, 512, 128)) # t1046: "cuda:0 bf16[1, 32, 512, 128]" - t1047 = prims.slice_prim(t1034, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1047: "cuda:0 bf16[1, 32, 512, 128]" - t1048 = prims.slice_prim(t1047, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1048: "cuda:0 bf16[1, 32, 512, 64]" - t1049 = prims.slice_prim(t1047, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1049: "cuda:0 bf16[1, 32, 512, 64]" - t1050 = prims.convert_element_type(t1049, dtypes.float32) # t1050: "cuda:0 f32[1, 32, 512, 64]" - t1051 = prims.neg(t1050) # t1051: "cuda:0 f32[1, 32, 512, 64]" - t1052 = prims.convert_element_type(t1051, dtypes.bfloat16) # t1052: "cuda:0 bf16[1, 32, 512, 64]" - t1054 = prims.cat((t1052, t1048), -1) # t1054: "cuda:0 bf16[1, 32, 512, 128]" - t1055 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1055: "cuda:0 f32[1, 32, 512, 128]" - t1056 = prims.convert_element_type(t1047, dtypes.float32) # t1056: "cuda:0 f32[1, 32, 512, 128]" - t1057 = ltorch.mul(t1056, t1055) # t1057: "cuda:0 f32[1, 32, 512, 128]" - # t1057 = prims.mul(t1056, t1055) # t1057: "cuda:0 f32[1, 32, 512, 128]" - t1058 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1058: "cuda:0 f32[1, 32, 512, 128]" - t1059 = prims.convert_element_type(t1054, dtypes.float32) # t1059: "cuda:0 f32[1, 32, 512, 128]" - t1060 = ltorch.mul(t1059, t1058) # t1060: "cuda:0 f32[1, 32, 512, 128]" - # t1060 = prims.mul(t1059, t1058) # t1060: "cuda:0 f32[1, 32, 512, 128]" - t1061 = ltorch.add(t1057, t1060, alpha=None) # t1061: "cuda:0 f32[1, 32, 512, 128]" - # t1061 = prims.add(t1057, t1060) # t1061: "cuda:0 f32[1, 32, 512, 128]" - t1062 = prims.convert_element_type(t1061, dtypes.bfloat16) # t1062: "cuda:0 bf16[1, 32, 512, 128]" - t1063 = prims.slice_prim(t1040, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1063: "cuda:0 bf16[1, 32, 512, 128]" - t1064 = prims.slice_prim(t1063, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1064: "cuda:0 bf16[1, 32, 512, 64]" - t1065 = prims.slice_prim(t1063, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1065: "cuda:0 bf16[1, 32, 512, 64]" - t1066 = prims.convert_element_type(t1065, dtypes.float32) # t1066: "cuda:0 f32[1, 32, 512, 64]" - t1067 = prims.neg(t1066) # t1067: "cuda:0 f32[1, 32, 512, 64]" - t1068 = prims.convert_element_type(t1067, dtypes.bfloat16) # t1068: "cuda:0 bf16[1, 32, 512, 64]" - t1070 = prims.cat((t1068, t1064), -1) # t1070: "cuda:0 bf16[1, 32, 512, 128]" - t1071 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1071: "cuda:0 f32[1, 32, 512, 128]" - t1072 = prims.convert_element_type(t1063, dtypes.float32) # t1072: "cuda:0 f32[1, 32, 512, 128]" - t1073 = ltorch.mul(t1072, t1071) # t1073: "cuda:0 f32[1, 32, 512, 128]" - # t1073 = prims.mul(t1072, t1071) # t1073: "cuda:0 f32[1, 32, 512, 128]" - t1074 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1074: "cuda:0 f32[1, 32, 512, 128]" - t1075 = prims.convert_element_type(t1070, dtypes.float32) # t1075: "cuda:0 f32[1, 32, 512, 128]" - t1076 = ltorch.mul(t1075, t1074) # t1076: "cuda:0 f32[1, 32, 512, 128]" - # t1076 = prims.mul(t1075, t1074) # t1076: "cuda:0 f32[1, 32, 512, 128]" - t1077 = ltorch.add(t1073, t1076, alpha=None) # t1077: "cuda:0 f32[1, 32, 512, 128]" - # t1077 = prims.add(t1073, t1076) # t1077: "cuda:0 f32[1, 32, 512, 128]" - t1078 = prims.convert_element_type(t1077, dtypes.bfloat16) # t1078: "cuda:0 bf16[1, 32, 512, 128]" - t1079 = prims.slice_prim(t1034, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1079: "cuda:0 bf16[1, 32, 512, 0]" - t1081 = prims.cat((t1062, t1079), -1) # t1081: "cuda:0 bf16[1, 32, 512, 128]" - t1082 = prims.slice_prim(t1040, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1082: "cuda:0 bf16[1, 32, 512, 0]" - t1084 = prims.cat((t1078, t1082), -1) # t1084: "cuda:0 bf16[1, 32, 512, 128]" - (t1085, t1086, t1087, t1088) = cudnn_sdpa_fwd(t1081, t1084, t1046, None, 0.0, True, scale=0.08838834764831843) - t1091 = prims.transpose(t1085, (0, 2, 1, 3)) # t1091: "cuda:0 bf16[1, 512, 32, 128]" - t1095 = prims.reshape(t1091, (1, 512, 4096)) # t1095: "cuda:0 bf16[1, 512, 4096]" - t1096 = prims.linear(t1095, t_transformer_h_7_attn_proj_weight, None) # t1096: "cuda:0 bf16[1, 512, 4096]" - t1097 = prims.convert_element_type(t1096, dtypes.float32) # t1097: "cuda:0 f32[1, 512, 4096]" - t1098 = prims.convert_element_type(t994, dtypes.float32) # t1098: "cuda:0 f32[1, 512, 4096]" - t1099 = ltorch.add(t1097, t1098, alpha=None) # t1099: "cuda:0 f32[1, 512, 4096]" - # t1099 = prims.add(t1097, t1098) # t1099: "cuda:0 f32[1, 512, 4096]" - t1100 = prims.convert_element_type(t1099, dtypes.bfloat16) # t1100: "cuda:0 bf16[1, 512, 4096]" - t1101 = prims.convert_element_type(t1100, dtypes.float32) # t1101: "cuda:0 f32[1, 512, 4096]" - t1102 = ltorch.mul(t1101, t1101) # t1102: "cuda:0 f32[1, 512, 4096]" - # t1102 = prims.mul(t1101, t1101) # t1102: "cuda:0 f32[1, 512, 4096]" - t1104 = prims.sum(t1102, (2,)) # t1104: "cuda:0 f32[1, 512]" - t1105 = prims.broadcast_in_dim(t1104, [1, 512, 1], [0, 1]) # t1105: "cuda:0 f32[1, 512, 1]" - t1107 = ltorch.true_divide(t1105, 4096.0) # t1107: "cuda:0 f32[1, 512, 1]" - # t1107 = prims.div(t1105, 4096.0) # t1107: "cuda:0 f32[1, 512, 1]" - t1109 = ltorch.add(t1107, 1e-05, alpha=None) # t1109: "cuda:0 f32[1, 512, 1]" - # t1109 = prims.add(t1107, 1e-05) # t1109: "cuda:0 f32[1, 512, 1]" - t1110 = prims.rsqrt(t1109) # t1110: "cuda:0 f32[1, 512, 1]" - t1111 = prims.broadcast_in_dim(t1110, (1, 512, 4096), (0, 1, 2)) # t1111: "cuda:0 f32[1, 512, 4096]" - t1112 = ltorch.mul(t1101, t1111) # t1112: "cuda:0 f32[1, 512, 4096]" - # t1112 = prims.mul(t1101, t1111) # t1112: "cuda:0 f32[1, 512, 4096]" - t1113 = prims.convert_element_type(t1112, dtypes.bfloat16) # t1113: "cuda:0 bf16[1, 512, 4096]" - t1114 = prims.broadcast_in_dim(t_transformer_h_7_norm_2_weight, (1, 512, 4096), (2,)) # t1114: "cuda:0 bf16[1, 512, 4096]" - t1115 = prims.convert_element_type(t1113, dtypes.float32) # t1115: "cuda:0 f32[1, 512, 4096]" - t1116 = prims.convert_element_type(t1114, dtypes.float32) # t1116: "cuda:0 f32[1, 512, 4096]" - t1117 = ltorch.mul(t1115, t1116) # t1117: "cuda:0 f32[1, 512, 4096]" - # t1117 = prims.mul(t1115, t1116) # t1117: "cuda:0 f32[1, 512, 4096]" - t1118 = prims.convert_element_type(t1117, dtypes.bfloat16) # t1118: "cuda:0 bf16[1, 512, 4096]" - t1119 = prims.linear(t1118, t_transformer_h_7_mlp_fc_1_weight, None) # t1119: "cuda:0 bf16[1, 512, 11008]" - t1120 = prims.linear(t1118, t_transformer_h_7_mlp_fc_2_weight, None) # t1120: "cuda:0 bf16[1, 512, 11008]" - t1121 = prims.convert_element_type(t1119, dtypes.float32) # t1121: "cuda:0 f32[1, 512, 11008]" - t1122 = prims.neg(t1121) # t1122: "cuda:0 f32[1, 512, 11008]" - t1123 = prims.exp(t1122) # t1123: "cuda:0 f32[1, 512, 11008]" - t1124 = ltorch.add(1.0, t1123, alpha=None) # t1124: "cuda:0 f32[1, 512, 11008]" - # t1124 = prims.add(1.0, t1123) # t1124: "cuda:0 f32[1, 512, 11008]" - t1125 = prims.reciprocal(t1124) # t1125: "cuda:0 f32[1, 512, 11008]" - t1126 = prims.convert_element_type(t1125, dtypes.bfloat16) # t1126: "cuda:0 bf16[1, 512, 11008]" - t1127 = prims.convert_element_type(t1119, dtypes.float32) # t1127: "cuda:0 f32[1, 512, 11008]" - t1128 = prims.convert_element_type(t1126, dtypes.float32) # t1128: "cuda:0 f32[1, 512, 11008]" - t1129 = ltorch.mul(t1127, t1128) # t1129: "cuda:0 f32[1, 512, 11008]" - # t1129 = prims.mul(t1127, t1128) # t1129: "cuda:0 f32[1, 512, 11008]" - t1130 = prims.convert_element_type(t1129, dtypes.bfloat16) # t1130: "cuda:0 bf16[1, 512, 11008]" - t1131 = prims.convert_element_type(t1130, dtypes.float32) # t1131: "cuda:0 f32[1, 512, 11008]" - t1132 = prims.convert_element_type(t1120, dtypes.float32) # t1132: "cuda:0 f32[1, 512, 11008]" - t1133 = ltorch.mul(t1131, t1132) # t1133: "cuda:0 f32[1, 512, 11008]" - # t1133 = prims.mul(t1131, t1132) # t1133: "cuda:0 f32[1, 512, 11008]" - t1134 = prims.convert_element_type(t1133, dtypes.bfloat16) # t1134: "cuda:0 bf16[1, 512, 11008]" - t1135 = prims.linear(t1134, t_transformer_h_7_mlp_proj_weight, None) # t1135: "cuda:0 bf16[1, 512, 4096]" - t1136 = prims.convert_element_type(t1135, dtypes.float32) # t1136: "cuda:0 f32[1, 512, 4096]" - t1137 = prims.convert_element_type(t1100, dtypes.float32) # t1137: "cuda:0 f32[1, 512, 4096]" - t1138 = ltorch.add(t1136, t1137, alpha=None) # t1138: "cuda:0 f32[1, 512, 4096]" - # t1138 = prims.add(t1136, t1137) # t1138: "cuda:0 f32[1, 512, 4096]" - t1139 = prims.convert_element_type(t1138, dtypes.bfloat16) # t1139: "cuda:0 bf16[1, 512, 4096]" - t1140 = prims.convert_element_type(t1139, dtypes.float32) # t1140: "cuda:0 f32[1, 512, 4096]" - t1141 = ltorch.mul(t1140, t1140) # t1141: "cuda:0 f32[1, 512, 4096]" - # t1141 = prims.mul(t1140, t1140) # t1141: "cuda:0 f32[1, 512, 4096]" - t1143 = prims.sum(t1141, (2,)) # t1143: "cuda:0 f32[1, 512]" - t1144 = prims.broadcast_in_dim(t1143, [1, 512, 1], [0, 1]) # t1144: "cuda:0 f32[1, 512, 1]" - t1146 = ltorch.true_divide(t1144, 4096.0) # t1146: "cuda:0 f32[1, 512, 1]" - # t1146 = prims.div(t1144, 4096.0) # t1146: "cuda:0 f32[1, 512, 1]" - t1148 = ltorch.add(t1146, 1e-05, alpha=None) # t1148: "cuda:0 f32[1, 512, 1]" - # t1148 = prims.add(t1146, 1e-05) # t1148: "cuda:0 f32[1, 512, 1]" - t1149 = prims.rsqrt(t1148) # t1149: "cuda:0 f32[1, 512, 1]" - t1150 = prims.broadcast_in_dim(t1149, (1, 512, 4096), (0, 1, 2)) # t1150: "cuda:0 f32[1, 512, 4096]" - t1151 = ltorch.mul(t1140, t1150) # t1151: "cuda:0 f32[1, 512, 4096]" - # t1151 = prims.mul(t1140, t1150) # t1151: "cuda:0 f32[1, 512, 4096]" - t1152 = prims.convert_element_type(t1151, dtypes.bfloat16) # t1152: "cuda:0 bf16[1, 512, 4096]" - t1153 = prims.broadcast_in_dim(t_transformer_h_8_norm_1_weight, (1, 512, 4096), (2,)) # t1153: "cuda:0 bf16[1, 512, 4096]" - t1154 = prims.convert_element_type(t1152, dtypes.float32) # t1154: "cuda:0 f32[1, 512, 4096]" - t1155 = prims.convert_element_type(t1153, dtypes.float32) # t1155: "cuda:0 f32[1, 512, 4096]" - t1156 = ltorch.mul(t1154, t1155) # t1156: "cuda:0 f32[1, 512, 4096]" - # t1156 = prims.mul(t1154, t1155) # t1156: "cuda:0 f32[1, 512, 4096]" - t1157 = prims.convert_element_type(t1156, dtypes.bfloat16) # t1157: "cuda:0 bf16[1, 512, 4096]" - t1158 = prims.linear(t1157, t_transformer_h_8_attn_attn_weight, None) # t1158: "cuda:0 bf16[1, 512, 12288]" - t1164 = prims.reshape(t1158, (1, 512, 32, 3, 128)) # t1164: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1170 = prims.transpose(t1164, (0, 2, 3, 1, 4)) # t1170: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1171, t1172, t1173) = ltorch.split(t1170, (1, 1, 1), 2) - # t1171 = prims.slice_prim(t1170, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1171: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1172 = prims.slice_prim(t1170, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1172: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1173 = prims.slice_prim(t1170, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1173: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1179 = prims.reshape(t1171, (1, 32, 512, 128)) # t1179: "cuda:0 bf16[1, 32, 512, 128]" - t1185 = prims.reshape(t1172, (1, 32, 512, 128)) # t1185: "cuda:0 bf16[1, 32, 512, 128]" - t1191 = prims.reshape(t1173, (1, 32, 512, 128)) # t1191: "cuda:0 bf16[1, 32, 512, 128]" - t1192 = prims.slice_prim(t1179, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1192: "cuda:0 bf16[1, 32, 512, 128]" - t1193 = prims.slice_prim(t1192, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1193: "cuda:0 bf16[1, 32, 512, 64]" - t1194 = prims.slice_prim(t1192, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1194: "cuda:0 bf16[1, 32, 512, 64]" - t1195 = prims.convert_element_type(t1194, dtypes.float32) # t1195: "cuda:0 f32[1, 32, 512, 64]" - t1196 = prims.neg(t1195) # t1196: "cuda:0 f32[1, 32, 512, 64]" - t1197 = prims.convert_element_type(t1196, dtypes.bfloat16) # t1197: "cuda:0 bf16[1, 32, 512, 64]" - t1199 = prims.cat((t1197, t1193), -1) # t1199: "cuda:0 bf16[1, 32, 512, 128]" - t1200 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1200: "cuda:0 f32[1, 32, 512, 128]" - t1201 = prims.convert_element_type(t1192, dtypes.float32) # t1201: "cuda:0 f32[1, 32, 512, 128]" - t1202 = ltorch.mul(t1201, t1200) # t1202: "cuda:0 f32[1, 32, 512, 128]" - # t1202 = prims.mul(t1201, t1200) # t1202: "cuda:0 f32[1, 32, 512, 128]" - t1203 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1203: "cuda:0 f32[1, 32, 512, 128]" - t1204 = prims.convert_element_type(t1199, dtypes.float32) # t1204: "cuda:0 f32[1, 32, 512, 128]" - t1205 = ltorch.mul(t1204, t1203) # t1205: "cuda:0 f32[1, 32, 512, 128]" - # t1205 = prims.mul(t1204, t1203) # t1205: "cuda:0 f32[1, 32, 512, 128]" - t1206 = ltorch.add(t1202, t1205, alpha=None) # t1206: "cuda:0 f32[1, 32, 512, 128]" - # t1206 = prims.add(t1202, t1205) # t1206: "cuda:0 f32[1, 32, 512, 128]" - t1207 = prims.convert_element_type(t1206, dtypes.bfloat16) # t1207: "cuda:0 bf16[1, 32, 512, 128]" - t1208 = prims.slice_prim(t1185, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1208: "cuda:0 bf16[1, 32, 512, 128]" - t1209 = prims.slice_prim(t1208, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1209: "cuda:0 bf16[1, 32, 512, 64]" - t1210 = prims.slice_prim(t1208, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1210: "cuda:0 bf16[1, 32, 512, 64]" - t1211 = prims.convert_element_type(t1210, dtypes.float32) # t1211: "cuda:0 f32[1, 32, 512, 64]" - t1212 = prims.neg(t1211) # t1212: "cuda:0 f32[1, 32, 512, 64]" - t1213 = prims.convert_element_type(t1212, dtypes.bfloat16) # t1213: "cuda:0 bf16[1, 32, 512, 64]" - t1215 = prims.cat((t1213, t1209), -1) # t1215: "cuda:0 bf16[1, 32, 512, 128]" - t1216 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1216: "cuda:0 f32[1, 32, 512, 128]" - t1217 = prims.convert_element_type(t1208, dtypes.float32) # t1217: "cuda:0 f32[1, 32, 512, 128]" - t1218 = ltorch.mul(t1217, t1216) # t1218: "cuda:0 f32[1, 32, 512, 128]" - # t1218 = prims.mul(t1217, t1216) # t1218: "cuda:0 f32[1, 32, 512, 128]" - t1219 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1219: "cuda:0 f32[1, 32, 512, 128]" - t1220 = prims.convert_element_type(t1215, dtypes.float32) # t1220: "cuda:0 f32[1, 32, 512, 128]" - t1221 = ltorch.mul(t1220, t1219) # t1221: "cuda:0 f32[1, 32, 512, 128]" - # t1221 = prims.mul(t1220, t1219) # t1221: "cuda:0 f32[1, 32, 512, 128]" - t1222 = ltorch.add(t1218, t1221, alpha=None) # t1222: "cuda:0 f32[1, 32, 512, 128]" - # t1222 = prims.add(t1218, t1221) # t1222: "cuda:0 f32[1, 32, 512, 128]" - t1223 = prims.convert_element_type(t1222, dtypes.bfloat16) # t1223: "cuda:0 bf16[1, 32, 512, 128]" - t1224 = prims.slice_prim(t1179, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1224: "cuda:0 bf16[1, 32, 512, 0]" - t1226 = prims.cat((t1207, t1224), -1) # t1226: "cuda:0 bf16[1, 32, 512, 128]" - t1227 = prims.slice_prim(t1185, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1227: "cuda:0 bf16[1, 32, 512, 0]" - t1229 = prims.cat((t1223, t1227), -1) # t1229: "cuda:0 bf16[1, 32, 512, 128]" - (t1230, t1231, t1232, t1233) = cudnn_sdpa_fwd(t1226, t1229, t1191, None, 0.0, True, scale=0.08838834764831843) - t1236 = prims.transpose(t1230, (0, 2, 1, 3)) # t1236: "cuda:0 bf16[1, 512, 32, 128]" - t1240 = prims.reshape(t1236, (1, 512, 4096)) # t1240: "cuda:0 bf16[1, 512, 4096]" - t1241 = prims.linear(t1240, t_transformer_h_8_attn_proj_weight, None) # t1241: "cuda:0 bf16[1, 512, 4096]" - t1242 = prims.convert_element_type(t1241, dtypes.float32) # t1242: "cuda:0 f32[1, 512, 4096]" - t1243 = prims.convert_element_type(t1139, dtypes.float32) # t1243: "cuda:0 f32[1, 512, 4096]" - t1244 = ltorch.add(t1242, t1243, alpha=None) # t1244: "cuda:0 f32[1, 512, 4096]" - # t1244 = prims.add(t1242, t1243) # t1244: "cuda:0 f32[1, 512, 4096]" - t1245 = prims.convert_element_type(t1244, dtypes.bfloat16) # t1245: "cuda:0 bf16[1, 512, 4096]" - t1246 = prims.convert_element_type(t1245, dtypes.float32) # t1246: "cuda:0 f32[1, 512, 4096]" - t1247 = ltorch.mul(t1246, t1246) # t1247: "cuda:0 f32[1, 512, 4096]" - # t1247 = prims.mul(t1246, t1246) # t1247: "cuda:0 f32[1, 512, 4096]" - t1249 = prims.sum(t1247, (2,)) # t1249: "cuda:0 f32[1, 512]" - t1250 = prims.broadcast_in_dim(t1249, [1, 512, 1], [0, 1]) # t1250: "cuda:0 f32[1, 512, 1]" - t1252 = ltorch.true_divide(t1250, 4096.0) # t1252: "cuda:0 f32[1, 512, 1]" - # t1252 = prims.div(t1250, 4096.0) # t1252: "cuda:0 f32[1, 512, 1]" - t1254 = ltorch.add(t1252, 1e-05, alpha=None) # t1254: "cuda:0 f32[1, 512, 1]" - # t1254 = prims.add(t1252, 1e-05) # t1254: "cuda:0 f32[1, 512, 1]" - t1255 = prims.rsqrt(t1254) # t1255: "cuda:0 f32[1, 512, 1]" - t1256 = prims.broadcast_in_dim(t1255, (1, 512, 4096), (0, 1, 2)) # t1256: "cuda:0 f32[1, 512, 4096]" - t1257 = ltorch.mul(t1246, t1256) # t1257: "cuda:0 f32[1, 512, 4096]" - # t1257 = prims.mul(t1246, t1256) # t1257: "cuda:0 f32[1, 512, 4096]" - t1258 = prims.convert_element_type(t1257, dtypes.bfloat16) # t1258: "cuda:0 bf16[1, 512, 4096]" - t1259 = prims.broadcast_in_dim(t_transformer_h_8_norm_2_weight, (1, 512, 4096), (2,)) # t1259: "cuda:0 bf16[1, 512, 4096]" - t1260 = prims.convert_element_type(t1258, dtypes.float32) # t1260: "cuda:0 f32[1, 512, 4096]" - t1261 = prims.convert_element_type(t1259, dtypes.float32) # t1261: "cuda:0 f32[1, 512, 4096]" - t1262 = ltorch.mul(t1260, t1261) # t1262: "cuda:0 f32[1, 512, 4096]" - # t1262 = prims.mul(t1260, t1261) # t1262: "cuda:0 f32[1, 512, 4096]" - t1263 = prims.convert_element_type(t1262, dtypes.bfloat16) # t1263: "cuda:0 bf16[1, 512, 4096]" - t1264 = prims.linear(t1263, t_transformer_h_8_mlp_fc_1_weight, None) # t1264: "cuda:0 bf16[1, 512, 11008]" - t1265 = prims.linear(t1263, t_transformer_h_8_mlp_fc_2_weight, None) # t1265: "cuda:0 bf16[1, 512, 11008]" - t1266 = prims.convert_element_type(t1264, dtypes.float32) # t1266: "cuda:0 f32[1, 512, 11008]" - t1267 = prims.neg(t1266) # t1267: "cuda:0 f32[1, 512, 11008]" - t1268 = prims.exp(t1267) # t1268: "cuda:0 f32[1, 512, 11008]" - t1269 = ltorch.add(1.0, t1268, alpha=None) # t1269: "cuda:0 f32[1, 512, 11008]" - # t1269 = prims.add(1.0, t1268) # t1269: "cuda:0 f32[1, 512, 11008]" - t1270 = prims.reciprocal(t1269) # t1270: "cuda:0 f32[1, 512, 11008]" - t1271 = prims.convert_element_type(t1270, dtypes.bfloat16) # t1271: "cuda:0 bf16[1, 512, 11008]" - t1272 = prims.convert_element_type(t1264, dtypes.float32) # t1272: "cuda:0 f32[1, 512, 11008]" - t1273 = prims.convert_element_type(t1271, dtypes.float32) # t1273: "cuda:0 f32[1, 512, 11008]" - t1274 = ltorch.mul(t1272, t1273) # t1274: "cuda:0 f32[1, 512, 11008]" - # t1274 = prims.mul(t1272, t1273) # t1274: "cuda:0 f32[1, 512, 11008]" - t1275 = prims.convert_element_type(t1274, dtypes.bfloat16) # t1275: "cuda:0 bf16[1, 512, 11008]" - t1276 = prims.convert_element_type(t1275, dtypes.float32) # t1276: "cuda:0 f32[1, 512, 11008]" - t1277 = prims.convert_element_type(t1265, dtypes.float32) # t1277: "cuda:0 f32[1, 512, 11008]" - t1278 = ltorch.mul(t1276, t1277) # t1278: "cuda:0 f32[1, 512, 11008]" - # t1278 = prims.mul(t1276, t1277) # t1278: "cuda:0 f32[1, 512, 11008]" - t1279 = prims.convert_element_type(t1278, dtypes.bfloat16) # t1279: "cuda:0 bf16[1, 512, 11008]" - t1280 = prims.linear(t1279, t_transformer_h_8_mlp_proj_weight, None) # t1280: "cuda:0 bf16[1, 512, 4096]" - t1281 = prims.convert_element_type(t1280, dtypes.float32) # t1281: "cuda:0 f32[1, 512, 4096]" - t1282 = prims.convert_element_type(t1245, dtypes.float32) # t1282: "cuda:0 f32[1, 512, 4096]" - t1283 = ltorch.add(t1281, t1282, alpha=None) # t1283: "cuda:0 f32[1, 512, 4096]" - # t1283 = prims.add(t1281, t1282) # t1283: "cuda:0 f32[1, 512, 4096]" - t1284 = prims.convert_element_type(t1283, dtypes.bfloat16) # t1284: "cuda:0 bf16[1, 512, 4096]" - t1285 = prims.convert_element_type(t1284, dtypes.float32) # t1285: "cuda:0 f32[1, 512, 4096]" - t1286 = ltorch.mul(t1285, t1285) # t1286: "cuda:0 f32[1, 512, 4096]" - # t1286 = prims.mul(t1285, t1285) # t1286: "cuda:0 f32[1, 512, 4096]" - t1288 = prims.sum(t1286, (2,)) # t1288: "cuda:0 f32[1, 512]" - t1289 = prims.broadcast_in_dim(t1288, [1, 512, 1], [0, 1]) # t1289: "cuda:0 f32[1, 512, 1]" - t1291 = ltorch.true_divide(t1289, 4096.0) # t1291: "cuda:0 f32[1, 512, 1]" - # t1291 = prims.div(t1289, 4096.0) # t1291: "cuda:0 f32[1, 512, 1]" - t1293 = ltorch.add(t1291, 1e-05, alpha=None) # t1293: "cuda:0 f32[1, 512, 1]" - # t1293 = prims.add(t1291, 1e-05) # t1293: "cuda:0 f32[1, 512, 1]" - t1294 = prims.rsqrt(t1293) # t1294: "cuda:0 f32[1, 512, 1]" - t1295 = prims.broadcast_in_dim(t1294, (1, 512, 4096), (0, 1, 2)) # t1295: "cuda:0 f32[1, 512, 4096]" - t1296 = ltorch.mul(t1285, t1295) # t1296: "cuda:0 f32[1, 512, 4096]" - # t1296 = prims.mul(t1285, t1295) # t1296: "cuda:0 f32[1, 512, 4096]" - t1297 = prims.convert_element_type(t1296, dtypes.bfloat16) # t1297: "cuda:0 bf16[1, 512, 4096]" - t1298 = prims.broadcast_in_dim(t_transformer_h_9_norm_1_weight, (1, 512, 4096), (2,)) # t1298: "cuda:0 bf16[1, 512, 4096]" - t1299 = prims.convert_element_type(t1297, dtypes.float32) # t1299: "cuda:0 f32[1, 512, 4096]" - t1300 = prims.convert_element_type(t1298, dtypes.float32) # t1300: "cuda:0 f32[1, 512, 4096]" - t1301 = ltorch.mul(t1299, t1300) # t1301: "cuda:0 f32[1, 512, 4096]" - # t1301 = prims.mul(t1299, t1300) # t1301: "cuda:0 f32[1, 512, 4096]" - t1302 = prims.convert_element_type(t1301, dtypes.bfloat16) # t1302: "cuda:0 bf16[1, 512, 4096]" - t1303 = prims.linear(t1302, t_transformer_h_9_attn_attn_weight, None) # t1303: "cuda:0 bf16[1, 512, 12288]" - t1309 = prims.reshape(t1303, (1, 512, 32, 3, 128)) # t1309: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1315 = prims.transpose(t1309, (0, 2, 3, 1, 4)) # t1315: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1316, t1317, t1318) = ltorch.split(t1315, (1, 1, 1), 2) - # t1316 = prims.slice_prim(t1315, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1316: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1317 = prims.slice_prim(t1315, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1317: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1318 = prims.slice_prim(t1315, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1318: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1324 = prims.reshape(t1316, (1, 32, 512, 128)) # t1324: "cuda:0 bf16[1, 32, 512, 128]" - t1330 = prims.reshape(t1317, (1, 32, 512, 128)) # t1330: "cuda:0 bf16[1, 32, 512, 128]" - t1336 = prims.reshape(t1318, (1, 32, 512, 128)) # t1336: "cuda:0 bf16[1, 32, 512, 128]" - t1337 = prims.slice_prim(t1324, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1337: "cuda:0 bf16[1, 32, 512, 128]" - t1338 = prims.slice_prim(t1337, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1338: "cuda:0 bf16[1, 32, 512, 64]" - t1339 = prims.slice_prim(t1337, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1339: "cuda:0 bf16[1, 32, 512, 64]" - t1340 = prims.convert_element_type(t1339, dtypes.float32) # t1340: "cuda:0 f32[1, 32, 512, 64]" - t1341 = prims.neg(t1340) # t1341: "cuda:0 f32[1, 32, 512, 64]" - t1342 = prims.convert_element_type(t1341, dtypes.bfloat16) # t1342: "cuda:0 bf16[1, 32, 512, 64]" - t1344 = prims.cat((t1342, t1338), -1) # t1344: "cuda:0 bf16[1, 32, 512, 128]" - t1345 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1345: "cuda:0 f32[1, 32, 512, 128]" - t1346 = prims.convert_element_type(t1337, dtypes.float32) # t1346: "cuda:0 f32[1, 32, 512, 128]" - t1347 = ltorch.mul(t1346, t1345) # t1347: "cuda:0 f32[1, 32, 512, 128]" - # t1347 = prims.mul(t1346, t1345) # t1347: "cuda:0 f32[1, 32, 512, 128]" - t1348 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1348: "cuda:0 f32[1, 32, 512, 128]" - t1349 = prims.convert_element_type(t1344, dtypes.float32) # t1349: "cuda:0 f32[1, 32, 512, 128]" - t1350 = ltorch.mul(t1349, t1348) # t1350: "cuda:0 f32[1, 32, 512, 128]" - # t1350 = prims.mul(t1349, t1348) # t1350: "cuda:0 f32[1, 32, 512, 128]" - t1351 = ltorch.add(t1347, t1350, alpha=None) # t1351: "cuda:0 f32[1, 32, 512, 128]" - # t1351 = prims.add(t1347, t1350) # t1351: "cuda:0 f32[1, 32, 512, 128]" - t1352 = prims.convert_element_type(t1351, dtypes.bfloat16) # t1352: "cuda:0 bf16[1, 32, 512, 128]" - t1353 = prims.slice_prim(t1330, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1353: "cuda:0 bf16[1, 32, 512, 128]" - t1354 = prims.slice_prim(t1353, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1354: "cuda:0 bf16[1, 32, 512, 64]" - t1355 = prims.slice_prim(t1353, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1355: "cuda:0 bf16[1, 32, 512, 64]" - t1356 = prims.convert_element_type(t1355, dtypes.float32) # t1356: "cuda:0 f32[1, 32, 512, 64]" - t1357 = prims.neg(t1356) # t1357: "cuda:0 f32[1, 32, 512, 64]" - t1358 = prims.convert_element_type(t1357, dtypes.bfloat16) # t1358: "cuda:0 bf16[1, 32, 512, 64]" - t1360 = prims.cat((t1358, t1354), -1) # t1360: "cuda:0 bf16[1, 32, 512, 128]" - t1361 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1361: "cuda:0 f32[1, 32, 512, 128]" - t1362 = prims.convert_element_type(t1353, dtypes.float32) # t1362: "cuda:0 f32[1, 32, 512, 128]" - t1363 = ltorch.mul(t1362, t1361) # t1363: "cuda:0 f32[1, 32, 512, 128]" - # t1363 = prims.mul(t1362, t1361) # t1363: "cuda:0 f32[1, 32, 512, 128]" - t1364 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1364: "cuda:0 f32[1, 32, 512, 128]" - t1365 = prims.convert_element_type(t1360, dtypes.float32) # t1365: "cuda:0 f32[1, 32, 512, 128]" - t1366 = ltorch.mul(t1365, t1364) # t1366: "cuda:0 f32[1, 32, 512, 128]" - # t1366 = prims.mul(t1365, t1364) # t1366: "cuda:0 f32[1, 32, 512, 128]" - t1367 = ltorch.add(t1363, t1366, alpha=None) # t1367: "cuda:0 f32[1, 32, 512, 128]" - # t1367 = prims.add(t1363, t1366) # t1367: "cuda:0 f32[1, 32, 512, 128]" - t1368 = prims.convert_element_type(t1367, dtypes.bfloat16) # t1368: "cuda:0 bf16[1, 32, 512, 128]" - t1369 = prims.slice_prim(t1324, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1369: "cuda:0 bf16[1, 32, 512, 0]" - t1371 = prims.cat((t1352, t1369), -1) # t1371: "cuda:0 bf16[1, 32, 512, 128]" - t1372 = prims.slice_prim(t1330, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1372: "cuda:0 bf16[1, 32, 512, 0]" - t1374 = prims.cat((t1368, t1372), -1) # t1374: "cuda:0 bf16[1, 32, 512, 128]" - (t1375, t1376, t1377, t1378) = cudnn_sdpa_fwd(t1371, t1374, t1336, None, 0.0, True, scale=0.08838834764831843) - t1381 = prims.transpose(t1375, (0, 2, 1, 3)) # t1381: "cuda:0 bf16[1, 512, 32, 128]" - t1385 = prims.reshape(t1381, (1, 512, 4096)) # t1385: "cuda:0 bf16[1, 512, 4096]" - t1386 = prims.linear(t1385, t_transformer_h_9_attn_proj_weight, None) # t1386: "cuda:0 bf16[1, 512, 4096]" - t1387 = prims.convert_element_type(t1386, dtypes.float32) # t1387: "cuda:0 f32[1, 512, 4096]" - t1388 = prims.convert_element_type(t1284, dtypes.float32) # t1388: "cuda:0 f32[1, 512, 4096]" - t1389 = ltorch.add(t1387, t1388, alpha=None) # t1389: "cuda:0 f32[1, 512, 4096]" - # t1389 = prims.add(t1387, t1388) # t1389: "cuda:0 f32[1, 512, 4096]" - t1390 = prims.convert_element_type(t1389, dtypes.bfloat16) # t1390: "cuda:0 bf16[1, 512, 4096]" - t1391 = prims.convert_element_type(t1390, dtypes.float32) # t1391: "cuda:0 f32[1, 512, 4096]" - t1392 = ltorch.mul(t1391, t1391) # t1392: "cuda:0 f32[1, 512, 4096]" - # t1392 = prims.mul(t1391, t1391) # t1392: "cuda:0 f32[1, 512, 4096]" - t1394 = prims.sum(t1392, (2,)) # t1394: "cuda:0 f32[1, 512]" - t1395 = prims.broadcast_in_dim(t1394, [1, 512, 1], [0, 1]) # t1395: "cuda:0 f32[1, 512, 1]" - t1397 = ltorch.true_divide(t1395, 4096.0) # t1397: "cuda:0 f32[1, 512, 1]" - # t1397 = prims.div(t1395, 4096.0) # t1397: "cuda:0 f32[1, 512, 1]" - t1399 = ltorch.add(t1397, 1e-05, alpha=None) # t1399: "cuda:0 f32[1, 512, 1]" - # t1399 = prims.add(t1397, 1e-05) # t1399: "cuda:0 f32[1, 512, 1]" - t1400 = prims.rsqrt(t1399) # t1400: "cuda:0 f32[1, 512, 1]" - t1401 = prims.broadcast_in_dim(t1400, (1, 512, 4096), (0, 1, 2)) # t1401: "cuda:0 f32[1, 512, 4096]" - t1402 = ltorch.mul(t1391, t1401) # t1402: "cuda:0 f32[1, 512, 4096]" - # t1402 = prims.mul(t1391, t1401) # t1402: "cuda:0 f32[1, 512, 4096]" - t1403 = prims.convert_element_type(t1402, dtypes.bfloat16) # t1403: "cuda:0 bf16[1, 512, 4096]" - t1404 = prims.broadcast_in_dim(t_transformer_h_9_norm_2_weight, (1, 512, 4096), (2,)) # t1404: "cuda:0 bf16[1, 512, 4096]" - t1405 = prims.convert_element_type(t1403, dtypes.float32) # t1405: "cuda:0 f32[1, 512, 4096]" - t1406 = prims.convert_element_type(t1404, dtypes.float32) # t1406: "cuda:0 f32[1, 512, 4096]" - t1407 = ltorch.mul(t1405, t1406) # t1407: "cuda:0 f32[1, 512, 4096]" - # t1407 = prims.mul(t1405, t1406) # t1407: "cuda:0 f32[1, 512, 4096]" - t1408 = prims.convert_element_type(t1407, dtypes.bfloat16) # t1408: "cuda:0 bf16[1, 512, 4096]" - t1409 = prims.linear(t1408, t_transformer_h_9_mlp_fc_1_weight, None) # t1409: "cuda:0 bf16[1, 512, 11008]" - t1410 = prims.linear(t1408, t_transformer_h_9_mlp_fc_2_weight, None) # t1410: "cuda:0 bf16[1, 512, 11008]" - t1411 = prims.convert_element_type(t1409, dtypes.float32) # t1411: "cuda:0 f32[1, 512, 11008]" - t1412 = prims.neg(t1411) # t1412: "cuda:0 f32[1, 512, 11008]" - t1413 = prims.exp(t1412) # t1413: "cuda:0 f32[1, 512, 11008]" - t1414 = ltorch.add(1.0, t1413, alpha=None) # t1414: "cuda:0 f32[1, 512, 11008]" - # t1414 = prims.add(1.0, t1413) # t1414: "cuda:0 f32[1, 512, 11008]" - t1415 = prims.reciprocal(t1414) # t1415: "cuda:0 f32[1, 512, 11008]" - t1416 = prims.convert_element_type(t1415, dtypes.bfloat16) # t1416: "cuda:0 bf16[1, 512, 11008]" - t1417 = prims.convert_element_type(t1409, dtypes.float32) # t1417: "cuda:0 f32[1, 512, 11008]" - t1418 = prims.convert_element_type(t1416, dtypes.float32) # t1418: "cuda:0 f32[1, 512, 11008]" - t1419 = ltorch.mul(t1417, t1418) # t1419: "cuda:0 f32[1, 512, 11008]" - # t1419 = prims.mul(t1417, t1418) # t1419: "cuda:0 f32[1, 512, 11008]" - t1420 = prims.convert_element_type(t1419, dtypes.bfloat16) # t1420: "cuda:0 bf16[1, 512, 11008]" - t1421 = prims.convert_element_type(t1420, dtypes.float32) # t1421: "cuda:0 f32[1, 512, 11008]" - t1422 = prims.convert_element_type(t1410, dtypes.float32) # t1422: "cuda:0 f32[1, 512, 11008]" - t1423 = ltorch.mul(t1421, t1422) # t1423: "cuda:0 f32[1, 512, 11008]" - # t1423 = prims.mul(t1421, t1422) # t1423: "cuda:0 f32[1, 512, 11008]" - t1424 = prims.convert_element_type(t1423, dtypes.bfloat16) # t1424: "cuda:0 bf16[1, 512, 11008]" - t1425 = prims.linear(t1424, t_transformer_h_9_mlp_proj_weight, None) # t1425: "cuda:0 bf16[1, 512, 4096]" - t1426 = prims.convert_element_type(t1425, dtypes.float32) # t1426: "cuda:0 f32[1, 512, 4096]" - t1427 = prims.convert_element_type(t1390, dtypes.float32) # t1427: "cuda:0 f32[1, 512, 4096]" - t1428 = ltorch.add(t1426, t1427, alpha=None) # t1428: "cuda:0 f32[1, 512, 4096]" - # t1428 = prims.add(t1426, t1427) # t1428: "cuda:0 f32[1, 512, 4096]" - t1429 = prims.convert_element_type(t1428, dtypes.bfloat16) # t1429: "cuda:0 bf16[1, 512, 4096]" - t1430 = prims.convert_element_type(t1429, dtypes.float32) # t1430: "cuda:0 f32[1, 512, 4096]" - t1431 = ltorch.mul(t1430, t1430) # t1431: "cuda:0 f32[1, 512, 4096]" - # t1431 = prims.mul(t1430, t1430) # t1431: "cuda:0 f32[1, 512, 4096]" - t1433 = prims.sum(t1431, (2,)) # t1433: "cuda:0 f32[1, 512]" - t1434 = prims.broadcast_in_dim(t1433, [1, 512, 1], [0, 1]) # t1434: "cuda:0 f32[1, 512, 1]" - t1436 = ltorch.true_divide(t1434, 4096.0) # t1436: "cuda:0 f32[1, 512, 1]" - # t1436 = prims.div(t1434, 4096.0) # t1436: "cuda:0 f32[1, 512, 1]" - t1438 = ltorch.add(t1436, 1e-05, alpha=None) # t1438: "cuda:0 f32[1, 512, 1]" - # t1438 = prims.add(t1436, 1e-05) # t1438: "cuda:0 f32[1, 512, 1]" - t1439 = prims.rsqrt(t1438) # t1439: "cuda:0 f32[1, 512, 1]" - t1440 = prims.broadcast_in_dim(t1439, (1, 512, 4096), (0, 1, 2)) # t1440: "cuda:0 f32[1, 512, 4096]" - t1441 = ltorch.mul(t1430, t1440) # t1441: "cuda:0 f32[1, 512, 4096]" - # t1441 = prims.mul(t1430, t1440) # t1441: "cuda:0 f32[1, 512, 4096]" - t1442 = prims.convert_element_type(t1441, dtypes.bfloat16) # t1442: "cuda:0 bf16[1, 512, 4096]" - t1443 = prims.broadcast_in_dim(t_transformer_h_10_norm_1_weight, (1, 512, 4096), (2,)) # t1443: "cuda:0 bf16[1, 512, 4096]" - t1444 = prims.convert_element_type(t1442, dtypes.float32) # t1444: "cuda:0 f32[1, 512, 4096]" - t1445 = prims.convert_element_type(t1443, dtypes.float32) # t1445: "cuda:0 f32[1, 512, 4096]" - t1446 = ltorch.mul(t1444, t1445) # t1446: "cuda:0 f32[1, 512, 4096]" - # t1446 = prims.mul(t1444, t1445) # t1446: "cuda:0 f32[1, 512, 4096]" - t1447 = prims.convert_element_type(t1446, dtypes.bfloat16) # t1447: "cuda:0 bf16[1, 512, 4096]" - t1448 = prims.linear(t1447, t_transformer_h_10_attn_attn_weight, None) # t1448: "cuda:0 bf16[1, 512, 12288]" - t1454 = prims.reshape(t1448, (1, 512, 32, 3, 128)) # t1454: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1460 = prims.transpose(t1454, (0, 2, 3, 1, 4)) # t1460: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1461, t1462, t1463) = ltorch.split(t1460, (1, 1, 1), 2) - # t1461 = prims.slice_prim(t1460, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1461: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1462 = prims.slice_prim(t1460, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1462: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1463 = prims.slice_prim(t1460, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1463: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1469 = prims.reshape(t1461, (1, 32, 512, 128)) # t1469: "cuda:0 bf16[1, 32, 512, 128]" - t1475 = prims.reshape(t1462, (1, 32, 512, 128)) # t1475: "cuda:0 bf16[1, 32, 512, 128]" - t1481 = prims.reshape(t1463, (1, 32, 512, 128)) # t1481: "cuda:0 bf16[1, 32, 512, 128]" - t1482 = prims.slice_prim(t1469, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1482: "cuda:0 bf16[1, 32, 512, 128]" - t1483 = prims.slice_prim(t1482, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1483: "cuda:0 bf16[1, 32, 512, 64]" - t1484 = prims.slice_prim(t1482, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1484: "cuda:0 bf16[1, 32, 512, 64]" - t1485 = prims.convert_element_type(t1484, dtypes.float32) # t1485: "cuda:0 f32[1, 32, 512, 64]" - t1486 = prims.neg(t1485) # t1486: "cuda:0 f32[1, 32, 512, 64]" - t1487 = prims.convert_element_type(t1486, dtypes.bfloat16) # t1487: "cuda:0 bf16[1, 32, 512, 64]" - t1489 = prims.cat((t1487, t1483), -1) # t1489: "cuda:0 bf16[1, 32, 512, 128]" - t1490 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1490: "cuda:0 f32[1, 32, 512, 128]" - t1491 = prims.convert_element_type(t1482, dtypes.float32) # t1491: "cuda:0 f32[1, 32, 512, 128]" - t1492 = ltorch.mul(t1491, t1490) # t1492: "cuda:0 f32[1, 32, 512, 128]" - # t1492 = prims.mul(t1491, t1490) # t1492: "cuda:0 f32[1, 32, 512, 128]" - t1493 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1493: "cuda:0 f32[1, 32, 512, 128]" - t1494 = prims.convert_element_type(t1489, dtypes.float32) # t1494: "cuda:0 f32[1, 32, 512, 128]" - t1495 = ltorch.mul(t1494, t1493) # t1495: "cuda:0 f32[1, 32, 512, 128]" - # t1495 = prims.mul(t1494, t1493) # t1495: "cuda:0 f32[1, 32, 512, 128]" - t1496 = ltorch.add(t1492, t1495, alpha=None) # t1496: "cuda:0 f32[1, 32, 512, 128]" - # t1496 = prims.add(t1492, t1495) # t1496: "cuda:0 f32[1, 32, 512, 128]" - t1497 = prims.convert_element_type(t1496, dtypes.bfloat16) # t1497: "cuda:0 bf16[1, 32, 512, 128]" - t1498 = prims.slice_prim(t1475, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1498: "cuda:0 bf16[1, 32, 512, 128]" - t1499 = prims.slice_prim(t1498, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1499: "cuda:0 bf16[1, 32, 512, 64]" - t1500 = prims.slice_prim(t1498, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1500: "cuda:0 bf16[1, 32, 512, 64]" - t1501 = prims.convert_element_type(t1500, dtypes.float32) # t1501: "cuda:0 f32[1, 32, 512, 64]" - t1502 = prims.neg(t1501) # t1502: "cuda:0 f32[1, 32, 512, 64]" - t1503 = prims.convert_element_type(t1502, dtypes.bfloat16) # t1503: "cuda:0 bf16[1, 32, 512, 64]" - t1505 = prims.cat((t1503, t1499), -1) # t1505: "cuda:0 bf16[1, 32, 512, 128]" - t1506 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1506: "cuda:0 f32[1, 32, 512, 128]" - t1507 = prims.convert_element_type(t1498, dtypes.float32) # t1507: "cuda:0 f32[1, 32, 512, 128]" - t1508 = ltorch.mul(t1507, t1506) # t1508: "cuda:0 f32[1, 32, 512, 128]" - # t1508 = prims.mul(t1507, t1506) # t1508: "cuda:0 f32[1, 32, 512, 128]" - t1509 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1509: "cuda:0 f32[1, 32, 512, 128]" - t1510 = prims.convert_element_type(t1505, dtypes.float32) # t1510: "cuda:0 f32[1, 32, 512, 128]" - t1511 = ltorch.mul(t1510, t1509) # t1511: "cuda:0 f32[1, 32, 512, 128]" - # t1511 = prims.mul(t1510, t1509) # t1511: "cuda:0 f32[1, 32, 512, 128]" - t1512 = ltorch.add(t1508, t1511, alpha=None) # t1512: "cuda:0 f32[1, 32, 512, 128]" - # t1512 = prims.add(t1508, t1511) # t1512: "cuda:0 f32[1, 32, 512, 128]" - t1513 = prims.convert_element_type(t1512, dtypes.bfloat16) # t1513: "cuda:0 bf16[1, 32, 512, 128]" - t1514 = prims.slice_prim(t1469, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1514: "cuda:0 bf16[1, 32, 512, 0]" - t1516 = prims.cat((t1497, t1514), -1) # t1516: "cuda:0 bf16[1, 32, 512, 128]" - t1517 = prims.slice_prim(t1475, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1517: "cuda:0 bf16[1, 32, 512, 0]" - t1519 = prims.cat((t1513, t1517), -1) # t1519: "cuda:0 bf16[1, 32, 512, 128]" - (t1520, t1521, t1522, t1523) = cudnn_sdpa_fwd(t1516, t1519, t1481, None, 0.0, True, scale=0.08838834764831843) - t1526 = prims.transpose(t1520, (0, 2, 1, 3)) # t1526: "cuda:0 bf16[1, 512, 32, 128]" - t1530 = prims.reshape(t1526, (1, 512, 4096)) # t1530: "cuda:0 bf16[1, 512, 4096]" - t1531 = prims.linear(t1530, t_transformer_h_10_attn_proj_weight, None) # t1531: "cuda:0 bf16[1, 512, 4096]" - t1532 = prims.convert_element_type(t1531, dtypes.float32) # t1532: "cuda:0 f32[1, 512, 4096]" - t1533 = prims.convert_element_type(t1429, dtypes.float32) # t1533: "cuda:0 f32[1, 512, 4096]" - t1534 = ltorch.add(t1532, t1533, alpha=None) # t1534: "cuda:0 f32[1, 512, 4096]" - # t1534 = prims.add(t1532, t1533) # t1534: "cuda:0 f32[1, 512, 4096]" - t1535 = prims.convert_element_type(t1534, dtypes.bfloat16) # t1535: "cuda:0 bf16[1, 512, 4096]" - t1536 = prims.convert_element_type(t1535, dtypes.float32) # t1536: "cuda:0 f32[1, 512, 4096]" - t1537 = ltorch.mul(t1536, t1536) # t1537: "cuda:0 f32[1, 512, 4096]" - # t1537 = prims.mul(t1536, t1536) # t1537: "cuda:0 f32[1, 512, 4096]" - t1539 = prims.sum(t1537, (2,)) # t1539: "cuda:0 f32[1, 512]" - t1540 = prims.broadcast_in_dim(t1539, [1, 512, 1], [0, 1]) # t1540: "cuda:0 f32[1, 512, 1]" - t1542 = ltorch.true_divide(t1540, 4096.0) # t1542: "cuda:0 f32[1, 512, 1]" - # t1542 = prims.div(t1540, 4096.0) # t1542: "cuda:0 f32[1, 512, 1]" - t1544 = ltorch.add(t1542, 1e-05, alpha=None) # t1544: "cuda:0 f32[1, 512, 1]" - # t1544 = prims.add(t1542, 1e-05) # t1544: "cuda:0 f32[1, 512, 1]" - t1545 = prims.rsqrt(t1544) # t1545: "cuda:0 f32[1, 512, 1]" - t1546 = prims.broadcast_in_dim(t1545, (1, 512, 4096), (0, 1, 2)) # t1546: "cuda:0 f32[1, 512, 4096]" - t1547 = ltorch.mul(t1536, t1546) # t1547: "cuda:0 f32[1, 512, 4096]" - # t1547 = prims.mul(t1536, t1546) # t1547: "cuda:0 f32[1, 512, 4096]" - t1548 = prims.convert_element_type(t1547, dtypes.bfloat16) # t1548: "cuda:0 bf16[1, 512, 4096]" - t1549 = prims.broadcast_in_dim(t_transformer_h_10_norm_2_weight, (1, 512, 4096), (2,)) # t1549: "cuda:0 bf16[1, 512, 4096]" - t1550 = prims.convert_element_type(t1548, dtypes.float32) # t1550: "cuda:0 f32[1, 512, 4096]" - t1551 = prims.convert_element_type(t1549, dtypes.float32) # t1551: "cuda:0 f32[1, 512, 4096]" - t1552 = ltorch.mul(t1550, t1551) # t1552: "cuda:0 f32[1, 512, 4096]" - # t1552 = prims.mul(t1550, t1551) # t1552: "cuda:0 f32[1, 512, 4096]" - t1553 = prims.convert_element_type(t1552, dtypes.bfloat16) # t1553: "cuda:0 bf16[1, 512, 4096]" - t1554 = prims.linear(t1553, t_transformer_h_10_mlp_fc_1_weight, None) # t1554: "cuda:0 bf16[1, 512, 11008]" - t1555 = prims.linear(t1553, t_transformer_h_10_mlp_fc_2_weight, None) # t1555: "cuda:0 bf16[1, 512, 11008]" - t1556 = prims.convert_element_type(t1554, dtypes.float32) # t1556: "cuda:0 f32[1, 512, 11008]" - t1557 = prims.neg(t1556) # t1557: "cuda:0 f32[1, 512, 11008]" - t1558 = prims.exp(t1557) # t1558: "cuda:0 f32[1, 512, 11008]" - t1559 = ltorch.add(1.0, t1558, alpha=None) # t1559: "cuda:0 f32[1, 512, 11008]" - # t1559 = prims.add(1.0, t1558) # t1559: "cuda:0 f32[1, 512, 11008]" - t1560 = prims.reciprocal(t1559) # t1560: "cuda:0 f32[1, 512, 11008]" - t1561 = prims.convert_element_type(t1560, dtypes.bfloat16) # t1561: "cuda:0 bf16[1, 512, 11008]" - t1562 = prims.convert_element_type(t1554, dtypes.float32) # t1562: "cuda:0 f32[1, 512, 11008]" - t1563 = prims.convert_element_type(t1561, dtypes.float32) # t1563: "cuda:0 f32[1, 512, 11008]" - t1564 = ltorch.mul(t1562, t1563) # t1564: "cuda:0 f32[1, 512, 11008]" - # t1564 = prims.mul(t1562, t1563) # t1564: "cuda:0 f32[1, 512, 11008]" - t1565 = prims.convert_element_type(t1564, dtypes.bfloat16) # t1565: "cuda:0 bf16[1, 512, 11008]" - t1566 = prims.convert_element_type(t1565, dtypes.float32) # t1566: "cuda:0 f32[1, 512, 11008]" - t1567 = prims.convert_element_type(t1555, dtypes.float32) # t1567: "cuda:0 f32[1, 512, 11008]" - t1568 = ltorch.mul(t1566, t1567) # t1568: "cuda:0 f32[1, 512, 11008]" - # t1568 = prims.mul(t1566, t1567) # t1568: "cuda:0 f32[1, 512, 11008]" - t1569 = prims.convert_element_type(t1568, dtypes.bfloat16) # t1569: "cuda:0 bf16[1, 512, 11008]" - t1570 = prims.linear(t1569, t_transformer_h_10_mlp_proj_weight, None) # t1570: "cuda:0 bf16[1, 512, 4096]" - t1571 = prims.convert_element_type(t1570, dtypes.float32) # t1571: "cuda:0 f32[1, 512, 4096]" - t1572 = prims.convert_element_type(t1535, dtypes.float32) # t1572: "cuda:0 f32[1, 512, 4096]" - t1573 = ltorch.add(t1571, t1572, alpha=None) # t1573: "cuda:0 f32[1, 512, 4096]" - # t1573 = prims.add(t1571, t1572) # t1573: "cuda:0 f32[1, 512, 4096]" - t1574 = prims.convert_element_type(t1573, dtypes.bfloat16) # t1574: "cuda:0 bf16[1, 512, 4096]" - t1575 = prims.convert_element_type(t1574, dtypes.float32) # t1575: "cuda:0 f32[1, 512, 4096]" - t1576 = ltorch.mul(t1575, t1575) # t1576: "cuda:0 f32[1, 512, 4096]" - # t1576 = prims.mul(t1575, t1575) # t1576: "cuda:0 f32[1, 512, 4096]" - t1578 = prims.sum(t1576, (2,)) # t1578: "cuda:0 f32[1, 512]" - t1579 = prims.broadcast_in_dim(t1578, [1, 512, 1], [0, 1]) # t1579: "cuda:0 f32[1, 512, 1]" - t1581 = ltorch.true_divide(t1579, 4096.0) # t1581: "cuda:0 f32[1, 512, 1]" - # t1581 = prims.div(t1579, 4096.0) # t1581: "cuda:0 f32[1, 512, 1]" - t1583 = ltorch.add(t1581, 1e-05, alpha=None) # t1583: "cuda:0 f32[1, 512, 1]" - # t1583 = prims.add(t1581, 1e-05) # t1583: "cuda:0 f32[1, 512, 1]" - t1584 = prims.rsqrt(t1583) # t1584: "cuda:0 f32[1, 512, 1]" - t1585 = prims.broadcast_in_dim(t1584, (1, 512, 4096), (0, 1, 2)) # t1585: "cuda:0 f32[1, 512, 4096]" - t1586 = ltorch.mul(t1575, t1585) # t1586: "cuda:0 f32[1, 512, 4096]" - # t1586 = prims.mul(t1575, t1585) # t1586: "cuda:0 f32[1, 512, 4096]" - t1587 = prims.convert_element_type(t1586, dtypes.bfloat16) # t1587: "cuda:0 bf16[1, 512, 4096]" - t1588 = prims.broadcast_in_dim(t_transformer_h_11_norm_1_weight, (1, 512, 4096), (2,)) # t1588: "cuda:0 bf16[1, 512, 4096]" - t1589 = prims.convert_element_type(t1587, dtypes.float32) # t1589: "cuda:0 f32[1, 512, 4096]" - t1590 = prims.convert_element_type(t1588, dtypes.float32) # t1590: "cuda:0 f32[1, 512, 4096]" - t1591 = ltorch.mul(t1589, t1590) # t1591: "cuda:0 f32[1, 512, 4096]" - # t1591 = prims.mul(t1589, t1590) # t1591: "cuda:0 f32[1, 512, 4096]" - t1592 = prims.convert_element_type(t1591, dtypes.bfloat16) # t1592: "cuda:0 bf16[1, 512, 4096]" - t1593 = prims.linear(t1592, t_transformer_h_11_attn_attn_weight, None) # t1593: "cuda:0 bf16[1, 512, 12288]" - t1599 = prims.reshape(t1593, (1, 512, 32, 3, 128)) # t1599: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1605 = prims.transpose(t1599, (0, 2, 3, 1, 4)) # t1605: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1606, t1607, t1608) = ltorch.split(t1605, (1, 1, 1), 2) - # t1606 = prims.slice_prim(t1605, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1606: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1607 = prims.slice_prim(t1605, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1607: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1608 = prims.slice_prim(t1605, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1608: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1614 = prims.reshape(t1606, (1, 32, 512, 128)) # t1614: "cuda:0 bf16[1, 32, 512, 128]" - t1620 = prims.reshape(t1607, (1, 32, 512, 128)) # t1620: "cuda:0 bf16[1, 32, 512, 128]" - t1626 = prims.reshape(t1608, (1, 32, 512, 128)) # t1626: "cuda:0 bf16[1, 32, 512, 128]" - t1627 = prims.slice_prim(t1614, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1627: "cuda:0 bf16[1, 32, 512, 128]" - t1628 = prims.slice_prim(t1627, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1628: "cuda:0 bf16[1, 32, 512, 64]" - t1629 = prims.slice_prim(t1627, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1629: "cuda:0 bf16[1, 32, 512, 64]" - t1630 = prims.convert_element_type(t1629, dtypes.float32) # t1630: "cuda:0 f32[1, 32, 512, 64]" - t1631 = prims.neg(t1630) # t1631: "cuda:0 f32[1, 32, 512, 64]" - t1632 = prims.convert_element_type(t1631, dtypes.bfloat16) # t1632: "cuda:0 bf16[1, 32, 512, 64]" - t1634 = prims.cat((t1632, t1628), -1) # t1634: "cuda:0 bf16[1, 32, 512, 128]" - t1635 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1635: "cuda:0 f32[1, 32, 512, 128]" - t1636 = prims.convert_element_type(t1627, dtypes.float32) # t1636: "cuda:0 f32[1, 32, 512, 128]" - t1637 = ltorch.mul(t1636, t1635) # t1637: "cuda:0 f32[1, 32, 512, 128]" - # t1637 = prims.mul(t1636, t1635) # t1637: "cuda:0 f32[1, 32, 512, 128]" - t1638 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1638: "cuda:0 f32[1, 32, 512, 128]" - t1639 = prims.convert_element_type(t1634, dtypes.float32) # t1639: "cuda:0 f32[1, 32, 512, 128]" - t1640 = ltorch.mul(t1639, t1638) # t1640: "cuda:0 f32[1, 32, 512, 128]" - # t1640 = prims.mul(t1639, t1638) # t1640: "cuda:0 f32[1, 32, 512, 128]" - t1641 = ltorch.add(t1637, t1640, alpha=None) # t1641: "cuda:0 f32[1, 32, 512, 128]" - # t1641 = prims.add(t1637, t1640) # t1641: "cuda:0 f32[1, 32, 512, 128]" - t1642 = prims.convert_element_type(t1641, dtypes.bfloat16) # t1642: "cuda:0 bf16[1, 32, 512, 128]" - t1643 = prims.slice_prim(t1620, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1643: "cuda:0 bf16[1, 32, 512, 128]" - t1644 = prims.slice_prim(t1643, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1644: "cuda:0 bf16[1, 32, 512, 64]" - t1645 = prims.slice_prim(t1643, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1645: "cuda:0 bf16[1, 32, 512, 64]" - t1646 = prims.convert_element_type(t1645, dtypes.float32) # t1646: "cuda:0 f32[1, 32, 512, 64]" - t1647 = prims.neg(t1646) # t1647: "cuda:0 f32[1, 32, 512, 64]" - t1648 = prims.convert_element_type(t1647, dtypes.bfloat16) # t1648: "cuda:0 bf16[1, 32, 512, 64]" - t1650 = prims.cat((t1648, t1644), -1) # t1650: "cuda:0 bf16[1, 32, 512, 128]" - t1651 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1651: "cuda:0 f32[1, 32, 512, 128]" - t1652 = prims.convert_element_type(t1643, dtypes.float32) # t1652: "cuda:0 f32[1, 32, 512, 128]" - t1653 = ltorch.mul(t1652, t1651) # t1653: "cuda:0 f32[1, 32, 512, 128]" - # t1653 = prims.mul(t1652, t1651) # t1653: "cuda:0 f32[1, 32, 512, 128]" - t1654 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1654: "cuda:0 f32[1, 32, 512, 128]" - t1655 = prims.convert_element_type(t1650, dtypes.float32) # t1655: "cuda:0 f32[1, 32, 512, 128]" - t1656 = ltorch.mul(t1655, t1654) # t1656: "cuda:0 f32[1, 32, 512, 128]" - # t1656 = prims.mul(t1655, t1654) # t1656: "cuda:0 f32[1, 32, 512, 128]" - t1657 = ltorch.add(t1653, t1656, alpha=None) # t1657: "cuda:0 f32[1, 32, 512, 128]" - # t1657 = prims.add(t1653, t1656) # t1657: "cuda:0 f32[1, 32, 512, 128]" - t1658 = prims.convert_element_type(t1657, dtypes.bfloat16) # t1658: "cuda:0 bf16[1, 32, 512, 128]" - t1659 = prims.slice_prim(t1614, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1659: "cuda:0 bf16[1, 32, 512, 0]" - t1661 = prims.cat((t1642, t1659), -1) # t1661: "cuda:0 bf16[1, 32, 512, 128]" - t1662 = prims.slice_prim(t1620, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1662: "cuda:0 bf16[1, 32, 512, 0]" - t1664 = prims.cat((t1658, t1662), -1) # t1664: "cuda:0 bf16[1, 32, 512, 128]" - (t1665, t1666, t1667, t1668) = cudnn_sdpa_fwd(t1661, t1664, t1626, None, 0.0, True, scale=0.08838834764831843) - t1671 = prims.transpose(t1665, (0, 2, 1, 3)) # t1671: "cuda:0 bf16[1, 512, 32, 128]" - t1675 = prims.reshape(t1671, (1, 512, 4096)) # t1675: "cuda:0 bf16[1, 512, 4096]" - t1676 = prims.linear(t1675, t_transformer_h_11_attn_proj_weight, None) # t1676: "cuda:0 bf16[1, 512, 4096]" - t1677 = prims.convert_element_type(t1676, dtypes.float32) # t1677: "cuda:0 f32[1, 512, 4096]" - t1678 = prims.convert_element_type(t1574, dtypes.float32) # t1678: "cuda:0 f32[1, 512, 4096]" - t1679 = ltorch.add(t1677, t1678, alpha=None) # t1679: "cuda:0 f32[1, 512, 4096]" - # t1679 = prims.add(t1677, t1678) # t1679: "cuda:0 f32[1, 512, 4096]" - t1680 = prims.convert_element_type(t1679, dtypes.bfloat16) # t1680: "cuda:0 bf16[1, 512, 4096]" - t1681 = prims.convert_element_type(t1680, dtypes.float32) # t1681: "cuda:0 f32[1, 512, 4096]" - t1682 = ltorch.mul(t1681, t1681) # t1682: "cuda:0 f32[1, 512, 4096]" - # t1682 = prims.mul(t1681, t1681) # t1682: "cuda:0 f32[1, 512, 4096]" - t1684 = prims.sum(t1682, (2,)) # t1684: "cuda:0 f32[1, 512]" - t1685 = prims.broadcast_in_dim(t1684, [1, 512, 1], [0, 1]) # t1685: "cuda:0 f32[1, 512, 1]" - t1687 = ltorch.true_divide(t1685, 4096.0) # t1687: "cuda:0 f32[1, 512, 1]" - # t1687 = prims.div(t1685, 4096.0) # t1687: "cuda:0 f32[1, 512, 1]" - t1689 = ltorch.add(t1687, 1e-05, alpha=None) # t1689: "cuda:0 f32[1, 512, 1]" - # t1689 = prims.add(t1687, 1e-05) # t1689: "cuda:0 f32[1, 512, 1]" - t1690 = prims.rsqrt(t1689) # t1690: "cuda:0 f32[1, 512, 1]" - t1691 = prims.broadcast_in_dim(t1690, (1, 512, 4096), (0, 1, 2)) # t1691: "cuda:0 f32[1, 512, 4096]" - t1692 = ltorch.mul(t1681, t1691) # t1692: "cuda:0 f32[1, 512, 4096]" - # t1692 = prims.mul(t1681, t1691) # t1692: "cuda:0 f32[1, 512, 4096]" - t1693 = prims.convert_element_type(t1692, dtypes.bfloat16) # t1693: "cuda:0 bf16[1, 512, 4096]" - t1694 = prims.broadcast_in_dim(t_transformer_h_11_norm_2_weight, (1, 512, 4096), (2,)) # t1694: "cuda:0 bf16[1, 512, 4096]" - t1695 = prims.convert_element_type(t1693, dtypes.float32) # t1695: "cuda:0 f32[1, 512, 4096]" - t1696 = prims.convert_element_type(t1694, dtypes.float32) # t1696: "cuda:0 f32[1, 512, 4096]" - t1697 = ltorch.mul(t1695, t1696) # t1697: "cuda:0 f32[1, 512, 4096]" - # t1697 = prims.mul(t1695, t1696) # t1697: "cuda:0 f32[1, 512, 4096]" - t1698 = prims.convert_element_type(t1697, dtypes.bfloat16) # t1698: "cuda:0 bf16[1, 512, 4096]" - t1699 = prims.linear(t1698, t_transformer_h_11_mlp_fc_1_weight, None) # t1699: "cuda:0 bf16[1, 512, 11008]" - t1700 = prims.linear(t1698, t_transformer_h_11_mlp_fc_2_weight, None) # t1700: "cuda:0 bf16[1, 512, 11008]" - t1701 = prims.convert_element_type(t1699, dtypes.float32) # t1701: "cuda:0 f32[1, 512, 11008]" - t1702 = prims.neg(t1701) # t1702: "cuda:0 f32[1, 512, 11008]" - t1703 = prims.exp(t1702) # t1703: "cuda:0 f32[1, 512, 11008]" - t1704 = ltorch.add(1.0, t1703, alpha=None) # t1704: "cuda:0 f32[1, 512, 11008]" - # t1704 = prims.add(1.0, t1703) # t1704: "cuda:0 f32[1, 512, 11008]" - t1705 = prims.reciprocal(t1704) # t1705: "cuda:0 f32[1, 512, 11008]" - t1706 = prims.convert_element_type(t1705, dtypes.bfloat16) # t1706: "cuda:0 bf16[1, 512, 11008]" - t1707 = prims.convert_element_type(t1699, dtypes.float32) # t1707: "cuda:0 f32[1, 512, 11008]" - t1708 = prims.convert_element_type(t1706, dtypes.float32) # t1708: "cuda:0 f32[1, 512, 11008]" - t1709 = ltorch.mul(t1707, t1708) # t1709: "cuda:0 f32[1, 512, 11008]" - # t1709 = prims.mul(t1707, t1708) # t1709: "cuda:0 f32[1, 512, 11008]" - t1710 = prims.convert_element_type(t1709, dtypes.bfloat16) # t1710: "cuda:0 bf16[1, 512, 11008]" - t1711 = prims.convert_element_type(t1710, dtypes.float32) # t1711: "cuda:0 f32[1, 512, 11008]" - t1712 = prims.convert_element_type(t1700, dtypes.float32) # t1712: "cuda:0 f32[1, 512, 11008]" - t1713 = ltorch.mul(t1711, t1712) # t1713: "cuda:0 f32[1, 512, 11008]" - # t1713 = prims.mul(t1711, t1712) # t1713: "cuda:0 f32[1, 512, 11008]" - t1714 = prims.convert_element_type(t1713, dtypes.bfloat16) # t1714: "cuda:0 bf16[1, 512, 11008]" - t1715 = prims.linear(t1714, t_transformer_h_11_mlp_proj_weight, None) # t1715: "cuda:0 bf16[1, 512, 4096]" - t1716 = prims.convert_element_type(t1715, dtypes.float32) # t1716: "cuda:0 f32[1, 512, 4096]" - t1717 = prims.convert_element_type(t1680, dtypes.float32) # t1717: "cuda:0 f32[1, 512, 4096]" - t1718 = ltorch.add(t1716, t1717, alpha=None) # t1718: "cuda:0 f32[1, 512, 4096]" - # t1718 = prims.add(t1716, t1717) # t1718: "cuda:0 f32[1, 512, 4096]" - t1719 = prims.convert_element_type(t1718, dtypes.bfloat16) # t1719: "cuda:0 bf16[1, 512, 4096]" - t1720 = prims.convert_element_type(t1719, dtypes.float32) # t1720: "cuda:0 f32[1, 512, 4096]" - t1721 = ltorch.mul(t1720, t1720) # t1721: "cuda:0 f32[1, 512, 4096]" - # t1721 = prims.mul(t1720, t1720) # t1721: "cuda:0 f32[1, 512, 4096]" - t1723 = prims.sum(t1721, (2,)) # t1723: "cuda:0 f32[1, 512]" - t1724 = prims.broadcast_in_dim(t1723, [1, 512, 1], [0, 1]) # t1724: "cuda:0 f32[1, 512, 1]" - t1726 = ltorch.true_divide(t1724, 4096.0) # t1726: "cuda:0 f32[1, 512, 1]" - # t1726 = prims.div(t1724, 4096.0) # t1726: "cuda:0 f32[1, 512, 1]" - t1728 = ltorch.add(t1726, 1e-05, alpha=None) # t1728: "cuda:0 f32[1, 512, 1]" - # t1728 = prims.add(t1726, 1e-05) # t1728: "cuda:0 f32[1, 512, 1]" - t1729 = prims.rsqrt(t1728) # t1729: "cuda:0 f32[1, 512, 1]" - t1730 = prims.broadcast_in_dim(t1729, (1, 512, 4096), (0, 1, 2)) # t1730: "cuda:0 f32[1, 512, 4096]" - t1731 = ltorch.mul(t1720, t1730) # t1731: "cuda:0 f32[1, 512, 4096]" - # t1731 = prims.mul(t1720, t1730) # t1731: "cuda:0 f32[1, 512, 4096]" - t1732 = prims.convert_element_type(t1731, dtypes.bfloat16) # t1732: "cuda:0 bf16[1, 512, 4096]" - t1733 = prims.broadcast_in_dim(t_transformer_h_12_norm_1_weight, (1, 512, 4096), (2,)) # t1733: "cuda:0 bf16[1, 512, 4096]" - t1734 = prims.convert_element_type(t1732, dtypes.float32) # t1734: "cuda:0 f32[1, 512, 4096]" - t1735 = prims.convert_element_type(t1733, dtypes.float32) # t1735: "cuda:0 f32[1, 512, 4096]" - t1736 = ltorch.mul(t1734, t1735) # t1736: "cuda:0 f32[1, 512, 4096]" - # t1736 = prims.mul(t1734, t1735) # t1736: "cuda:0 f32[1, 512, 4096]" - t1737 = prims.convert_element_type(t1736, dtypes.bfloat16) # t1737: "cuda:0 bf16[1, 512, 4096]" - t1738 = prims.linear(t1737, t_transformer_h_12_attn_attn_weight, None) # t1738: "cuda:0 bf16[1, 512, 12288]" - t1744 = prims.reshape(t1738, (1, 512, 32, 3, 128)) # t1744: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1750 = prims.transpose(t1744, (0, 2, 3, 1, 4)) # t1750: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1751, t1752, t1753) = ltorch.split(t1750, (1, 1, 1), 2) - # t1751 = prims.slice_prim(t1750, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1751: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1752 = prims.slice_prim(t1750, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1752: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1753 = prims.slice_prim(t1750, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1753: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1759 = prims.reshape(t1751, (1, 32, 512, 128)) # t1759: "cuda:0 bf16[1, 32, 512, 128]" - t1765 = prims.reshape(t1752, (1, 32, 512, 128)) # t1765: "cuda:0 bf16[1, 32, 512, 128]" - t1771 = prims.reshape(t1753, (1, 32, 512, 128)) # t1771: "cuda:0 bf16[1, 32, 512, 128]" - t1772 = prims.slice_prim(t1759, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1772: "cuda:0 bf16[1, 32, 512, 128]" - t1773 = prims.slice_prim(t1772, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1773: "cuda:0 bf16[1, 32, 512, 64]" - t1774 = prims.slice_prim(t1772, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1774: "cuda:0 bf16[1, 32, 512, 64]" - t1775 = prims.convert_element_type(t1774, dtypes.float32) # t1775: "cuda:0 f32[1, 32, 512, 64]" - t1776 = prims.neg(t1775) # t1776: "cuda:0 f32[1, 32, 512, 64]" - t1777 = prims.convert_element_type(t1776, dtypes.bfloat16) # t1777: "cuda:0 bf16[1, 32, 512, 64]" - t1779 = prims.cat((t1777, t1773), -1) # t1779: "cuda:0 bf16[1, 32, 512, 128]" - t1780 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1780: "cuda:0 f32[1, 32, 512, 128]" - t1781 = prims.convert_element_type(t1772, dtypes.float32) # t1781: "cuda:0 f32[1, 32, 512, 128]" - t1782 = ltorch.mul(t1781, t1780) # t1782: "cuda:0 f32[1, 32, 512, 128]" - # t1782 = prims.mul(t1781, t1780) # t1782: "cuda:0 f32[1, 32, 512, 128]" - t1783 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1783: "cuda:0 f32[1, 32, 512, 128]" - t1784 = prims.convert_element_type(t1779, dtypes.float32) # t1784: "cuda:0 f32[1, 32, 512, 128]" - t1785 = ltorch.mul(t1784, t1783) # t1785: "cuda:0 f32[1, 32, 512, 128]" - # t1785 = prims.mul(t1784, t1783) # t1785: "cuda:0 f32[1, 32, 512, 128]" - t1786 = ltorch.add(t1782, t1785, alpha=None) # t1786: "cuda:0 f32[1, 32, 512, 128]" - # t1786 = prims.add(t1782, t1785) # t1786: "cuda:0 f32[1, 32, 512, 128]" - t1787 = prims.convert_element_type(t1786, dtypes.bfloat16) # t1787: "cuda:0 bf16[1, 32, 512, 128]" - t1788 = prims.slice_prim(t1765, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1788: "cuda:0 bf16[1, 32, 512, 128]" - t1789 = prims.slice_prim(t1788, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1789: "cuda:0 bf16[1, 32, 512, 64]" - t1790 = prims.slice_prim(t1788, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1790: "cuda:0 bf16[1, 32, 512, 64]" - t1791 = prims.convert_element_type(t1790, dtypes.float32) # t1791: "cuda:0 f32[1, 32, 512, 64]" - t1792 = prims.neg(t1791) # t1792: "cuda:0 f32[1, 32, 512, 64]" - t1793 = prims.convert_element_type(t1792, dtypes.bfloat16) # t1793: "cuda:0 bf16[1, 32, 512, 64]" - t1795 = prims.cat((t1793, t1789), -1) # t1795: "cuda:0 bf16[1, 32, 512, 128]" - t1796 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1796: "cuda:0 f32[1, 32, 512, 128]" - t1797 = prims.convert_element_type(t1788, dtypes.float32) # t1797: "cuda:0 f32[1, 32, 512, 128]" - t1798 = ltorch.mul(t1797, t1796) # t1798: "cuda:0 f32[1, 32, 512, 128]" - # t1798 = prims.mul(t1797, t1796) # t1798: "cuda:0 f32[1, 32, 512, 128]" - t1799 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1799: "cuda:0 f32[1, 32, 512, 128]" - t1800 = prims.convert_element_type(t1795, dtypes.float32) # t1800: "cuda:0 f32[1, 32, 512, 128]" - t1801 = ltorch.mul(t1800, t1799) # t1801: "cuda:0 f32[1, 32, 512, 128]" - # t1801 = prims.mul(t1800, t1799) # t1801: "cuda:0 f32[1, 32, 512, 128]" - t1802 = ltorch.add(t1798, t1801, alpha=None) # t1802: "cuda:0 f32[1, 32, 512, 128]" - # t1802 = prims.add(t1798, t1801) # t1802: "cuda:0 f32[1, 32, 512, 128]" - t1803 = prims.convert_element_type(t1802, dtypes.bfloat16) # t1803: "cuda:0 bf16[1, 32, 512, 128]" - t1804 = prims.slice_prim(t1759, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1804: "cuda:0 bf16[1, 32, 512, 0]" - t1806 = prims.cat((t1787, t1804), -1) # t1806: "cuda:0 bf16[1, 32, 512, 128]" - t1807 = prims.slice_prim(t1765, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1807: "cuda:0 bf16[1, 32, 512, 0]" - t1809 = prims.cat((t1803, t1807), -1) # t1809: "cuda:0 bf16[1, 32, 512, 128]" - (t1810, t1811, t1812, t1813) = cudnn_sdpa_fwd(t1806, t1809, t1771, None, 0.0, True, scale=0.08838834764831843) - t1816 = prims.transpose(t1810, (0, 2, 1, 3)) # t1816: "cuda:0 bf16[1, 512, 32, 128]" - t1820 = prims.reshape(t1816, (1, 512, 4096)) # t1820: "cuda:0 bf16[1, 512, 4096]" - t1821 = prims.linear(t1820, t_transformer_h_12_attn_proj_weight, None) # t1821: "cuda:0 bf16[1, 512, 4096]" - t1822 = prims.convert_element_type(t1821, dtypes.float32) # t1822: "cuda:0 f32[1, 512, 4096]" - t1823 = prims.convert_element_type(t1719, dtypes.float32) # t1823: "cuda:0 f32[1, 512, 4096]" - t1824 = ltorch.add(t1822, t1823, alpha=None) # t1824: "cuda:0 f32[1, 512, 4096]" - # t1824 = prims.add(t1822, t1823) # t1824: "cuda:0 f32[1, 512, 4096]" - t1825 = prims.convert_element_type(t1824, dtypes.bfloat16) # t1825: "cuda:0 bf16[1, 512, 4096]" - t1826 = prims.convert_element_type(t1825, dtypes.float32) # t1826: "cuda:0 f32[1, 512, 4096]" - t1827 = ltorch.mul(t1826, t1826) # t1827: "cuda:0 f32[1, 512, 4096]" - # t1827 = prims.mul(t1826, t1826) # t1827: "cuda:0 f32[1, 512, 4096]" - t1829 = prims.sum(t1827, (2,)) # t1829: "cuda:0 f32[1, 512]" - t1830 = prims.broadcast_in_dim(t1829, [1, 512, 1], [0, 1]) # t1830: "cuda:0 f32[1, 512, 1]" - t1832 = ltorch.true_divide(t1830, 4096.0) # t1832: "cuda:0 f32[1, 512, 1]" - # t1832 = prims.div(t1830, 4096.0) # t1832: "cuda:0 f32[1, 512, 1]" - t1834 = ltorch.add(t1832, 1e-05, alpha=None) # t1834: "cuda:0 f32[1, 512, 1]" - # t1834 = prims.add(t1832, 1e-05) # t1834: "cuda:0 f32[1, 512, 1]" - t1835 = prims.rsqrt(t1834) # t1835: "cuda:0 f32[1, 512, 1]" - t1836 = prims.broadcast_in_dim(t1835, (1, 512, 4096), (0, 1, 2)) # t1836: "cuda:0 f32[1, 512, 4096]" - t1837 = ltorch.mul(t1826, t1836) # t1837: "cuda:0 f32[1, 512, 4096]" - # t1837 = prims.mul(t1826, t1836) # t1837: "cuda:0 f32[1, 512, 4096]" - t1838 = prims.convert_element_type(t1837, dtypes.bfloat16) # t1838: "cuda:0 bf16[1, 512, 4096]" - t1839 = prims.broadcast_in_dim(t_transformer_h_12_norm_2_weight, (1, 512, 4096), (2,)) # t1839: "cuda:0 bf16[1, 512, 4096]" - t1840 = prims.convert_element_type(t1838, dtypes.float32) # t1840: "cuda:0 f32[1, 512, 4096]" - t1841 = prims.convert_element_type(t1839, dtypes.float32) # t1841: "cuda:0 f32[1, 512, 4096]" - t1842 = ltorch.mul(t1840, t1841) # t1842: "cuda:0 f32[1, 512, 4096]" - # t1842 = prims.mul(t1840, t1841) # t1842: "cuda:0 f32[1, 512, 4096]" - t1843 = prims.convert_element_type(t1842, dtypes.bfloat16) # t1843: "cuda:0 bf16[1, 512, 4096]" - t1844 = prims.linear(t1843, t_transformer_h_12_mlp_fc_1_weight, None) # t1844: "cuda:0 bf16[1, 512, 11008]" - t1845 = prims.linear(t1843, t_transformer_h_12_mlp_fc_2_weight, None) # t1845: "cuda:0 bf16[1, 512, 11008]" - t1846 = prims.convert_element_type(t1844, dtypes.float32) # t1846: "cuda:0 f32[1, 512, 11008]" - t1847 = prims.neg(t1846) # t1847: "cuda:0 f32[1, 512, 11008]" - t1848 = prims.exp(t1847) # t1848: "cuda:0 f32[1, 512, 11008]" - t1849 = ltorch.add(1.0, t1848, alpha=None) # t1849: "cuda:0 f32[1, 512, 11008]" - # t1849 = prims.add(1.0, t1848) # t1849: "cuda:0 f32[1, 512, 11008]" - t1850 = prims.reciprocal(t1849) # t1850: "cuda:0 f32[1, 512, 11008]" - t1851 = prims.convert_element_type(t1850, dtypes.bfloat16) # t1851: "cuda:0 bf16[1, 512, 11008]" - t1852 = prims.convert_element_type(t1844, dtypes.float32) # t1852: "cuda:0 f32[1, 512, 11008]" - t1853 = prims.convert_element_type(t1851, dtypes.float32) # t1853: "cuda:0 f32[1, 512, 11008]" - t1854 = ltorch.mul(t1852, t1853) # t1854: "cuda:0 f32[1, 512, 11008]" - # t1854 = prims.mul(t1852, t1853) # t1854: "cuda:0 f32[1, 512, 11008]" - t1855 = prims.convert_element_type(t1854, dtypes.bfloat16) # t1855: "cuda:0 bf16[1, 512, 11008]" - t1856 = prims.convert_element_type(t1855, dtypes.float32) # t1856: "cuda:0 f32[1, 512, 11008]" - t1857 = prims.convert_element_type(t1845, dtypes.float32) # t1857: "cuda:0 f32[1, 512, 11008]" - t1858 = ltorch.mul(t1856, t1857) # t1858: "cuda:0 f32[1, 512, 11008]" - # t1858 = prims.mul(t1856, t1857) # t1858: "cuda:0 f32[1, 512, 11008]" - t1859 = prims.convert_element_type(t1858, dtypes.bfloat16) # t1859: "cuda:0 bf16[1, 512, 11008]" - t1860 = prims.linear(t1859, t_transformer_h_12_mlp_proj_weight, None) # t1860: "cuda:0 bf16[1, 512, 4096]" - t1861 = prims.convert_element_type(t1860, dtypes.float32) # t1861: "cuda:0 f32[1, 512, 4096]" - t1862 = prims.convert_element_type(t1825, dtypes.float32) # t1862: "cuda:0 f32[1, 512, 4096]" - t1863 = ltorch.add(t1861, t1862, alpha=None) # t1863: "cuda:0 f32[1, 512, 4096]" - # t1863 = prims.add(t1861, t1862) # t1863: "cuda:0 f32[1, 512, 4096]" - t1864 = prims.convert_element_type(t1863, dtypes.bfloat16) # t1864: "cuda:0 bf16[1, 512, 4096]" - t1865 = prims.convert_element_type(t1864, dtypes.float32) # t1865: "cuda:0 f32[1, 512, 4096]" - t1866 = ltorch.mul(t1865, t1865) # t1866: "cuda:0 f32[1, 512, 4096]" - # t1866 = prims.mul(t1865, t1865) # t1866: "cuda:0 f32[1, 512, 4096]" - t1868 = prims.sum(t1866, (2,)) # t1868: "cuda:0 f32[1, 512]" - t1869 = prims.broadcast_in_dim(t1868, [1, 512, 1], [0, 1]) # t1869: "cuda:0 f32[1, 512, 1]" - t1871 = ltorch.true_divide(t1869, 4096.0) # t1871: "cuda:0 f32[1, 512, 1]" - # t1871 = prims.div(t1869, 4096.0) # t1871: "cuda:0 f32[1, 512, 1]" - t1873 = ltorch.add(t1871, 1e-05, alpha=None) # t1873: "cuda:0 f32[1, 512, 1]" - # t1873 = prims.add(t1871, 1e-05) # t1873: "cuda:0 f32[1, 512, 1]" - t1874 = prims.rsqrt(t1873) # t1874: "cuda:0 f32[1, 512, 1]" - t1875 = prims.broadcast_in_dim(t1874, (1, 512, 4096), (0, 1, 2)) # t1875: "cuda:0 f32[1, 512, 4096]" - t1876 = ltorch.mul(t1865, t1875) # t1876: "cuda:0 f32[1, 512, 4096]" - # t1876 = prims.mul(t1865, t1875) # t1876: "cuda:0 f32[1, 512, 4096]" - t1877 = prims.convert_element_type(t1876, dtypes.bfloat16) # t1877: "cuda:0 bf16[1, 512, 4096]" - t1878 = prims.broadcast_in_dim(t_transformer_h_13_norm_1_weight, (1, 512, 4096), (2,)) # t1878: "cuda:0 bf16[1, 512, 4096]" - t1879 = prims.convert_element_type(t1877, dtypes.float32) # t1879: "cuda:0 f32[1, 512, 4096]" - t1880 = prims.convert_element_type(t1878, dtypes.float32) # t1880: "cuda:0 f32[1, 512, 4096]" - t1881 = ltorch.mul(t1879, t1880) # t1881: "cuda:0 f32[1, 512, 4096]" - # t1881 = prims.mul(t1879, t1880) # t1881: "cuda:0 f32[1, 512, 4096]" - t1882 = prims.convert_element_type(t1881, dtypes.bfloat16) # t1882: "cuda:0 bf16[1, 512, 4096]" - t1883 = prims.linear(t1882, t_transformer_h_13_attn_attn_weight, None) # t1883: "cuda:0 bf16[1, 512, 12288]" - t1889 = prims.reshape(t1883, (1, 512, 32, 3, 128)) # t1889: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1895 = prims.transpose(t1889, (0, 2, 3, 1, 4)) # t1895: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1896, t1897, t1898) = ltorch.split(t1895, (1, 1, 1), 2) - # t1896 = prims.slice_prim(t1895, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1896: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1897 = prims.slice_prim(t1895, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1897: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1898 = prims.slice_prim(t1895, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1898: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1904 = prims.reshape(t1896, (1, 32, 512, 128)) # t1904: "cuda:0 bf16[1, 32, 512, 128]" - t1910 = prims.reshape(t1897, (1, 32, 512, 128)) # t1910: "cuda:0 bf16[1, 32, 512, 128]" - t1916 = prims.reshape(t1898, (1, 32, 512, 128)) # t1916: "cuda:0 bf16[1, 32, 512, 128]" - t1917 = prims.slice_prim(t1904, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1917: "cuda:0 bf16[1, 32, 512, 128]" - t1918 = prims.slice_prim(t1917, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1918: "cuda:0 bf16[1, 32, 512, 64]" - t1919 = prims.slice_prim(t1917, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1919: "cuda:0 bf16[1, 32, 512, 64]" - t1920 = prims.convert_element_type(t1919, dtypes.float32) # t1920: "cuda:0 f32[1, 32, 512, 64]" - t1921 = prims.neg(t1920) # t1921: "cuda:0 f32[1, 32, 512, 64]" - t1922 = prims.convert_element_type(t1921, dtypes.bfloat16) # t1922: "cuda:0 bf16[1, 32, 512, 64]" - t1924 = prims.cat((t1922, t1918), -1) # t1924: "cuda:0 bf16[1, 32, 512, 128]" - t1925 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1925: "cuda:0 f32[1, 32, 512, 128]" - t1926 = prims.convert_element_type(t1917, dtypes.float32) # t1926: "cuda:0 f32[1, 32, 512, 128]" - t1927 = ltorch.mul(t1926, t1925) # t1927: "cuda:0 f32[1, 32, 512, 128]" - # t1927 = prims.mul(t1926, t1925) # t1927: "cuda:0 f32[1, 32, 512, 128]" - t1928 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1928: "cuda:0 f32[1, 32, 512, 128]" - t1929 = prims.convert_element_type(t1924, dtypes.float32) # t1929: "cuda:0 f32[1, 32, 512, 128]" - t1930 = ltorch.mul(t1929, t1928) # t1930: "cuda:0 f32[1, 32, 512, 128]" - # t1930 = prims.mul(t1929, t1928) # t1930: "cuda:0 f32[1, 32, 512, 128]" - t1931 = ltorch.add(t1927, t1930, alpha=None) # t1931: "cuda:0 f32[1, 32, 512, 128]" - # t1931 = prims.add(t1927, t1930) # t1931: "cuda:0 f32[1, 32, 512, 128]" - t1932 = prims.convert_element_type(t1931, dtypes.bfloat16) # t1932: "cuda:0 bf16[1, 32, 512, 128]" - t1933 = prims.slice_prim(t1910, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1933: "cuda:0 bf16[1, 32, 512, 128]" - t1934 = prims.slice_prim(t1933, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1934: "cuda:0 bf16[1, 32, 512, 64]" - t1935 = prims.slice_prim(t1933, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1935: "cuda:0 bf16[1, 32, 512, 64]" - t1936 = prims.convert_element_type(t1935, dtypes.float32) # t1936: "cuda:0 f32[1, 32, 512, 64]" - t1937 = prims.neg(t1936) # t1937: "cuda:0 f32[1, 32, 512, 64]" - t1938 = prims.convert_element_type(t1937, dtypes.bfloat16) # t1938: "cuda:0 bf16[1, 32, 512, 64]" - t1940 = prims.cat((t1938, t1934), -1) # t1940: "cuda:0 bf16[1, 32, 512, 128]" - t1941 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1941: "cuda:0 f32[1, 32, 512, 128]" - t1942 = prims.convert_element_type(t1933, dtypes.float32) # t1942: "cuda:0 f32[1, 32, 512, 128]" - t1943 = ltorch.mul(t1942, t1941) # t1943: "cuda:0 f32[1, 32, 512, 128]" - # t1943 = prims.mul(t1942, t1941) # t1943: "cuda:0 f32[1, 32, 512, 128]" - t1944 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1944: "cuda:0 f32[1, 32, 512, 128]" - t1945 = prims.convert_element_type(t1940, dtypes.float32) # t1945: "cuda:0 f32[1, 32, 512, 128]" - t1946 = ltorch.mul(t1945, t1944) # t1946: "cuda:0 f32[1, 32, 512, 128]" - # t1946 = prims.mul(t1945, t1944) # t1946: "cuda:0 f32[1, 32, 512, 128]" - t1947 = ltorch.add(t1943, t1946, alpha=None) # t1947: "cuda:0 f32[1, 32, 512, 128]" - # t1947 = prims.add(t1943, t1946) # t1947: "cuda:0 f32[1, 32, 512, 128]" - t1948 = prims.convert_element_type(t1947, dtypes.bfloat16) # t1948: "cuda:0 bf16[1, 32, 512, 128]" - t1949 = prims.slice_prim(t1904, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1949: "cuda:0 bf16[1, 32, 512, 0]" - t1951 = prims.cat((t1932, t1949), -1) # t1951: "cuda:0 bf16[1, 32, 512, 128]" - t1952 = prims.slice_prim(t1910, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1952: "cuda:0 bf16[1, 32, 512, 0]" - t1954 = prims.cat((t1948, t1952), -1) # t1954: "cuda:0 bf16[1, 32, 512, 128]" - (t1955, t1956, t1957, t1958) = cudnn_sdpa_fwd(t1951, t1954, t1916, None, 0.0, True, scale=0.08838834764831843) - t1961 = prims.transpose(t1955, (0, 2, 1, 3)) # t1961: "cuda:0 bf16[1, 512, 32, 128]" - t1965 = prims.reshape(t1961, (1, 512, 4096)) # t1965: "cuda:0 bf16[1, 512, 4096]" - t1966 = prims.linear(t1965, t_transformer_h_13_attn_proj_weight, None) # t1966: "cuda:0 bf16[1, 512, 4096]" - t1967 = prims.convert_element_type(t1966, dtypes.float32) # t1967: "cuda:0 f32[1, 512, 4096]" - t1968 = prims.convert_element_type(t1864, dtypes.float32) # t1968: "cuda:0 f32[1, 512, 4096]" - t1969 = ltorch.add(t1967, t1968, alpha=None) # t1969: "cuda:0 f32[1, 512, 4096]" - # t1969 = prims.add(t1967, t1968) # t1969: "cuda:0 f32[1, 512, 4096]" - t1970 = prims.convert_element_type(t1969, dtypes.bfloat16) # t1970: "cuda:0 bf16[1, 512, 4096]" - t1971 = prims.convert_element_type(t1970, dtypes.float32) # t1971: "cuda:0 f32[1, 512, 4096]" - t1972 = ltorch.mul(t1971, t1971) # t1972: "cuda:0 f32[1, 512, 4096]" - # t1972 = prims.mul(t1971, t1971) # t1972: "cuda:0 f32[1, 512, 4096]" - t1974 = prims.sum(t1972, (2,)) # t1974: "cuda:0 f32[1, 512]" - t1975 = prims.broadcast_in_dim(t1974, [1, 512, 1], [0, 1]) # t1975: "cuda:0 f32[1, 512, 1]" - t1977 = ltorch.true_divide(t1975, 4096.0) # t1977: "cuda:0 f32[1, 512, 1]" - # t1977 = prims.div(t1975, 4096.0) # t1977: "cuda:0 f32[1, 512, 1]" - t1979 = ltorch.add(t1977, 1e-05, alpha=None) # t1979: "cuda:0 f32[1, 512, 1]" - # t1979 = prims.add(t1977, 1e-05) # t1979: "cuda:0 f32[1, 512, 1]" - t1980 = prims.rsqrt(t1979) # t1980: "cuda:0 f32[1, 512, 1]" - t1981 = prims.broadcast_in_dim(t1980, (1, 512, 4096), (0, 1, 2)) # t1981: "cuda:0 f32[1, 512, 4096]" - t1982 = ltorch.mul(t1971, t1981) # t1982: "cuda:0 f32[1, 512, 4096]" - # t1982 = prims.mul(t1971, t1981) # t1982: "cuda:0 f32[1, 512, 4096]" - t1983 = prims.convert_element_type(t1982, dtypes.bfloat16) # t1983: "cuda:0 bf16[1, 512, 4096]" - t1984 = prims.broadcast_in_dim(t_transformer_h_13_norm_2_weight, (1, 512, 4096), (2,)) # t1984: "cuda:0 bf16[1, 512, 4096]" - t1985 = prims.convert_element_type(t1983, dtypes.float32) # t1985: "cuda:0 f32[1, 512, 4096]" - t1986 = prims.convert_element_type(t1984, dtypes.float32) # t1986: "cuda:0 f32[1, 512, 4096]" - t1987 = ltorch.mul(t1985, t1986) # t1987: "cuda:0 f32[1, 512, 4096]" - # t1987 = prims.mul(t1985, t1986) # t1987: "cuda:0 f32[1, 512, 4096]" - t1988 = prims.convert_element_type(t1987, dtypes.bfloat16) # t1988: "cuda:0 bf16[1, 512, 4096]" - t1989 = prims.linear(t1988, t_transformer_h_13_mlp_fc_1_weight, None) # t1989: "cuda:0 bf16[1, 512, 11008]" - t1990 = prims.linear(t1988, t_transformer_h_13_mlp_fc_2_weight, None) # t1990: "cuda:0 bf16[1, 512, 11008]" - t1991 = prims.convert_element_type(t1989, dtypes.float32) # t1991: "cuda:0 f32[1, 512, 11008]" - t1992 = prims.neg(t1991) # t1992: "cuda:0 f32[1, 512, 11008]" - t1993 = prims.exp(t1992) # t1993: "cuda:0 f32[1, 512, 11008]" - t1994 = ltorch.add(1.0, t1993, alpha=None) # t1994: "cuda:0 f32[1, 512, 11008]" - # t1994 = prims.add(1.0, t1993) # t1994: "cuda:0 f32[1, 512, 11008]" - t1995 = prims.reciprocal(t1994) # t1995: "cuda:0 f32[1, 512, 11008]" - t1996 = prims.convert_element_type(t1995, dtypes.bfloat16) # t1996: "cuda:0 bf16[1, 512, 11008]" - t1997 = prims.convert_element_type(t1989, dtypes.float32) # t1997: "cuda:0 f32[1, 512, 11008]" - t1998 = prims.convert_element_type(t1996, dtypes.float32) # t1998: "cuda:0 f32[1, 512, 11008]" - t1999 = ltorch.mul(t1997, t1998) # t1999: "cuda:0 f32[1, 512, 11008]" - # t1999 = prims.mul(t1997, t1998) # t1999: "cuda:0 f32[1, 512, 11008]" - t2000 = prims.convert_element_type(t1999, dtypes.bfloat16) # t2000: "cuda:0 bf16[1, 512, 11008]" - t2001 = prims.convert_element_type(t2000, dtypes.float32) # t2001: "cuda:0 f32[1, 512, 11008]" - t2002 = prims.convert_element_type(t1990, dtypes.float32) # t2002: "cuda:0 f32[1, 512, 11008]" - t2003 = ltorch.mul(t2001, t2002) # t2003: "cuda:0 f32[1, 512, 11008]" - # t2003 = prims.mul(t2001, t2002) # t2003: "cuda:0 f32[1, 512, 11008]" - t2004 = prims.convert_element_type(t2003, dtypes.bfloat16) # t2004: "cuda:0 bf16[1, 512, 11008]" - t2005 = prims.linear(t2004, t_transformer_h_13_mlp_proj_weight, None) # t2005: "cuda:0 bf16[1, 512, 4096]" - t2006 = prims.convert_element_type(t2005, dtypes.float32) # t2006: "cuda:0 f32[1, 512, 4096]" - t2007 = prims.convert_element_type(t1970, dtypes.float32) # t2007: "cuda:0 f32[1, 512, 4096]" - t2008 = ltorch.add(t2006, t2007, alpha=None) # t2008: "cuda:0 f32[1, 512, 4096]" - # t2008 = prims.add(t2006, t2007) # t2008: "cuda:0 f32[1, 512, 4096]" - t2009 = prims.convert_element_type(t2008, dtypes.bfloat16) # t2009: "cuda:0 bf16[1, 512, 4096]" - t2010 = prims.convert_element_type(t2009, dtypes.float32) # t2010: "cuda:0 f32[1, 512, 4096]" - t2011 = ltorch.mul(t2010, t2010) # t2011: "cuda:0 f32[1, 512, 4096]" - # t2011 = prims.mul(t2010, t2010) # t2011: "cuda:0 f32[1, 512, 4096]" - t2013 = prims.sum(t2011, (2,)) # t2013: "cuda:0 f32[1, 512]" - t2014 = prims.broadcast_in_dim(t2013, [1, 512, 1], [0, 1]) # t2014: "cuda:0 f32[1, 512, 1]" - t2016 = ltorch.true_divide(t2014, 4096.0) # t2016: "cuda:0 f32[1, 512, 1]" - # t2016 = prims.div(t2014, 4096.0) # t2016: "cuda:0 f32[1, 512, 1]" - t2018 = ltorch.add(t2016, 1e-05, alpha=None) # t2018: "cuda:0 f32[1, 512, 1]" - # t2018 = prims.add(t2016, 1e-05) # t2018: "cuda:0 f32[1, 512, 1]" - t2019 = prims.rsqrt(t2018) # t2019: "cuda:0 f32[1, 512, 1]" - t2020 = prims.broadcast_in_dim(t2019, (1, 512, 4096), (0, 1, 2)) # t2020: "cuda:0 f32[1, 512, 4096]" - t2021 = ltorch.mul(t2010, t2020) # t2021: "cuda:0 f32[1, 512, 4096]" - # t2021 = prims.mul(t2010, t2020) # t2021: "cuda:0 f32[1, 512, 4096]" - t2022 = prims.convert_element_type(t2021, dtypes.bfloat16) # t2022: "cuda:0 bf16[1, 512, 4096]" - t2023 = prims.broadcast_in_dim(t_transformer_h_14_norm_1_weight, (1, 512, 4096), (2,)) # t2023: "cuda:0 bf16[1, 512, 4096]" - t2024 = prims.convert_element_type(t2022, dtypes.float32) # t2024: "cuda:0 f32[1, 512, 4096]" - t2025 = prims.convert_element_type(t2023, dtypes.float32) # t2025: "cuda:0 f32[1, 512, 4096]" - t2026 = ltorch.mul(t2024, t2025) # t2026: "cuda:0 f32[1, 512, 4096]" - # t2026 = prims.mul(t2024, t2025) # t2026: "cuda:0 f32[1, 512, 4096]" - t2027 = prims.convert_element_type(t2026, dtypes.bfloat16) # t2027: "cuda:0 bf16[1, 512, 4096]" - t2028 = prims.linear(t2027, t_transformer_h_14_attn_attn_weight, None) # t2028: "cuda:0 bf16[1, 512, 12288]" - t2034 = prims.reshape(t2028, (1, 512, 32, 3, 128)) # t2034: "cuda:0 bf16[1, 512, 32, 3, 128]" - t2040 = prims.transpose(t2034, (0, 2, 3, 1, 4)) # t2040: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t2041, t2042, t2043) = ltorch.split(t2040, (1, 1, 1), 2) - # t2041 = prims.slice_prim(t2040, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t2041: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2042 = prims.slice_prim(t2040, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t2042: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2043 = prims.slice_prim(t2040, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t2043: "cuda:0 bf16[1, 32, 1, 512, 128]" - t2049 = prims.reshape(t2041, (1, 32, 512, 128)) # t2049: "cuda:0 bf16[1, 32, 512, 128]" - t2055 = prims.reshape(t2042, (1, 32, 512, 128)) # t2055: "cuda:0 bf16[1, 32, 512, 128]" - t2061 = prims.reshape(t2043, (1, 32, 512, 128)) # t2061: "cuda:0 bf16[1, 32, 512, 128]" - t2062 = prims.slice_prim(t2049, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2062: "cuda:0 bf16[1, 32, 512, 128]" - t2063 = prims.slice_prim(t2062, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2063: "cuda:0 bf16[1, 32, 512, 64]" - t2064 = prims.slice_prim(t2062, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2064: "cuda:0 bf16[1, 32, 512, 64]" - t2065 = prims.convert_element_type(t2064, dtypes.float32) # t2065: "cuda:0 f32[1, 32, 512, 64]" - t2066 = prims.neg(t2065) # t2066: "cuda:0 f32[1, 32, 512, 64]" - t2067 = prims.convert_element_type(t2066, dtypes.bfloat16) # t2067: "cuda:0 bf16[1, 32, 512, 64]" - t2069 = prims.cat((t2067, t2063), -1) # t2069: "cuda:0 bf16[1, 32, 512, 128]" - t2070 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t2070: "cuda:0 f32[1, 32, 512, 128]" - t2071 = prims.convert_element_type(t2062, dtypes.float32) # t2071: "cuda:0 f32[1, 32, 512, 128]" - t2072 = ltorch.mul(t2071, t2070) # t2072: "cuda:0 f32[1, 32, 512, 128]" - # t2072 = prims.mul(t2071, t2070) # t2072: "cuda:0 f32[1, 32, 512, 128]" - t2073 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t2073: "cuda:0 f32[1, 32, 512, 128]" - t2074 = prims.convert_element_type(t2069, dtypes.float32) # t2074: "cuda:0 f32[1, 32, 512, 128]" - t2075 = ltorch.mul(t2074, t2073) # t2075: "cuda:0 f32[1, 32, 512, 128]" - # t2075 = prims.mul(t2074, t2073) # t2075: "cuda:0 f32[1, 32, 512, 128]" - t2076 = ltorch.add(t2072, t2075, alpha=None) # t2076: "cuda:0 f32[1, 32, 512, 128]" - # t2076 = prims.add(t2072, t2075) # t2076: "cuda:0 f32[1, 32, 512, 128]" - t2077 = prims.convert_element_type(t2076, dtypes.bfloat16) # t2077: "cuda:0 bf16[1, 32, 512, 128]" - t2078 = prims.slice_prim(t2055, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2078: "cuda:0 bf16[1, 32, 512, 128]" - t2079 = prims.slice_prim(t2078, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2079: "cuda:0 bf16[1, 32, 512, 64]" - t2080 = prims.slice_prim(t2078, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2080: "cuda:0 bf16[1, 32, 512, 64]" - t2081 = prims.convert_element_type(t2080, dtypes.float32) # t2081: "cuda:0 f32[1, 32, 512, 64]" - t2082 = prims.neg(t2081) # t2082: "cuda:0 f32[1, 32, 512, 64]" - t2083 = prims.convert_element_type(t2082, dtypes.bfloat16) # t2083: "cuda:0 bf16[1, 32, 512, 64]" - t2085 = prims.cat((t2083, t2079), -1) # t2085: "cuda:0 bf16[1, 32, 512, 128]" - t2086 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t2086: "cuda:0 f32[1, 32, 512, 128]" - t2087 = prims.convert_element_type(t2078, dtypes.float32) # t2087: "cuda:0 f32[1, 32, 512, 128]" - t2088 = ltorch.mul(t2087, t2086) # t2088: "cuda:0 f32[1, 32, 512, 128]" - # t2088 = prims.mul(t2087, t2086) # t2088: "cuda:0 f32[1, 32, 512, 128]" - t2089 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t2089: "cuda:0 f32[1, 32, 512, 128]" - t2090 = prims.convert_element_type(t2085, dtypes.float32) # t2090: "cuda:0 f32[1, 32, 512, 128]" - t2091 = ltorch.mul(t2090, t2089) # t2091: "cuda:0 f32[1, 32, 512, 128]" - # t2091 = prims.mul(t2090, t2089) # t2091: "cuda:0 f32[1, 32, 512, 128]" - t2092 = ltorch.add(t2088, t2091, alpha=None) # t2092: "cuda:0 f32[1, 32, 512, 128]" - # t2092 = prims.add(t2088, t2091) # t2092: "cuda:0 f32[1, 32, 512, 128]" - t2093 = prims.convert_element_type(t2092, dtypes.bfloat16) # t2093: "cuda:0 bf16[1, 32, 512, 128]" - t2094 = prims.slice_prim(t2049, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2094: "cuda:0 bf16[1, 32, 512, 0]" - t2096 = prims.cat((t2077, t2094), -1) # t2096: "cuda:0 bf16[1, 32, 512, 128]" - t2097 = prims.slice_prim(t2055, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2097: "cuda:0 bf16[1, 32, 512, 0]" - t2099 = prims.cat((t2093, t2097), -1) # t2099: "cuda:0 bf16[1, 32, 512, 128]" - (t2100, t2101, t2102, t2103) = cudnn_sdpa_fwd(t2096, t2099, t2061, None, 0.0, True, scale=0.08838834764831843) - t2106 = prims.transpose(t2100, (0, 2, 1, 3)) # t2106: "cuda:0 bf16[1, 512, 32, 128]" - t2110 = prims.reshape(t2106, (1, 512, 4096)) # t2110: "cuda:0 bf16[1, 512, 4096]" - t2111 = prims.linear(t2110, t_transformer_h_14_attn_proj_weight, None) # t2111: "cuda:0 bf16[1, 512, 4096]" - t2112 = prims.convert_element_type(t2111, dtypes.float32) # t2112: "cuda:0 f32[1, 512, 4096]" - t2113 = prims.convert_element_type(t2009, dtypes.float32) # t2113: "cuda:0 f32[1, 512, 4096]" - t2114 = ltorch.add(t2112, t2113, alpha=None) # t2114: "cuda:0 f32[1, 512, 4096]" - # t2114 = prims.add(t2112, t2113) # t2114: "cuda:0 f32[1, 512, 4096]" - t2115 = prims.convert_element_type(t2114, dtypes.bfloat16) # t2115: "cuda:0 bf16[1, 512, 4096]" - t2116 = prims.convert_element_type(t2115, dtypes.float32) # t2116: "cuda:0 f32[1, 512, 4096]" - t2117 = ltorch.mul(t2116, t2116) # t2117: "cuda:0 f32[1, 512, 4096]" - # t2117 = prims.mul(t2116, t2116) # t2117: "cuda:0 f32[1, 512, 4096]" - t2119 = prims.sum(t2117, (2,)) # t2119: "cuda:0 f32[1, 512]" - t2120 = prims.broadcast_in_dim(t2119, [1, 512, 1], [0, 1]) # t2120: "cuda:0 f32[1, 512, 1]" - t2122 = ltorch.true_divide(t2120, 4096.0) # t2122: "cuda:0 f32[1, 512, 1]" - # t2122 = prims.div(t2120, 4096.0) # t2122: "cuda:0 f32[1, 512, 1]" - t2124 = ltorch.add(t2122, 1e-05, alpha=None) # t2124: "cuda:0 f32[1, 512, 1]" - # t2124 = prims.add(t2122, 1e-05) # t2124: "cuda:0 f32[1, 512, 1]" - t2125 = prims.rsqrt(t2124) # t2125: "cuda:0 f32[1, 512, 1]" - t2126 = prims.broadcast_in_dim(t2125, (1, 512, 4096), (0, 1, 2)) # t2126: "cuda:0 f32[1, 512, 4096]" - t2127 = ltorch.mul(t2116, t2126) # t2127: "cuda:0 f32[1, 512, 4096]" - # t2127 = prims.mul(t2116, t2126) # t2127: "cuda:0 f32[1, 512, 4096]" - t2128 = prims.convert_element_type(t2127, dtypes.bfloat16) # t2128: "cuda:0 bf16[1, 512, 4096]" - t2129 = prims.broadcast_in_dim(t_transformer_h_14_norm_2_weight, (1, 512, 4096), (2,)) # t2129: "cuda:0 bf16[1, 512, 4096]" - t2130 = prims.convert_element_type(t2128, dtypes.float32) # t2130: "cuda:0 f32[1, 512, 4096]" - t2131 = prims.convert_element_type(t2129, dtypes.float32) # t2131: "cuda:0 f32[1, 512, 4096]" - t2132 = ltorch.mul(t2130, t2131) # t2132: "cuda:0 f32[1, 512, 4096]" - # t2132 = prims.mul(t2130, t2131) # t2132: "cuda:0 f32[1, 512, 4096]" - t2133 = prims.convert_element_type(t2132, dtypes.bfloat16) # t2133: "cuda:0 bf16[1, 512, 4096]" - t2134 = prims.linear(t2133, t_transformer_h_14_mlp_fc_1_weight, None) # t2134: "cuda:0 bf16[1, 512, 11008]" - t2135 = prims.linear(t2133, t_transformer_h_14_mlp_fc_2_weight, None) # t2135: "cuda:0 bf16[1, 512, 11008]" - t2136 = prims.convert_element_type(t2134, dtypes.float32) # t2136: "cuda:0 f32[1, 512, 11008]" - t2137 = prims.neg(t2136) # t2137: "cuda:0 f32[1, 512, 11008]" - t2138 = prims.exp(t2137) # t2138: "cuda:0 f32[1, 512, 11008]" - t2139 = ltorch.add(1.0, t2138, alpha=None) # t2139: "cuda:0 f32[1, 512, 11008]" - # t2139 = prims.add(1.0, t2138) # t2139: "cuda:0 f32[1, 512, 11008]" - t2140 = prims.reciprocal(t2139) # t2140: "cuda:0 f32[1, 512, 11008]" - t2141 = prims.convert_element_type(t2140, dtypes.bfloat16) # t2141: "cuda:0 bf16[1, 512, 11008]" - t2142 = prims.convert_element_type(t2134, dtypes.float32) # t2142: "cuda:0 f32[1, 512, 11008]" - t2143 = prims.convert_element_type(t2141, dtypes.float32) # t2143: "cuda:0 f32[1, 512, 11008]" - t2144 = ltorch.mul(t2142, t2143) # t2144: "cuda:0 f32[1, 512, 11008]" - # t2144 = prims.mul(t2142, t2143) # t2144: "cuda:0 f32[1, 512, 11008]" - t2145 = prims.convert_element_type(t2144, dtypes.bfloat16) # t2145: "cuda:0 bf16[1, 512, 11008]" - t2146 = prims.convert_element_type(t2145, dtypes.float32) # t2146: "cuda:0 f32[1, 512, 11008]" - t2147 = prims.convert_element_type(t2135, dtypes.float32) # t2147: "cuda:0 f32[1, 512, 11008]" - t2148 = ltorch.mul(t2146, t2147) # t2148: "cuda:0 f32[1, 512, 11008]" - # t2148 = prims.mul(t2146, t2147) # t2148: "cuda:0 f32[1, 512, 11008]" - t2149 = prims.convert_element_type(t2148, dtypes.bfloat16) # t2149: "cuda:0 bf16[1, 512, 11008]" - t2150 = prims.linear(t2149, t_transformer_h_14_mlp_proj_weight, None) # t2150: "cuda:0 bf16[1, 512, 4096]" - t2151 = prims.convert_element_type(t2150, dtypes.float32) # t2151: "cuda:0 f32[1, 512, 4096]" - t2152 = prims.convert_element_type(t2115, dtypes.float32) # t2152: "cuda:0 f32[1, 512, 4096]" - t2153 = ltorch.add(t2151, t2152, alpha=None) # t2153: "cuda:0 f32[1, 512, 4096]" - # t2153 = prims.add(t2151, t2152) # t2153: "cuda:0 f32[1, 512, 4096]" - t2154 = prims.convert_element_type(t2153, dtypes.bfloat16) # t2154: "cuda:0 bf16[1, 512, 4096]" - t2155 = prims.convert_element_type(t2154, dtypes.float32) # t2155: "cuda:0 f32[1, 512, 4096]" - t2156 = ltorch.mul(t2155, t2155) # t2156: "cuda:0 f32[1, 512, 4096]" - # t2156 = prims.mul(t2155, t2155) # t2156: "cuda:0 f32[1, 512, 4096]" - t2158 = prims.sum(t2156, (2,)) # t2158: "cuda:0 f32[1, 512]" - t2159 = prims.broadcast_in_dim(t2158, [1, 512, 1], [0, 1]) # t2159: "cuda:0 f32[1, 512, 1]" - t2161 = ltorch.true_divide(t2159, 4096.0) # t2161: "cuda:0 f32[1, 512, 1]" - # t2161 = prims.div(t2159, 4096.0) # t2161: "cuda:0 f32[1, 512, 1]" - t2163 = ltorch.add(t2161, 1e-05, alpha=None) # t2163: "cuda:0 f32[1, 512, 1]" - # t2163 = prims.add(t2161, 1e-05) # t2163: "cuda:0 f32[1, 512, 1]" - t2164 = prims.rsqrt(t2163) # t2164: "cuda:0 f32[1, 512, 1]" - t2165 = prims.broadcast_in_dim(t2164, (1, 512, 4096), (0, 1, 2)) # t2165: "cuda:0 f32[1, 512, 4096]" - t2166 = ltorch.mul(t2155, t2165) # t2166: "cuda:0 f32[1, 512, 4096]" - # t2166 = prims.mul(t2155, t2165) # t2166: "cuda:0 f32[1, 512, 4096]" - t2167 = prims.convert_element_type(t2166, dtypes.bfloat16) # t2167: "cuda:0 bf16[1, 512, 4096]" - t2168 = prims.broadcast_in_dim(t_transformer_h_15_norm_1_weight, (1, 512, 4096), (2,)) # t2168: "cuda:0 bf16[1, 512, 4096]" - t2169 = prims.convert_element_type(t2167, dtypes.float32) # t2169: "cuda:0 f32[1, 512, 4096]" - t2170 = prims.convert_element_type(t2168, dtypes.float32) # t2170: "cuda:0 f32[1, 512, 4096]" - t2171 = ltorch.mul(t2169, t2170) # t2171: "cuda:0 f32[1, 512, 4096]" - # t2171 = prims.mul(t2169, t2170) # t2171: "cuda:0 f32[1, 512, 4096]" - t2172 = prims.convert_element_type(t2171, dtypes.bfloat16) # t2172: "cuda:0 bf16[1, 512, 4096]" - t2173 = prims.linear(t2172, t_transformer_h_15_attn_attn_weight, None) # t2173: "cuda:0 bf16[1, 512, 12288]" - t2179 = prims.reshape(t2173, (1, 512, 32, 3, 128)) # t2179: "cuda:0 bf16[1, 512, 32, 3, 128]" - t2185 = prims.transpose(t2179, (0, 2, 3, 1, 4)) # t2185: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t2186, t2187, t2188) = ltorch.split(t2185, (1, 1, 1), 2) - # t2186 = prims.slice_prim(t2185, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t2186: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2187 = prims.slice_prim(t2185, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t2187: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2188 = prims.slice_prim(t2185, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t2188: "cuda:0 bf16[1, 32, 1, 512, 128]" - t2194 = prims.reshape(t2186, (1, 32, 512, 128)) # t2194: "cuda:0 bf16[1, 32, 512, 128]" - t2200 = prims.reshape(t2187, (1, 32, 512, 128)) # t2200: "cuda:0 bf16[1, 32, 512, 128]" - t2206 = prims.reshape(t2188, (1, 32, 512, 128)) # t2206: "cuda:0 bf16[1, 32, 512, 128]" - t2207 = prims.slice_prim(t2194, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2207: "cuda:0 bf16[1, 32, 512, 128]" - t2208 = prims.slice_prim(t2207, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2208: "cuda:0 bf16[1, 32, 512, 64]" - t2209 = prims.slice_prim(t2207, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2209: "cuda:0 bf16[1, 32, 512, 64]" - t2210 = prims.convert_element_type(t2209, dtypes.float32) # t2210: "cuda:0 f32[1, 32, 512, 64]" - t2211 = prims.neg(t2210) # t2211: "cuda:0 f32[1, 32, 512, 64]" - t2212 = prims.convert_element_type(t2211, dtypes.bfloat16) # t2212: "cuda:0 bf16[1, 32, 512, 64]" - t2214 = prims.cat((t2212, t2208), -1) # t2214: "cuda:0 bf16[1, 32, 512, 128]" - t2215 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t2215: "cuda:0 f32[1, 32, 512, 128]" - t2216 = prims.convert_element_type(t2207, dtypes.float32) # t2216: "cuda:0 f32[1, 32, 512, 128]" - t2217 = ltorch.mul(t2216, t2215) # t2217: "cuda:0 f32[1, 32, 512, 128]" - # t2217 = prims.mul(t2216, t2215) # t2217: "cuda:0 f32[1, 32, 512, 128]" - t2218 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t2218: "cuda:0 f32[1, 32, 512, 128]" - t2219 = prims.convert_element_type(t2214, dtypes.float32) # t2219: "cuda:0 f32[1, 32, 512, 128]" - t2220 = ltorch.mul(t2219, t2218) # t2220: "cuda:0 f32[1, 32, 512, 128]" - # t2220 = prims.mul(t2219, t2218) # t2220: "cuda:0 f32[1, 32, 512, 128]" - t2221 = ltorch.add(t2217, t2220, alpha=None) # t2221: "cuda:0 f32[1, 32, 512, 128]" - # t2221 = prims.add(t2217, t2220) # t2221: "cuda:0 f32[1, 32, 512, 128]" - t2222 = prims.convert_element_type(t2221, dtypes.bfloat16) # t2222: "cuda:0 bf16[1, 32, 512, 128]" - t2223 = prims.slice_prim(t2200, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2223: "cuda:0 bf16[1, 32, 512, 128]" - t2224 = prims.slice_prim(t2223, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2224: "cuda:0 bf16[1, 32, 512, 64]" - t2225 = prims.slice_prim(t2223, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2225: "cuda:0 bf16[1, 32, 512, 64]" - t2226 = prims.convert_element_type(t2225, dtypes.float32) # t2226: "cuda:0 f32[1, 32, 512, 64]" - t2227 = prims.neg(t2226) # t2227: "cuda:0 f32[1, 32, 512, 64]" - t2228 = prims.convert_element_type(t2227, dtypes.bfloat16) # t2228: "cuda:0 bf16[1, 32, 512, 64]" - t2230 = prims.cat((t2228, t2224), -1) # t2230: "cuda:0 bf16[1, 32, 512, 128]" - t2231 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t2231: "cuda:0 f32[1, 32, 512, 128]" - t2232 = prims.convert_element_type(t2223, dtypes.float32) # t2232: "cuda:0 f32[1, 32, 512, 128]" - t2233 = ltorch.mul(t2232, t2231) # t2233: "cuda:0 f32[1, 32, 512, 128]" - # t2233 = prims.mul(t2232, t2231) # t2233: "cuda:0 f32[1, 32, 512, 128]" - t2234 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t2234: "cuda:0 f32[1, 32, 512, 128]" - t2235 = prims.convert_element_type(t2230, dtypes.float32) # t2235: "cuda:0 f32[1, 32, 512, 128]" - t2236 = ltorch.mul(t2235, t2234) # t2236: "cuda:0 f32[1, 32, 512, 128]" - # t2236 = prims.mul(t2235, t2234) # t2236: "cuda:0 f32[1, 32, 512, 128]" - t2237 = ltorch.add(t2233, t2236, alpha=None) # t2237: "cuda:0 f32[1, 32, 512, 128]" - # t2237 = prims.add(t2233, t2236) # t2237: "cuda:0 f32[1, 32, 512, 128]" - t2238 = prims.convert_element_type(t2237, dtypes.bfloat16) # t2238: "cuda:0 bf16[1, 32, 512, 128]" - t2239 = prims.slice_prim(t2194, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2239: "cuda:0 bf16[1, 32, 512, 0]" - t2241 = prims.cat((t2222, t2239), -1) # t2241: "cuda:0 bf16[1, 32, 512, 128]" - t2242 = prims.slice_prim(t2200, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2242: "cuda:0 bf16[1, 32, 512, 0]" - t2244 = prims.cat((t2238, t2242), -1) # t2244: "cuda:0 bf16[1, 32, 512, 128]" - (t2245, t2246, t2247, t2248) = cudnn_sdpa_fwd(t2241, t2244, t2206, None, 0.0, True, scale=0.08838834764831843) - t2251 = prims.transpose(t2245, (0, 2, 1, 3)) # t2251: "cuda:0 bf16[1, 512, 32, 128]" - t2255 = prims.reshape(t2251, (1, 512, 4096)) # t2255: "cuda:0 bf16[1, 512, 4096]" - t2256 = prims.linear(t2255, t_transformer_h_15_attn_proj_weight, None) # t2256: "cuda:0 bf16[1, 512, 4096]" - t2257 = prims.convert_element_type(t2256, dtypes.float32) # t2257: "cuda:0 f32[1, 512, 4096]" - t2258 = prims.convert_element_type(t2154, dtypes.float32) # t2258: "cuda:0 f32[1, 512, 4096]" - t2259 = ltorch.add(t2257, t2258, alpha=None) # t2259: "cuda:0 f32[1, 512, 4096]" - # t2259 = prims.add(t2257, t2258) # t2259: "cuda:0 f32[1, 512, 4096]" - t2260 = prims.convert_element_type(t2259, dtypes.bfloat16) # t2260: "cuda:0 bf16[1, 512, 4096]" - t2261 = prims.convert_element_type(t2260, dtypes.float32) # t2261: "cuda:0 f32[1, 512, 4096]" - t2262 = ltorch.mul(t2261, t2261) # t2262: "cuda:0 f32[1, 512, 4096]" - # t2262 = prims.mul(t2261, t2261) # t2262: "cuda:0 f32[1, 512, 4096]" - t2264 = prims.sum(t2262, (2,)) # t2264: "cuda:0 f32[1, 512]" - t2265 = prims.broadcast_in_dim(t2264, [1, 512, 1], [0, 1]) # t2265: "cuda:0 f32[1, 512, 1]" - t2267 = ltorch.true_divide(t2265, 4096.0) # t2267: "cuda:0 f32[1, 512, 1]" - # t2267 = prims.div(t2265, 4096.0) # t2267: "cuda:0 f32[1, 512, 1]" - t2269 = ltorch.add(t2267, 1e-05, alpha=None) # t2269: "cuda:0 f32[1, 512, 1]" - # t2269 = prims.add(t2267, 1e-05) # t2269: "cuda:0 f32[1, 512, 1]" - t2270 = prims.rsqrt(t2269) # t2270: "cuda:0 f32[1, 512, 1]" - t2271 = prims.broadcast_in_dim(t2270, (1, 512, 4096), (0, 1, 2)) # t2271: "cuda:0 f32[1, 512, 4096]" - t2272 = ltorch.mul(t2261, t2271) # t2272: "cuda:0 f32[1, 512, 4096]" - # t2272 = prims.mul(t2261, t2271) # t2272: "cuda:0 f32[1, 512, 4096]" - t2273 = prims.convert_element_type(t2272, dtypes.bfloat16) # t2273: "cuda:0 bf16[1, 512, 4096]" - t2274 = prims.broadcast_in_dim(t_transformer_h_15_norm_2_weight, (1, 512, 4096), (2,)) # t2274: "cuda:0 bf16[1, 512, 4096]" - t2275 = prims.convert_element_type(t2273, dtypes.float32) # t2275: "cuda:0 f32[1, 512, 4096]" - t2276 = prims.convert_element_type(t2274, dtypes.float32) # t2276: "cuda:0 f32[1, 512, 4096]" - t2277 = ltorch.mul(t2275, t2276) # t2277: "cuda:0 f32[1, 512, 4096]" - # t2277 = prims.mul(t2275, t2276) # t2277: "cuda:0 f32[1, 512, 4096]" - t2278 = prims.convert_element_type(t2277, dtypes.bfloat16) # t2278: "cuda:0 bf16[1, 512, 4096]" - t2279 = prims.linear(t2278, t_transformer_h_15_mlp_fc_1_weight, None) # t2279: "cuda:0 bf16[1, 512, 11008]" - t2280 = prims.linear(t2278, t_transformer_h_15_mlp_fc_2_weight, None) # t2280: "cuda:0 bf16[1, 512, 11008]" - t2281 = prims.convert_element_type(t2279, dtypes.float32) # t2281: "cuda:0 f32[1, 512, 11008]" - t2282 = prims.neg(t2281) # t2282: "cuda:0 f32[1, 512, 11008]" - t2283 = prims.exp(t2282) # t2283: "cuda:0 f32[1, 512, 11008]" - t2284 = ltorch.add(1.0, t2283, alpha=None) # t2284: "cuda:0 f32[1, 512, 11008]" - # t2284 = prims.add(1.0, t2283) # t2284: "cuda:0 f32[1, 512, 11008]" - t2285 = prims.reciprocal(t2284) # t2285: "cuda:0 f32[1, 512, 11008]" - t2286 = prims.convert_element_type(t2285, dtypes.bfloat16) # t2286: "cuda:0 bf16[1, 512, 11008]" - t2287 = prims.convert_element_type(t2279, dtypes.float32) # t2287: "cuda:0 f32[1, 512, 11008]" - t2288 = prims.convert_element_type(t2286, dtypes.float32) # t2288: "cuda:0 f32[1, 512, 11008]" - t2289 = ltorch.mul(t2287, t2288) # t2289: "cuda:0 f32[1, 512, 11008]" - # t2289 = prims.mul(t2287, t2288) # t2289: "cuda:0 f32[1, 512, 11008]" - t2290 = prims.convert_element_type(t2289, dtypes.bfloat16) # t2290: "cuda:0 bf16[1, 512, 11008]" - t2291 = prims.convert_element_type(t2290, dtypes.float32) # t2291: "cuda:0 f32[1, 512, 11008]" - t2292 = prims.convert_element_type(t2280, dtypes.float32) # t2292: "cuda:0 f32[1, 512, 11008]" - t2293 = ltorch.mul(t2291, t2292) # t2293: "cuda:0 f32[1, 512, 11008]" - # t2293 = prims.mul(t2291, t2292) # t2293: "cuda:0 f32[1, 512, 11008]" - t2294 = prims.convert_element_type(t2293, dtypes.bfloat16) # t2294: "cuda:0 bf16[1, 512, 11008]" - t2295 = prims.linear(t2294, t_transformer_h_15_mlp_proj_weight, None) # t2295: "cuda:0 bf16[1, 512, 4096]" - t2296 = prims.convert_element_type(t2295, dtypes.float32) # t2296: "cuda:0 f32[1, 512, 4096]" - t2297 = prims.convert_element_type(t2260, dtypes.float32) # t2297: "cuda:0 f32[1, 512, 4096]" - t2298 = ltorch.add(t2296, t2297, alpha=None) # t2298: "cuda:0 f32[1, 512, 4096]" - # t2298 = prims.add(t2296, t2297) # t2298: "cuda:0 f32[1, 512, 4096]" - t2299 = prims.convert_element_type(t2298, dtypes.bfloat16) # t2299: "cuda:0 bf16[1, 512, 4096]" - t2300 = prims.convert_element_type(t2299, dtypes.float32) # t2300: "cuda:0 f32[1, 512, 4096]" - t2301 = ltorch.mul(t2300, t2300) # t2301: "cuda:0 f32[1, 512, 4096]" - # t2301 = prims.mul(t2300, t2300) # t2301: "cuda:0 f32[1, 512, 4096]" - t2303 = prims.sum(t2301, (2,)) # t2303: "cuda:0 f32[1, 512]" - t2304 = prims.broadcast_in_dim(t2303, [1, 512, 1], [0, 1]) # t2304: "cuda:0 f32[1, 512, 1]" - t2306 = ltorch.true_divide(t2304, 4096.0) # t2306: "cuda:0 f32[1, 512, 1]" - # t2306 = prims.div(t2304, 4096.0) # t2306: "cuda:0 f32[1, 512, 1]" - t2308 = ltorch.add(t2306, 1e-05, alpha=None) # t2308: "cuda:0 f32[1, 512, 1]" - # t2308 = prims.add(t2306, 1e-05) # t2308: "cuda:0 f32[1, 512, 1]" - t2309 = prims.rsqrt(t2308) # t2309: "cuda:0 f32[1, 512, 1]" - t2310 = prims.broadcast_in_dim(t2309, (1, 512, 4096), (0, 1, 2)) # t2310: "cuda:0 f32[1, 512, 4096]" - t2311 = ltorch.mul(t2300, t2310) # t2311: "cuda:0 f32[1, 512, 4096]" - # t2311 = prims.mul(t2300, t2310) # t2311: "cuda:0 f32[1, 512, 4096]" - t2312 = prims.convert_element_type(t2311, dtypes.bfloat16) # t2312: "cuda:0 bf16[1, 512, 4096]" - t2313 = prims.broadcast_in_dim(t_transformer_ln_f_weight, (1, 512, 4096), (2,)) # t2313: "cuda:0 bf16[1, 512, 4096]" - t2314 = prims.convert_element_type(t2312, dtypes.float32) # t2314: "cuda:0 f32[1, 512, 4096]" - t2315 = prims.convert_element_type(t2313, dtypes.float32) # t2315: "cuda:0 f32[1, 512, 4096]" - t2316 = ltorch.mul(t2314, t2315) # t2316: "cuda:0 f32[1, 512, 4096]" - # t2316 = prims.mul(t2314, t2315) # t2316: "cuda:0 f32[1, 512, 4096]" - t2317 = prims.convert_element_type(t2316, dtypes.bfloat16) # t2317: "cuda:0 bf16[1, 512, 4096]" - t2318 = prims.linear(t2317, t_lm_head_weight, None) # t2318: "cuda:0 bf16[1, 512, 32000]" - return {'output': t2318, 'flat_args': [idx, tos1, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_1_weight, t_transformer_h_0_mlp_fc_2_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_weight, t_transformer_h_0_norm_2_weight, t_transformer_h_1_attn_attn_weight, t_transformer_h_1_attn_proj_weight, t_transformer_h_1_mlp_fc_1_weight, t_transformer_h_1_mlp_fc_2_weight, t_transformer_h_1_mlp_proj_weight, t_transformer_h_1_norm_1_weight, t_transformer_h_1_norm_2_weight, t_transformer_h_2_attn_attn_weight, t_transformer_h_2_attn_proj_weight, t_transformer_h_2_mlp_fc_1_weight, t_transformer_h_2_mlp_fc_2_weight, t_transformer_h_2_mlp_proj_weight, t_transformer_h_2_norm_1_weight, t_transformer_h_2_norm_2_weight, t_transformer_h_3_attn_attn_weight, t_transformer_h_3_attn_proj_weight, t_transformer_h_3_mlp_fc_1_weight, t_transformer_h_3_mlp_fc_2_weight, t_transformer_h_3_mlp_proj_weight, t_transformer_h_3_norm_1_weight, t_transformer_h_3_norm_2_weight, t_transformer_h_4_attn_attn_weight, t_transformer_h_4_attn_proj_weight, t_transformer_h_4_mlp_fc_1_weight, t_transformer_h_4_mlp_fc_2_weight, t_transformer_h_4_mlp_proj_weight, t_transformer_h_4_norm_1_weight, t_transformer_h_4_norm_2_weight, t_transformer_h_5_attn_attn_weight, t_transformer_h_5_attn_proj_weight, t_transformer_h_5_mlp_fc_1_weight, t_transformer_h_5_mlp_fc_2_weight, t_transformer_h_5_mlp_proj_weight, t_transformer_h_5_norm_1_weight, t_transformer_h_5_norm_2_weight, t_transformer_h_6_attn_attn_weight, t_transformer_h_6_attn_proj_weight, t_transformer_h_6_mlp_fc_1_weight, t_transformer_h_6_mlp_fc_2_weight, t_transformer_h_6_mlp_proj_weight, t_transformer_h_6_norm_1_weight, t_transformer_h_6_norm_2_weight, t_transformer_h_7_attn_attn_weight, t_transformer_h_7_attn_proj_weight, t_transformer_h_7_mlp_fc_1_weight, t_transformer_h_7_mlp_fc_2_weight, t_transformer_h_7_mlp_proj_weight, t_transformer_h_7_norm_1_weight, t_transformer_h_7_norm_2_weight, t_transformer_h_8_attn_attn_weight, t_transformer_h_8_attn_proj_weight, t_transformer_h_8_mlp_fc_1_weight, t_transformer_h_8_mlp_fc_2_weight, t_transformer_h_8_mlp_proj_weight, t_transformer_h_8_norm_1_weight, t_transformer_h_8_norm_2_weight, t_transformer_h_9_attn_attn_weight, t_transformer_h_9_attn_proj_weight, t_transformer_h_9_mlp_fc_1_weight, t_transformer_h_9_mlp_fc_2_weight, t_transformer_h_9_mlp_proj_weight, t_transformer_h_9_norm_1_weight, t_transformer_h_9_norm_2_weight, t_transformer_h_10_attn_attn_weight, t_transformer_h_10_attn_proj_weight, t_transformer_h_10_mlp_fc_1_weight, t_transformer_h_10_mlp_fc_2_weight, t_transformer_h_10_mlp_proj_weight, t_transformer_h_10_norm_1_weight, t_transformer_h_10_norm_2_weight, t_transformer_h_11_attn_attn_weight, t_transformer_h_11_attn_proj_weight, t_transformer_h_11_mlp_fc_1_weight, t_transformer_h_11_mlp_fc_2_weight, t_transformer_h_11_mlp_proj_weight, t_transformer_h_11_norm_1_weight, t_transformer_h_11_norm_2_weight, t_transformer_h_12_attn_attn_weight, t_transformer_h_12_attn_proj_weight, t_transformer_h_12_mlp_fc_1_weight, t_transformer_h_12_mlp_fc_2_weight, t_transformer_h_12_mlp_proj_weight, t_transformer_h_12_norm_1_weight, t_transformer_h_12_norm_2_weight, t_transformer_h_13_attn_attn_weight, t_transformer_h_13_attn_proj_weight, t_transformer_h_13_mlp_fc_1_weight, t_transformer_h_13_mlp_fc_2_weight, t_transformer_h_13_mlp_proj_weight, t_transformer_h_13_norm_1_weight, t_transformer_h_13_norm_2_weight, t_transformer_h_14_attn_attn_weight, t_transformer_h_14_attn_proj_weight, t_transformer_h_14_mlp_fc_1_weight, t_transformer_h_14_mlp_fc_2_weight, t_transformer_h_14_mlp_proj_weight, t_transformer_h_14_norm_1_weight, t_transformer_h_14_norm_2_weight, t_transformer_h_15_attn_attn_weight, t_transformer_h_15_attn_proj_weight, t_transformer_h_15_mlp_fc_1_weight, t_transformer_h_15_mlp_fc_2_weight, t_transformer_h_15_mlp_proj_weight, t_transformer_h_15_norm_1_weight, t_transformer_h_15_norm_2_weight, t_transformer_ln_f_weight, t_transformer_wte_weight], 'flat_output': (t2318,)}, ((idx, t5, t11, t12, t17, t16, t19, t_transformer_h_0_attn_attn_weight, t46, t47, t49, t50, t62, t63, t65, t66, t71, t74, t38, t75, t76, t77, t78, t80, t_transformer_h_0_attn_proj_weight, t86, t95, t96, t101, t100, t103, t_transformer_h_0_mlp_fc_1_weight, t_transformer_h_0_mlp_fc_2_weight, t108, t110, t113, t112, t117, t116, t119, t_transformer_h_0_mlp_proj_weight, t125, t134, t135, t140, t139, t142, t_transformer_h_1_attn_attn_weight, t185, t186, t188, t189, t201, t202, t204, t205, t211, t214, t176, t215, t216, t217, t218, t225, t_transformer_h_1_attn_proj_weight, t231, t240, t241, t246, t245, t248, t_transformer_h_1_mlp_fc_1_weight, t_transformer_h_1_mlp_fc_2_weight, t253, t255, t258, t257, t262, t261, t264, t_transformer_h_1_mlp_proj_weight, t270, t279, t280, t285, t284, t287, t_transformer_h_2_attn_attn_weight, t330, t331, t333, t334, t346, t347, t349, t350, t356, t359, t321, t360, t361, t362, t363, t370, t_transformer_h_2_attn_proj_weight, t376, t385, t386, t391, t390, t393, t_transformer_h_2_mlp_fc_1_weight, t_transformer_h_2_mlp_fc_2_weight, t398, t400, t403, t402, t407, t406, t409, t_transformer_h_2_mlp_proj_weight, t415, t424, t425, t430, t429, t432, t_transformer_h_3_attn_attn_weight, t475, t476, t478, t479, t491, t492, t494, t495, t501, t504, t466, t505, t506, t507, t508, t515, t_transformer_h_3_attn_proj_weight, t521, t530, t531, t536, t535, t538, t_transformer_h_3_mlp_fc_1_weight, t_transformer_h_3_mlp_fc_2_weight, t543, t545, t548, t547, t552, t551, t554, t_transformer_h_3_mlp_proj_weight, t560, t569, t570, t575, t574, t577, t_transformer_h_4_attn_attn_weight, t620, t621, t623, t624, t636, t637, t639, t640, t646, t649, t611, t650, t651, t652, t653, t660, t_transformer_h_4_attn_proj_weight, t666, t675, t676, t681, t680, t683, t_transformer_h_4_mlp_fc_1_weight, t_transformer_h_4_mlp_fc_2_weight, t688, t690, t693, t692, t697, t696, t699, t_transformer_h_4_mlp_proj_weight, t705, t714, t715, t720, t719, t722, t_transformer_h_5_attn_attn_weight, t765, t766, t768, t769, t781, t782, t784, t785, t791, t794, t756, t795, t796, t797, t798, t805, t_transformer_h_5_attn_proj_weight, t811, t820, t821, t826, t825, t828, t_transformer_h_5_mlp_fc_1_weight, t_transformer_h_5_mlp_fc_2_weight, t833, t835, t838, t837, t842, t841, t844, t_transformer_h_5_mlp_proj_weight, t850, t859, t860, t865, t864, t867, t_transformer_h_6_attn_attn_weight, t910, t911, t913, t914, t926, t927, t929, t930, t936, t939, t901, t940, t941, t942, t943, t950, t_transformer_h_6_attn_proj_weight, t956, t965, t966, t971, t970, t973, t_transformer_h_6_mlp_fc_1_weight, t_transformer_h_6_mlp_fc_2_weight, t978, t980, t983, t982, t987, t986, t989, t_transformer_h_6_mlp_proj_weight, t995, t1004, t1005, t1010, t1009, t1012, t_transformer_h_7_attn_attn_weight, t1055, t1056, t1058, t1059, t1071, t1072, t1074, t1075, t1081, t1084, t1046, t1085, t1086, t1087, t1088, t1095, t_transformer_h_7_attn_proj_weight, t1101, t1110, t1111, t1116, t1115, t1118, t_transformer_h_7_mlp_fc_1_weight, t_transformer_h_7_mlp_fc_2_weight, t1123, t1125, t1128, t1127, t1132, t1131, t1134, t_transformer_h_7_mlp_proj_weight, t1140, t1149, t1150, t1155, t1154, t1157, t_transformer_h_8_attn_attn_weight, t1200, t1201, t1203, t1204, t1216, t1217, t1219, t1220, t1226, t1229, t1191, t1230, t1231, t1232, t1233, t1240, t_transformer_h_8_attn_proj_weight, t1246, t1255, t1256, t1261, t1260, t1263, t_transformer_h_8_mlp_fc_1_weight, t_transformer_h_8_mlp_fc_2_weight, t1268, t1270, t1273, t1272, t1277, t1276, t1279, t_transformer_h_8_mlp_proj_weight, t1285, t1294, t1295, t1300, t1299, t1302, t_transformer_h_9_attn_attn_weight, t1345, t1346, t1348, t1349, t1361, t1362, t1364, t1365, t1371, t1374, t1336, t1375, t1376, t1377, t1378, t1385, t_transformer_h_9_attn_proj_weight, t1391, t1400, t1401, t1406, t1405, t1408, t_transformer_h_9_mlp_fc_1_weight, t_transformer_h_9_mlp_fc_2_weight, t1413, t1415, t1418, t1417, t1422, t1421, t1424, t_transformer_h_9_mlp_proj_weight, t1430, t1439, t1440, t1445, t1444, t1447, t_transformer_h_10_attn_attn_weight, t1490, t1491, t1493, t1494, t1506, t1507, t1509, t1510, t1516, t1519, t1481, t1520, t1521, t1522, t1523, t1530, t_transformer_h_10_attn_proj_weight, t1536, t1545, t1546, t1551, t1550, t1553, t_transformer_h_10_mlp_fc_1_weight, t_transformer_h_10_mlp_fc_2_weight, t1558, t1560, t1563, t1562, t1567, t1566, t1569, t_transformer_h_10_mlp_proj_weight, t1575, t1584, t1585, t1590, t1589, t1592, t_transformer_h_11_attn_attn_weight, t1635, t1636, t1638, t1639, t1651, t1652, t1654, t1655, t1661, t1664, t1626, t1665, t1666, t1667, t1668, t1675, t_transformer_h_11_attn_proj_weight, t1681, t1690, t1691, t1696, t1695, t1698, t_transformer_h_11_mlp_fc_1_weight, t_transformer_h_11_mlp_fc_2_weight, t1703, t1705, t1708, t1707, t1712, t1711, t1714, t_transformer_h_11_mlp_proj_weight, t1720, t1729, t1730, t1735, t1734, t1737, t_transformer_h_12_attn_attn_weight, t1780, t1781, t1783, t1784, t1796, t1797, t1799, t1800, t1806, t1809, t1771, t1810, t1811, t1812, t1813, t1820, t_transformer_h_12_attn_proj_weight, t1826, t1835, t1836, t1841, t1840, t1843, t_transformer_h_12_mlp_fc_1_weight, t_transformer_h_12_mlp_fc_2_weight, t1848, t1850, t1853, t1852, t1857, t1856, t1859, t_transformer_h_12_mlp_proj_weight, t1865, t1874, t1875, t1880, t1879, t1882, t_transformer_h_13_attn_attn_weight, t1925, t1926, t1928, t1929, t1941, t1942, t1944, t1945, t1951, t1954, t1916, t1955, t1956, t1957, t1958, t1965, t_transformer_h_13_attn_proj_weight, t1971, t1980, t1981, t1986, t1985, t1988, t_transformer_h_13_mlp_fc_1_weight, t_transformer_h_13_mlp_fc_2_weight, t1993, t1995, t1998, t1997, t2002, t2001, t2004, t_transformer_h_13_mlp_proj_weight, t2010, t2019, t2020, t2025, t2024, t2027, t_transformer_h_14_attn_attn_weight, t2070, t2071, t2073, t2074, t2086, t2087, t2089, t2090, t2096, t2099, t2061, t2100, t2101, t2102, t2103, t2110, t_transformer_h_14_attn_proj_weight, t2116, t2125, t2126, t2131, t2130, t2133, t_transformer_h_14_mlp_fc_1_weight, t_transformer_h_14_mlp_fc_2_weight, t2138, t2140, t2143, t2142, t2147, t2146, t2149, t_transformer_h_14_mlp_proj_weight, t2155, t2164, t2165, t2170, t2169, t2172, t_transformer_h_15_attn_attn_weight, t2215, t2216, t2218, t2219, t2231, t2232, t2234, t2235, t2241, t2244, t2206, t2245, t2246, t2247, t2248, t2255, t_transformer_h_15_attn_proj_weight, t2261, t2270, t2271, t2276, t2275, t2278, t_transformer_h_15_mlp_fc_1_weight, t_transformer_h_15_mlp_fc_2_weight, t2283, t2285, t2288, t2287, t2292, t2291, t2294, t_transformer_h_15_mlp_proj_weight, t2300, t2309, t2310, t2315, t2314, t2317, t_lm_head_weight), (32000, False, False, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0)) -============================================ END: after forward_trc transform_for_execution -============================================ START: LABEL forward_trc -============================================ START: before _transform_for_operator_executor_execution -# Constructed by Dead Code Elimination (took 9 milliseconds) -import thunder -import thunder.core.dtypes as dtypes -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(idx, tos1, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_1_weight, t_transformer_h_0_mlp_fc_2_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_weight, t_transformer_h_0_norm_2_weight, t_transformer_h_1_attn_attn_weight, t_transformer_h_1_attn_proj_weight, t_transformer_h_1_mlp_fc_1_weight, t_transformer_h_1_mlp_fc_2_weight, t_transformer_h_1_mlp_proj_weight, t_transformer_h_1_norm_1_weight, t_transformer_h_1_norm_2_weight, t_transformer_h_2_attn_attn_weight, t_transformer_h_2_attn_proj_weight, t_transformer_h_2_mlp_fc_1_weight, t_transformer_h_2_mlp_fc_2_weight, t_transformer_h_2_mlp_proj_weight, t_transformer_h_2_norm_1_weight, t_transformer_h_2_norm_2_weight, t_transformer_h_3_attn_attn_weight, t_transformer_h_3_attn_proj_weight, t_transformer_h_3_mlp_fc_1_weight, t_transformer_h_3_mlp_fc_2_weight, t_transformer_h_3_mlp_proj_weight, t_transformer_h_3_norm_1_weight, t_transformer_h_3_norm_2_weight, t_transformer_h_4_attn_attn_weight, t_transformer_h_4_attn_proj_weight, t_transformer_h_4_mlp_fc_1_weight, t_transformer_h_4_mlp_fc_2_weight, t_transformer_h_4_mlp_proj_weight, t_transformer_h_4_norm_1_weight, t_transformer_h_4_norm_2_weight, t_transformer_h_5_attn_attn_weight, t_transformer_h_5_attn_proj_weight, t_transformer_h_5_mlp_fc_1_weight, t_transformer_h_5_mlp_fc_2_weight, t_transformer_h_5_mlp_proj_weight, t_transformer_h_5_norm_1_weight, t_transformer_h_5_norm_2_weight, t_transformer_h_6_attn_attn_weight, t_transformer_h_6_attn_proj_weight, t_transformer_h_6_mlp_fc_1_weight, t_transformer_h_6_mlp_fc_2_weight, t_transformer_h_6_mlp_proj_weight, t_transformer_h_6_norm_1_weight, t_transformer_h_6_norm_2_weight, t_transformer_h_7_attn_attn_weight, t_transformer_h_7_attn_proj_weight, t_transformer_h_7_mlp_fc_1_weight, t_transformer_h_7_mlp_fc_2_weight, t_transformer_h_7_mlp_proj_weight, t_transformer_h_7_norm_1_weight, t_transformer_h_7_norm_2_weight, t_transformer_h_8_attn_attn_weight, t_transformer_h_8_attn_proj_weight, t_transformer_h_8_mlp_fc_1_weight, t_transformer_h_8_mlp_fc_2_weight, t_transformer_h_8_mlp_proj_weight, t_transformer_h_8_norm_1_weight, t_transformer_h_8_norm_2_weight, t_transformer_h_9_attn_attn_weight, t_transformer_h_9_attn_proj_weight, t_transformer_h_9_mlp_fc_1_weight, t_transformer_h_9_mlp_fc_2_weight, t_transformer_h_9_mlp_proj_weight, t_transformer_h_9_norm_1_weight, t_transformer_h_9_norm_2_weight, t_transformer_h_10_attn_attn_weight, t_transformer_h_10_attn_proj_weight, t_transformer_h_10_mlp_fc_1_weight, t_transformer_h_10_mlp_fc_2_weight, t_transformer_h_10_mlp_proj_weight, t_transformer_h_10_norm_1_weight, t_transformer_h_10_norm_2_weight, t_transformer_h_11_attn_attn_weight, t_transformer_h_11_attn_proj_weight, t_transformer_h_11_mlp_fc_1_weight, t_transformer_h_11_mlp_fc_2_weight, t_transformer_h_11_mlp_proj_weight, t_transformer_h_11_norm_1_weight, t_transformer_h_11_norm_2_weight, t_transformer_h_12_attn_attn_weight, t_transformer_h_12_attn_proj_weight, t_transformer_h_12_mlp_fc_1_weight, t_transformer_h_12_mlp_fc_2_weight, t_transformer_h_12_mlp_proj_weight, t_transformer_h_12_norm_1_weight, t_transformer_h_12_norm_2_weight, t_transformer_h_13_attn_attn_weight, t_transformer_h_13_attn_proj_weight, t_transformer_h_13_mlp_fc_1_weight, t_transformer_h_13_mlp_fc_2_weight, t_transformer_h_13_mlp_proj_weight, t_transformer_h_13_norm_1_weight, t_transformer_h_13_norm_2_weight, t_transformer_h_14_attn_attn_weight, t_transformer_h_14_attn_proj_weight, t_transformer_h_14_mlp_fc_1_weight, t_transformer_h_14_mlp_fc_2_weight, t_transformer_h_14_mlp_proj_weight, t_transformer_h_14_norm_1_weight, t_transformer_h_14_norm_2_weight, t_transformer_h_15_attn_attn_weight, t_transformer_h_15_attn_proj_weight, t_transformer_h_15_mlp_fc_1_weight, t_transformer_h_15_mlp_fc_2_weight, t_transformer_h_15_mlp_proj_weight, t_transformer_h_15_norm_1_weight, t_transformer_h_15_norm_2_weight, t_transformer_ln_f_weight, t_transformer_wte_weight): - # idx: "cuda:0 i64[1, 512]" - # tos1: "cuda:0 f32[4096, 128]" - # t_lm_head_weight: "cuda:0 bf16[32000, 4096]" - # t_sin: "cuda:0 f32[4096, 128]" - # t_transformer_h_0_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_0_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_0_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_0_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_0_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_0_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_0_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_1_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_1_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_1_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_1_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_1_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_1_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_1_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_2_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_2_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_2_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_2_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_2_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_2_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_2_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_3_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_3_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_3_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_3_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_3_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_3_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_3_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_4_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_4_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_4_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_4_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_4_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_4_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_4_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_5_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_5_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_5_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_5_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_5_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_5_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_5_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_6_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_6_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_6_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_6_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_6_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_6_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_6_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_7_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_7_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_7_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_7_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_7_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_7_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_7_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_8_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_8_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_8_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_8_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_8_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_8_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_8_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_9_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_9_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_9_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_9_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_9_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_9_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_9_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_10_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_10_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_10_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_10_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_10_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_10_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_10_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_11_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_11_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_11_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_11_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_11_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_11_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_11_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_12_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_12_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_12_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_12_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_12_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_12_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_12_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_13_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_13_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_13_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_13_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_13_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_13_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_13_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_14_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_14_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_14_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_14_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_14_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_14_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_14_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_15_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_15_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_15_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_15_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_15_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_15_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_15_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_ln_f_weight: "cuda:0 bf16[4096]" - # t_transformer_wte_weight: "cuda:0 bf16[32000, 4096]" - t0 = prims.slice_prim(tos1, [0, 0], [512, 128], [1, 1]) # t0: "cuda:0 f32[512, 128]" - t1 = prims.slice_prim(t_sin, [0, 0], [512, 128], [1, 1]) # t1: "cuda:0 f32[512, 128]" - t4 = ltorch.embedding(idx, t_transformer_wte_weight, None, None, 2.0, False, False) # t4: "cuda:0 bf16[1, 512, 4096]" - # t2 = ltorch.reshape(idx, [512]) # t2: "cuda:0 i64[512]" - # t2 = prims.reshape(idx, (512,)) # t2: "cuda:0 i64[512]" - # t3 = prims.take(t_transformer_wte_weight, t2, 0) # t3: "cuda:0 bf16[512, 4096]" - # t4 = ltorch.reshape(t3, [1, 512, 4096]) # t4: "cuda:0 bf16[1, 512, 4096]" - # t4 = prims.reshape(t3, (1, 512, 4096)) # t4: "cuda:0 bf16[1, 512, 4096]" - t5 = prims.convert_element_type(t4, dtypes.float32) # t5: "cuda:0 f32[1, 512, 4096]" - t6 = ltorch.mul(t5, t5) # t6: "cuda:0 f32[1, 512, 4096]" - # t6 = prims.mul(t5, t5) # t6: "cuda:0 f32[1, 512, 4096]" - t7 = prims.sum(t6, (2,)) # t7: "cuda:0 f32[1, 512]" - t8 = prims.broadcast_in_dim(t7, [1, 512, 1], [0, 1]) # t8: "cuda:0 f32[1, 512, 1]" - t9 = ltorch.true_divide(t8, 4096.0) # t9: "cuda:0 f32[1, 512, 1]" - # t9 = prims.div(t8, 4096.0) # t9: "cuda:0 f32[1, 512, 1]" - t10 = ltorch.add(t9, 1e-05, alpha=None) # t10: "cuda:0 f32[1, 512, 1]" - # t10 = prims.add(t9, 1e-05) # t10: "cuda:0 f32[1, 512, 1]" - t11 = prims.rsqrt(t10) # t11: "cuda:0 f32[1, 512, 1]" - t12 = prims.broadcast_in_dim(t11, (1, 512, 4096), (0, 1, 2)) # t12: "cuda:0 f32[1, 512, 4096]" - t13 = ltorch.mul(t5, t12) # t13: "cuda:0 f32[1, 512, 4096]" - # t13 = prims.mul(t5, t12) # t13: "cuda:0 f32[1, 512, 4096]" - t14 = prims.convert_element_type(t13, dtypes.bfloat16) # t14: "cuda:0 bf16[1, 512, 4096]" - t15 = prims.broadcast_in_dim(t_transformer_h_0_norm_1_weight, (1, 512, 4096), (2,)) # t15: "cuda:0 bf16[1, 512, 4096]" - t16 = prims.convert_element_type(t14, dtypes.float32) # t16: "cuda:0 f32[1, 512, 4096]" - t17 = prims.convert_element_type(t15, dtypes.float32) # t17: "cuda:0 f32[1, 512, 4096]" - t18 = ltorch.mul(t16, t17) # t18: "cuda:0 f32[1, 512, 4096]" - # t18 = prims.mul(t16, t17) # t18: "cuda:0 f32[1, 512, 4096]" - t19 = prims.convert_element_type(t18, dtypes.bfloat16) # t19: "cuda:0 bf16[1, 512, 4096]" - t20 = prims.linear(t19, t_transformer_h_0_attn_attn_weight, None) # t20: "cuda:0 bf16[1, 512, 12288]" - t21 = prims.reshape(t20, (1, 512, 32, 3, 128)) # t21: "cuda:0 bf16[1, 512, 32, 3, 128]" - t22 = prims.transpose(t21, (0, 2, 3, 1, 4)) # t22: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t23, t24, t25) = ltorch.split(t22, (1, 1, 1), 2) - # t23 = prims.slice_prim(t22, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t23: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t24 = prims.slice_prim(t22, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t24: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t25 = prims.slice_prim(t22, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t25: "cuda:0 bf16[1, 32, 1, 512, 128]" - t26 = prims.reshape(t23, (1, 32, 512, 128)) # t26: "cuda:0 bf16[1, 32, 512, 128]" - t32 = prims.reshape(t24, (1, 32, 512, 128)) # t32: "cuda:0 bf16[1, 32, 512, 128]" - t38 = prims.reshape(t25, (1, 32, 512, 128)) # t38: "cuda:0 bf16[1, 32, 512, 128]" - t39 = prims.slice_prim(t26, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t39: "cuda:0 bf16[1, 32, 512, 128]" - t40 = prims.slice_prim(t39, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t40: "cuda:0 bf16[1, 32, 512, 64]" - t41 = prims.slice_prim(t39, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t41: "cuda:0 bf16[1, 32, 512, 64]" - t42 = prims.convert_element_type(t41, dtypes.float32) # t42: "cuda:0 f32[1, 32, 512, 64]" - t43 = prims.neg(t42) # t43: "cuda:0 f32[1, 32, 512, 64]" - t44 = prims.convert_element_type(t43, dtypes.bfloat16) # t44: "cuda:0 bf16[1, 32, 512, 64]" - t45 = prims.cat((t44, t40), -1) # t45: "cuda:0 bf16[1, 32, 512, 128]" - t46 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t46: "cuda:0 f32[1, 32, 512, 128]" - t47 = prims.convert_element_type(t39, dtypes.float32) # t47: "cuda:0 f32[1, 32, 512, 128]" - t48 = ltorch.mul(t47, t46) # t48: "cuda:0 f32[1, 32, 512, 128]" - # t48 = prims.mul(t47, t46) # t48: "cuda:0 f32[1, 32, 512, 128]" - t49 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t49: "cuda:0 f32[1, 32, 512, 128]" - t50 = prims.convert_element_type(t45, dtypes.float32) # t50: "cuda:0 f32[1, 32, 512, 128]" - t51 = ltorch.mul(t50, t49) # t51: "cuda:0 f32[1, 32, 512, 128]" - # t51 = prims.mul(t50, t49) # t51: "cuda:0 f32[1, 32, 512, 128]" - t52 = ltorch.add(t48, t51, alpha=None) # t52: "cuda:0 f32[1, 32, 512, 128]" - # t52 = prims.add(t48, t51) # t52: "cuda:0 f32[1, 32, 512, 128]" - t53 = prims.convert_element_type(t52, dtypes.bfloat16) # t53: "cuda:0 bf16[1, 32, 512, 128]" - t54 = prims.slice_prim(t32, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t54: "cuda:0 bf16[1, 32, 512, 128]" - t55 = prims.slice_prim(t54, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t55: "cuda:0 bf16[1, 32, 512, 64]" - t56 = prims.slice_prim(t54, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t56: "cuda:0 bf16[1, 32, 512, 64]" - t57 = prims.convert_element_type(t56, dtypes.float32) # t57: "cuda:0 f32[1, 32, 512, 64]" - t58 = prims.neg(t57) # t58: "cuda:0 f32[1, 32, 512, 64]" - t59 = prims.convert_element_type(t58, dtypes.bfloat16) # t59: "cuda:0 bf16[1, 32, 512, 64]" - t61 = prims.cat((t59, t55), -1) # t61: "cuda:0 bf16[1, 32, 512, 128]" - t62 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t62: "cuda:0 f32[1, 32, 512, 128]" - t63 = prims.convert_element_type(t54, dtypes.float32) # t63: "cuda:0 f32[1, 32, 512, 128]" - t64 = ltorch.mul(t63, t62) # t64: "cuda:0 f32[1, 32, 512, 128]" - # t64 = prims.mul(t63, t62) # t64: "cuda:0 f32[1, 32, 512, 128]" - t65 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t65: "cuda:0 f32[1, 32, 512, 128]" - t66 = prims.convert_element_type(t61, dtypes.float32) # t66: "cuda:0 f32[1, 32, 512, 128]" - t67 = ltorch.mul(t66, t65) # t67: "cuda:0 f32[1, 32, 512, 128]" - # t67 = prims.mul(t66, t65) # t67: "cuda:0 f32[1, 32, 512, 128]" - t68 = ltorch.add(t64, t67, alpha=None) # t68: "cuda:0 f32[1, 32, 512, 128]" - # t68 = prims.add(t64, t67) # t68: "cuda:0 f32[1, 32, 512, 128]" - t69 = prims.convert_element_type(t68, dtypes.bfloat16) # t69: "cuda:0 bf16[1, 32, 512, 128]" - t70 = prims.slice_prim(t26, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t70: "cuda:0 bf16[1, 32, 512, 0]" - t71 = prims.cat((t53, t70), -1) # t71: "cuda:0 bf16[1, 32, 512, 128]" - t72 = prims.slice_prim(t32, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t72: "cuda:0 bf16[1, 32, 512, 0]" - t74 = prims.cat((t69, t72), -1) # t74: "cuda:0 bf16[1, 32, 512, 128]" - (t75, t76, t77, t78) = cudnn_sdpa_fwd(t71, t74, t38, None, 0.0, True, scale=0.08838834764831843) - t79 = prims.transpose(t75, (0, 2, 1, 3)) # t79: "cuda:0 bf16[1, 512, 32, 128]" - t80 = prims.reshape(t79, (1, 512, 4096)) # t80: "cuda:0 bf16[1, 512, 4096]" - t81 = prims.linear(t80, t_transformer_h_0_attn_proj_weight, None) # t81: "cuda:0 bf16[1, 512, 4096]" - t82 = prims.convert_element_type(t81, dtypes.float32) # t82: "cuda:0 f32[1, 512, 4096]" - t83 = prims.convert_element_type(t4, dtypes.float32) # t83: "cuda:0 f32[1, 512, 4096]" - t84 = ltorch.add(t82, t83, alpha=None) # t84: "cuda:0 f32[1, 512, 4096]" - # t84 = prims.add(t82, t83) # t84: "cuda:0 f32[1, 512, 4096]" - t85 = prims.convert_element_type(t84, dtypes.bfloat16) # t85: "cuda:0 bf16[1, 512, 4096]" - t86 = prims.convert_element_type(t85, dtypes.float32) # t86: "cuda:0 f32[1, 512, 4096]" - t87 = ltorch.mul(t86, t86) # t87: "cuda:0 f32[1, 512, 4096]" - # t87 = prims.mul(t86, t86) # t87: "cuda:0 f32[1, 512, 4096]" - t89 = prims.sum(t87, (2,)) # t89: "cuda:0 f32[1, 512]" - t90 = prims.broadcast_in_dim(t89, [1, 512, 1], [0, 1]) # t90: "cuda:0 f32[1, 512, 1]" - t92 = ltorch.true_divide(t90, 4096.0) # t92: "cuda:0 f32[1, 512, 1]" - # t92 = prims.div(t90, 4096.0) # t92: "cuda:0 f32[1, 512, 1]" - t94 = ltorch.add(t92, 1e-05, alpha=None) # t94: "cuda:0 f32[1, 512, 1]" - # t94 = prims.add(t92, 1e-05) # t94: "cuda:0 f32[1, 512, 1]" - t95 = prims.rsqrt(t94) # t95: "cuda:0 f32[1, 512, 1]" - t96 = prims.broadcast_in_dim(t95, (1, 512, 4096), (0, 1, 2)) # t96: "cuda:0 f32[1, 512, 4096]" - t97 = ltorch.mul(t86, t96) # t97: "cuda:0 f32[1, 512, 4096]" - # t97 = prims.mul(t86, t96) # t97: "cuda:0 f32[1, 512, 4096]" - t98 = prims.convert_element_type(t97, dtypes.bfloat16) # t98: "cuda:0 bf16[1, 512, 4096]" - t99 = prims.broadcast_in_dim(t_transformer_h_0_norm_2_weight, (1, 512, 4096), (2,)) # t99: "cuda:0 bf16[1, 512, 4096]" - t100 = prims.convert_element_type(t98, dtypes.float32) # t100: "cuda:0 f32[1, 512, 4096]" - t101 = prims.convert_element_type(t99, dtypes.float32) # t101: "cuda:0 f32[1, 512, 4096]" - t102 = ltorch.mul(t100, t101) # t102: "cuda:0 f32[1, 512, 4096]" - # t102 = prims.mul(t100, t101) # t102: "cuda:0 f32[1, 512, 4096]" - t103 = prims.convert_element_type(t102, dtypes.bfloat16) # t103: "cuda:0 bf16[1, 512, 4096]" - t104 = prims.linear(t103, t_transformer_h_0_mlp_fc_1_weight, None) # t104: "cuda:0 bf16[1, 512, 11008]" - t105 = prims.linear(t103, t_transformer_h_0_mlp_fc_2_weight, None) # t105: "cuda:0 bf16[1, 512, 11008]" - t106 = prims.convert_element_type(t104, dtypes.float32) # t106: "cuda:0 f32[1, 512, 11008]" - t107 = prims.neg(t106) # t107: "cuda:0 f32[1, 512, 11008]" - t108 = prims.exp(t107) # t108: "cuda:0 f32[1, 512, 11008]" - t109 = ltorch.add(1.0, t108, alpha=None) # t109: "cuda:0 f32[1, 512, 11008]" - # t109 = prims.add(1.0, t108) # t109: "cuda:0 f32[1, 512, 11008]" - t110 = prims.reciprocal(t109) # t110: "cuda:0 f32[1, 512, 11008]" - t111 = prims.convert_element_type(t110, dtypes.bfloat16) # t111: "cuda:0 bf16[1, 512, 11008]" - t112 = prims.convert_element_type(t104, dtypes.float32) # t112: "cuda:0 f32[1, 512, 11008]" - t113 = prims.convert_element_type(t111, dtypes.float32) # t113: "cuda:0 f32[1, 512, 11008]" - t114 = ltorch.mul(t112, t113) # t114: "cuda:0 f32[1, 512, 11008]" - # t114 = prims.mul(t112, t113) # t114: "cuda:0 f32[1, 512, 11008]" - t115 = prims.convert_element_type(t114, dtypes.bfloat16) # t115: "cuda:0 bf16[1, 512, 11008]" - t116 = prims.convert_element_type(t115, dtypes.float32) # t116: "cuda:0 f32[1, 512, 11008]" - t117 = prims.convert_element_type(t105, dtypes.float32) # t117: "cuda:0 f32[1, 512, 11008]" - t118 = ltorch.mul(t116, t117) # t118: "cuda:0 f32[1, 512, 11008]" - # t118 = prims.mul(t116, t117) # t118: "cuda:0 f32[1, 512, 11008]" - t119 = prims.convert_element_type(t118, dtypes.bfloat16) # t119: "cuda:0 bf16[1, 512, 11008]" - t120 = prims.linear(t119, t_transformer_h_0_mlp_proj_weight, None) # t120: "cuda:0 bf16[1, 512, 4096]" - t121 = prims.convert_element_type(t120, dtypes.float32) # t121: "cuda:0 f32[1, 512, 4096]" - t122 = prims.convert_element_type(t85, dtypes.float32) # t122: "cuda:0 f32[1, 512, 4096]" - t123 = ltorch.add(t121, t122, alpha=None) # t123: "cuda:0 f32[1, 512, 4096]" - # t123 = prims.add(t121, t122) # t123: "cuda:0 f32[1, 512, 4096]" - t124 = prims.convert_element_type(t123, dtypes.bfloat16) # t124: "cuda:0 bf16[1, 512, 4096]" - t125 = prims.convert_element_type(t124, dtypes.float32) # t125: "cuda:0 f32[1, 512, 4096]" - t126 = ltorch.mul(t125, t125) # t126: "cuda:0 f32[1, 512, 4096]" - # t126 = prims.mul(t125, t125) # t126: "cuda:0 f32[1, 512, 4096]" - t128 = prims.sum(t126, (2,)) # t128: "cuda:0 f32[1, 512]" - t129 = prims.broadcast_in_dim(t128, [1, 512, 1], [0, 1]) # t129: "cuda:0 f32[1, 512, 1]" - t131 = ltorch.true_divide(t129, 4096.0) # t131: "cuda:0 f32[1, 512, 1]" - # t131 = prims.div(t129, 4096.0) # t131: "cuda:0 f32[1, 512, 1]" - t133 = ltorch.add(t131, 1e-05, alpha=None) # t133: "cuda:0 f32[1, 512, 1]" - # t133 = prims.add(t131, 1e-05) # t133: "cuda:0 f32[1, 512, 1]" - t134 = prims.rsqrt(t133) # t134: "cuda:0 f32[1, 512, 1]" - t135 = prims.broadcast_in_dim(t134, (1, 512, 4096), (0, 1, 2)) # t135: "cuda:0 f32[1, 512, 4096]" - t136 = ltorch.mul(t125, t135) # t136: "cuda:0 f32[1, 512, 4096]" - # t136 = prims.mul(t125, t135) # t136: "cuda:0 f32[1, 512, 4096]" - t137 = prims.convert_element_type(t136, dtypes.bfloat16) # t137: "cuda:0 bf16[1, 512, 4096]" - t138 = prims.broadcast_in_dim(t_transformer_h_1_norm_1_weight, (1, 512, 4096), (2,)) # t138: "cuda:0 bf16[1, 512, 4096]" - t139 = prims.convert_element_type(t137, dtypes.float32) # t139: "cuda:0 f32[1, 512, 4096]" - t140 = prims.convert_element_type(t138, dtypes.float32) # t140: "cuda:0 f32[1, 512, 4096]" - t141 = ltorch.mul(t139, t140) # t141: "cuda:0 f32[1, 512, 4096]" - # t141 = prims.mul(t139, t140) # t141: "cuda:0 f32[1, 512, 4096]" - t142 = prims.convert_element_type(t141, dtypes.bfloat16) # t142: "cuda:0 bf16[1, 512, 4096]" - t143 = prims.linear(t142, t_transformer_h_1_attn_attn_weight, None) # t143: "cuda:0 bf16[1, 512, 12288]" - t149 = prims.reshape(t143, (1, 512, 32, 3, 128)) # t149: "cuda:0 bf16[1, 512, 32, 3, 128]" - t155 = prims.transpose(t149, (0, 2, 3, 1, 4)) # t155: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t156, t157, t158) = ltorch.split(t155, (1, 1, 1), 2) - # t156 = prims.slice_prim(t155, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t156: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t157 = prims.slice_prim(t155, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t157: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t158 = prims.slice_prim(t155, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t158: "cuda:0 bf16[1, 32, 1, 512, 128]" - t164 = prims.reshape(t156, (1, 32, 512, 128)) # t164: "cuda:0 bf16[1, 32, 512, 128]" - t170 = prims.reshape(t157, (1, 32, 512, 128)) # t170: "cuda:0 bf16[1, 32, 512, 128]" - t176 = prims.reshape(t158, (1, 32, 512, 128)) # t176: "cuda:0 bf16[1, 32, 512, 128]" - t177 = prims.slice_prim(t164, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t177: "cuda:0 bf16[1, 32, 512, 128]" - t178 = prims.slice_prim(t177, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t178: "cuda:0 bf16[1, 32, 512, 64]" - t179 = prims.slice_prim(t177, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t179: "cuda:0 bf16[1, 32, 512, 64]" - t180 = prims.convert_element_type(t179, dtypes.float32) # t180: "cuda:0 f32[1, 32, 512, 64]" - t181 = prims.neg(t180) # t181: "cuda:0 f32[1, 32, 512, 64]" - t182 = prims.convert_element_type(t181, dtypes.bfloat16) # t182: "cuda:0 bf16[1, 32, 512, 64]" - t184 = prims.cat((t182, t178), -1) # t184: "cuda:0 bf16[1, 32, 512, 128]" - t185 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t185: "cuda:0 f32[1, 32, 512, 128]" - t186 = prims.convert_element_type(t177, dtypes.float32) # t186: "cuda:0 f32[1, 32, 512, 128]" - t187 = ltorch.mul(t186, t185) # t187: "cuda:0 f32[1, 32, 512, 128]" - # t187 = prims.mul(t186, t185) # t187: "cuda:0 f32[1, 32, 512, 128]" - t188 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t188: "cuda:0 f32[1, 32, 512, 128]" - t189 = prims.convert_element_type(t184, dtypes.float32) # t189: "cuda:0 f32[1, 32, 512, 128]" - t190 = ltorch.mul(t189, t188) # t190: "cuda:0 f32[1, 32, 512, 128]" - # t190 = prims.mul(t189, t188) # t190: "cuda:0 f32[1, 32, 512, 128]" - t191 = ltorch.add(t187, t190, alpha=None) # t191: "cuda:0 f32[1, 32, 512, 128]" - # t191 = prims.add(t187, t190) # t191: "cuda:0 f32[1, 32, 512, 128]" - t192 = prims.convert_element_type(t191, dtypes.bfloat16) # t192: "cuda:0 bf16[1, 32, 512, 128]" - t193 = prims.slice_prim(t170, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t193: "cuda:0 bf16[1, 32, 512, 128]" - t194 = prims.slice_prim(t193, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t194: "cuda:0 bf16[1, 32, 512, 64]" - t195 = prims.slice_prim(t193, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t195: "cuda:0 bf16[1, 32, 512, 64]" - t196 = prims.convert_element_type(t195, dtypes.float32) # t196: "cuda:0 f32[1, 32, 512, 64]" - t197 = prims.neg(t196) # t197: "cuda:0 f32[1, 32, 512, 64]" - t198 = prims.convert_element_type(t197, dtypes.bfloat16) # t198: "cuda:0 bf16[1, 32, 512, 64]" - t200 = prims.cat((t198, t194), -1) # t200: "cuda:0 bf16[1, 32, 512, 128]" - t201 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t201: "cuda:0 f32[1, 32, 512, 128]" - t202 = prims.convert_element_type(t193, dtypes.float32) # t202: "cuda:0 f32[1, 32, 512, 128]" - t203 = ltorch.mul(t202, t201) # t203: "cuda:0 f32[1, 32, 512, 128]" - # t203 = prims.mul(t202, t201) # t203: "cuda:0 f32[1, 32, 512, 128]" - t204 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t204: "cuda:0 f32[1, 32, 512, 128]" - t205 = prims.convert_element_type(t200, dtypes.float32) # t205: "cuda:0 f32[1, 32, 512, 128]" - t206 = ltorch.mul(t205, t204) # t206: "cuda:0 f32[1, 32, 512, 128]" - # t206 = prims.mul(t205, t204) # t206: "cuda:0 f32[1, 32, 512, 128]" - t207 = ltorch.add(t203, t206, alpha=None) # t207: "cuda:0 f32[1, 32, 512, 128]" - # t207 = prims.add(t203, t206) # t207: "cuda:0 f32[1, 32, 512, 128]" - t208 = prims.convert_element_type(t207, dtypes.bfloat16) # t208: "cuda:0 bf16[1, 32, 512, 128]" - t209 = prims.slice_prim(t164, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t209: "cuda:0 bf16[1, 32, 512, 0]" - t211 = prims.cat((t192, t209), -1) # t211: "cuda:0 bf16[1, 32, 512, 128]" - t212 = prims.slice_prim(t170, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t212: "cuda:0 bf16[1, 32, 512, 0]" - t214 = prims.cat((t208, t212), -1) # t214: "cuda:0 bf16[1, 32, 512, 128]" - (t215, t216, t217, t218) = cudnn_sdpa_fwd(t211, t214, t176, None, 0.0, True, scale=0.08838834764831843) - t221 = prims.transpose(t215, (0, 2, 1, 3)) # t221: "cuda:0 bf16[1, 512, 32, 128]" - t225 = prims.reshape(t221, (1, 512, 4096)) # t225: "cuda:0 bf16[1, 512, 4096]" - t226 = prims.linear(t225, t_transformer_h_1_attn_proj_weight, None) # t226: "cuda:0 bf16[1, 512, 4096]" - t227 = prims.convert_element_type(t226, dtypes.float32) # t227: "cuda:0 f32[1, 512, 4096]" - t228 = prims.convert_element_type(t124, dtypes.float32) # t228: "cuda:0 f32[1, 512, 4096]" - t229 = ltorch.add(t227, t228, alpha=None) # t229: "cuda:0 f32[1, 512, 4096]" - # t229 = prims.add(t227, t228) # t229: "cuda:0 f32[1, 512, 4096]" - t230 = prims.convert_element_type(t229, dtypes.bfloat16) # t230: "cuda:0 bf16[1, 512, 4096]" - t231 = prims.convert_element_type(t230, dtypes.float32) # t231: "cuda:0 f32[1, 512, 4096]" - t232 = ltorch.mul(t231, t231) # t232: "cuda:0 f32[1, 512, 4096]" - # t232 = prims.mul(t231, t231) # t232: "cuda:0 f32[1, 512, 4096]" - t234 = prims.sum(t232, (2,)) # t234: "cuda:0 f32[1, 512]" - t235 = prims.broadcast_in_dim(t234, [1, 512, 1], [0, 1]) # t235: "cuda:0 f32[1, 512, 1]" - t237 = ltorch.true_divide(t235, 4096.0) # t237: "cuda:0 f32[1, 512, 1]" - # t237 = prims.div(t235, 4096.0) # t237: "cuda:0 f32[1, 512, 1]" - t239 = ltorch.add(t237, 1e-05, alpha=None) # t239: "cuda:0 f32[1, 512, 1]" - # t239 = prims.add(t237, 1e-05) # t239: "cuda:0 f32[1, 512, 1]" - t240 = prims.rsqrt(t239) # t240: "cuda:0 f32[1, 512, 1]" - t241 = prims.broadcast_in_dim(t240, (1, 512, 4096), (0, 1, 2)) # t241: "cuda:0 f32[1, 512, 4096]" - t242 = ltorch.mul(t231, t241) # t242: "cuda:0 f32[1, 512, 4096]" - # t242 = prims.mul(t231, t241) # t242: "cuda:0 f32[1, 512, 4096]" - t243 = prims.convert_element_type(t242, dtypes.bfloat16) # t243: "cuda:0 bf16[1, 512, 4096]" - t244 = prims.broadcast_in_dim(t_transformer_h_1_norm_2_weight, (1, 512, 4096), (2,)) # t244: "cuda:0 bf16[1, 512, 4096]" - t245 = prims.convert_element_type(t243, dtypes.float32) # t245: "cuda:0 f32[1, 512, 4096]" - t246 = prims.convert_element_type(t244, dtypes.float32) # t246: "cuda:0 f32[1, 512, 4096]" - t247 = ltorch.mul(t245, t246) # t247: "cuda:0 f32[1, 512, 4096]" - # t247 = prims.mul(t245, t246) # t247: "cuda:0 f32[1, 512, 4096]" - t248 = prims.convert_element_type(t247, dtypes.bfloat16) # t248: "cuda:0 bf16[1, 512, 4096]" - t249 = prims.linear(t248, t_transformer_h_1_mlp_fc_1_weight, None) # t249: "cuda:0 bf16[1, 512, 11008]" - t250 = prims.linear(t248, t_transformer_h_1_mlp_fc_2_weight, None) # t250: "cuda:0 bf16[1, 512, 11008]" - t251 = prims.convert_element_type(t249, dtypes.float32) # t251: "cuda:0 f32[1, 512, 11008]" - t252 = prims.neg(t251) # t252: "cuda:0 f32[1, 512, 11008]" - t253 = prims.exp(t252) # t253: "cuda:0 f32[1, 512, 11008]" - t254 = ltorch.add(1.0, t253, alpha=None) # t254: "cuda:0 f32[1, 512, 11008]" - # t254 = prims.add(1.0, t253) # t254: "cuda:0 f32[1, 512, 11008]" - t255 = prims.reciprocal(t254) # t255: "cuda:0 f32[1, 512, 11008]" - t256 = prims.convert_element_type(t255, dtypes.bfloat16) # t256: "cuda:0 bf16[1, 512, 11008]" - t257 = prims.convert_element_type(t249, dtypes.float32) # t257: "cuda:0 f32[1, 512, 11008]" - t258 = prims.convert_element_type(t256, dtypes.float32) # t258: "cuda:0 f32[1, 512, 11008]" - t259 = ltorch.mul(t257, t258) # t259: "cuda:0 f32[1, 512, 11008]" - # t259 = prims.mul(t257, t258) # t259: "cuda:0 f32[1, 512, 11008]" - t260 = prims.convert_element_type(t259, dtypes.bfloat16) # t260: "cuda:0 bf16[1, 512, 11008]" - t261 = prims.convert_element_type(t260, dtypes.float32) # t261: "cuda:0 f32[1, 512, 11008]" - t262 = prims.convert_element_type(t250, dtypes.float32) # t262: "cuda:0 f32[1, 512, 11008]" - t263 = ltorch.mul(t261, t262) # t263: "cuda:0 f32[1, 512, 11008]" - # t263 = prims.mul(t261, t262) # t263: "cuda:0 f32[1, 512, 11008]" - t264 = prims.convert_element_type(t263, dtypes.bfloat16) # t264: "cuda:0 bf16[1, 512, 11008]" - t265 = prims.linear(t264, t_transformer_h_1_mlp_proj_weight, None) # t265: "cuda:0 bf16[1, 512, 4096]" - t266 = prims.convert_element_type(t265, dtypes.float32) # t266: "cuda:0 f32[1, 512, 4096]" - t267 = prims.convert_element_type(t230, dtypes.float32) # t267: "cuda:0 f32[1, 512, 4096]" - t268 = ltorch.add(t266, t267, alpha=None) # t268: "cuda:0 f32[1, 512, 4096]" - # t268 = prims.add(t266, t267) # t268: "cuda:0 f32[1, 512, 4096]" - t269 = prims.convert_element_type(t268, dtypes.bfloat16) # t269: "cuda:0 bf16[1, 512, 4096]" - t270 = prims.convert_element_type(t269, dtypes.float32) # t270: "cuda:0 f32[1, 512, 4096]" - t271 = ltorch.mul(t270, t270) # t271: "cuda:0 f32[1, 512, 4096]" - # t271 = prims.mul(t270, t270) # t271: "cuda:0 f32[1, 512, 4096]" - t273 = prims.sum(t271, (2,)) # t273: "cuda:0 f32[1, 512]" - t274 = prims.broadcast_in_dim(t273, [1, 512, 1], [0, 1]) # t274: "cuda:0 f32[1, 512, 1]" - t276 = ltorch.true_divide(t274, 4096.0) # t276: "cuda:0 f32[1, 512, 1]" - # t276 = prims.div(t274, 4096.0) # t276: "cuda:0 f32[1, 512, 1]" - t278 = ltorch.add(t276, 1e-05, alpha=None) # t278: "cuda:0 f32[1, 512, 1]" - # t278 = prims.add(t276, 1e-05) # t278: "cuda:0 f32[1, 512, 1]" - t279 = prims.rsqrt(t278) # t279: "cuda:0 f32[1, 512, 1]" - t280 = prims.broadcast_in_dim(t279, (1, 512, 4096), (0, 1, 2)) # t280: "cuda:0 f32[1, 512, 4096]" - t281 = ltorch.mul(t270, t280) # t281: "cuda:0 f32[1, 512, 4096]" - # t281 = prims.mul(t270, t280) # t281: "cuda:0 f32[1, 512, 4096]" - t282 = prims.convert_element_type(t281, dtypes.bfloat16) # t282: "cuda:0 bf16[1, 512, 4096]" - t283 = prims.broadcast_in_dim(t_transformer_h_2_norm_1_weight, (1, 512, 4096), (2,)) # t283: "cuda:0 bf16[1, 512, 4096]" - t284 = prims.convert_element_type(t282, dtypes.float32) # t284: "cuda:0 f32[1, 512, 4096]" - t285 = prims.convert_element_type(t283, dtypes.float32) # t285: "cuda:0 f32[1, 512, 4096]" - t286 = ltorch.mul(t284, t285) # t286: "cuda:0 f32[1, 512, 4096]" - # t286 = prims.mul(t284, t285) # t286: "cuda:0 f32[1, 512, 4096]" - t287 = prims.convert_element_type(t286, dtypes.bfloat16) # t287: "cuda:0 bf16[1, 512, 4096]" - t288 = prims.linear(t287, t_transformer_h_2_attn_attn_weight, None) # t288: "cuda:0 bf16[1, 512, 12288]" - t294 = prims.reshape(t288, (1, 512, 32, 3, 128)) # t294: "cuda:0 bf16[1, 512, 32, 3, 128]" - t300 = prims.transpose(t294, (0, 2, 3, 1, 4)) # t300: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t301, t302, t303) = ltorch.split(t300, (1, 1, 1), 2) - # t301 = prims.slice_prim(t300, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t301: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t302 = prims.slice_prim(t300, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t302: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t303 = prims.slice_prim(t300, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t303: "cuda:0 bf16[1, 32, 1, 512, 128]" - t309 = prims.reshape(t301, (1, 32, 512, 128)) # t309: "cuda:0 bf16[1, 32, 512, 128]" - t315 = prims.reshape(t302, (1, 32, 512, 128)) # t315: "cuda:0 bf16[1, 32, 512, 128]" - t321 = prims.reshape(t303, (1, 32, 512, 128)) # t321: "cuda:0 bf16[1, 32, 512, 128]" - t322 = prims.slice_prim(t309, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t322: "cuda:0 bf16[1, 32, 512, 128]" - t323 = prims.slice_prim(t322, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t323: "cuda:0 bf16[1, 32, 512, 64]" - t324 = prims.slice_prim(t322, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t324: "cuda:0 bf16[1, 32, 512, 64]" - t325 = prims.convert_element_type(t324, dtypes.float32) # t325: "cuda:0 f32[1, 32, 512, 64]" - t326 = prims.neg(t325) # t326: "cuda:0 f32[1, 32, 512, 64]" - t327 = prims.convert_element_type(t326, dtypes.bfloat16) # t327: "cuda:0 bf16[1, 32, 512, 64]" - t329 = prims.cat((t327, t323), -1) # t329: "cuda:0 bf16[1, 32, 512, 128]" - t330 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t330: "cuda:0 f32[1, 32, 512, 128]" - t331 = prims.convert_element_type(t322, dtypes.float32) # t331: "cuda:0 f32[1, 32, 512, 128]" - t332 = ltorch.mul(t331, t330) # t332: "cuda:0 f32[1, 32, 512, 128]" - # t332 = prims.mul(t331, t330) # t332: "cuda:0 f32[1, 32, 512, 128]" - t333 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t333: "cuda:0 f32[1, 32, 512, 128]" - t334 = prims.convert_element_type(t329, dtypes.float32) # t334: "cuda:0 f32[1, 32, 512, 128]" - t335 = ltorch.mul(t334, t333) # t335: "cuda:0 f32[1, 32, 512, 128]" - # t335 = prims.mul(t334, t333) # t335: "cuda:0 f32[1, 32, 512, 128]" - t336 = ltorch.add(t332, t335, alpha=None) # t336: "cuda:0 f32[1, 32, 512, 128]" - # t336 = prims.add(t332, t335) # t336: "cuda:0 f32[1, 32, 512, 128]" - t337 = prims.convert_element_type(t336, dtypes.bfloat16) # t337: "cuda:0 bf16[1, 32, 512, 128]" - t338 = prims.slice_prim(t315, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t338: "cuda:0 bf16[1, 32, 512, 128]" - t339 = prims.slice_prim(t338, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t339: "cuda:0 bf16[1, 32, 512, 64]" - t340 = prims.slice_prim(t338, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t340: "cuda:0 bf16[1, 32, 512, 64]" - t341 = prims.convert_element_type(t340, dtypes.float32) # t341: "cuda:0 f32[1, 32, 512, 64]" - t342 = prims.neg(t341) # t342: "cuda:0 f32[1, 32, 512, 64]" - t343 = prims.convert_element_type(t342, dtypes.bfloat16) # t343: "cuda:0 bf16[1, 32, 512, 64]" - t345 = prims.cat((t343, t339), -1) # t345: "cuda:0 bf16[1, 32, 512, 128]" - t346 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t346: "cuda:0 f32[1, 32, 512, 128]" - t347 = prims.convert_element_type(t338, dtypes.float32) # t347: "cuda:0 f32[1, 32, 512, 128]" - t348 = ltorch.mul(t347, t346) # t348: "cuda:0 f32[1, 32, 512, 128]" - # t348 = prims.mul(t347, t346) # t348: "cuda:0 f32[1, 32, 512, 128]" - t349 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t349: "cuda:0 f32[1, 32, 512, 128]" - t350 = prims.convert_element_type(t345, dtypes.float32) # t350: "cuda:0 f32[1, 32, 512, 128]" - t351 = ltorch.mul(t350, t349) # t351: "cuda:0 f32[1, 32, 512, 128]" - # t351 = prims.mul(t350, t349) # t351: "cuda:0 f32[1, 32, 512, 128]" - t352 = ltorch.add(t348, t351, alpha=None) # t352: "cuda:0 f32[1, 32, 512, 128]" - # t352 = prims.add(t348, t351) # t352: "cuda:0 f32[1, 32, 512, 128]" - t353 = prims.convert_element_type(t352, dtypes.bfloat16) # t353: "cuda:0 bf16[1, 32, 512, 128]" - t354 = prims.slice_prim(t309, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t354: "cuda:0 bf16[1, 32, 512, 0]" - t356 = prims.cat((t337, t354), -1) # t356: "cuda:0 bf16[1, 32, 512, 128]" - t357 = prims.slice_prim(t315, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t357: "cuda:0 bf16[1, 32, 512, 0]" - t359 = prims.cat((t353, t357), -1) # t359: "cuda:0 bf16[1, 32, 512, 128]" - (t360, t361, t362, t363) = cudnn_sdpa_fwd(t356, t359, t321, None, 0.0, True, scale=0.08838834764831843) - t366 = prims.transpose(t360, (0, 2, 1, 3)) # t366: "cuda:0 bf16[1, 512, 32, 128]" - t370 = prims.reshape(t366, (1, 512, 4096)) # t370: "cuda:0 bf16[1, 512, 4096]" - t371 = prims.linear(t370, t_transformer_h_2_attn_proj_weight, None) # t371: "cuda:0 bf16[1, 512, 4096]" - t372 = prims.convert_element_type(t371, dtypes.float32) # t372: "cuda:0 f32[1, 512, 4096]" - t373 = prims.convert_element_type(t269, dtypes.float32) # t373: "cuda:0 f32[1, 512, 4096]" - t374 = ltorch.add(t372, t373, alpha=None) # t374: "cuda:0 f32[1, 512, 4096]" - # t374 = prims.add(t372, t373) # t374: "cuda:0 f32[1, 512, 4096]" - t375 = prims.convert_element_type(t374, dtypes.bfloat16) # t375: "cuda:0 bf16[1, 512, 4096]" - t376 = prims.convert_element_type(t375, dtypes.float32) # t376: "cuda:0 f32[1, 512, 4096]" - t377 = ltorch.mul(t376, t376) # t377: "cuda:0 f32[1, 512, 4096]" - # t377 = prims.mul(t376, t376) # t377: "cuda:0 f32[1, 512, 4096]" - t379 = prims.sum(t377, (2,)) # t379: "cuda:0 f32[1, 512]" - t380 = prims.broadcast_in_dim(t379, [1, 512, 1], [0, 1]) # t380: "cuda:0 f32[1, 512, 1]" - t382 = ltorch.true_divide(t380, 4096.0) # t382: "cuda:0 f32[1, 512, 1]" - # t382 = prims.div(t380, 4096.0) # t382: "cuda:0 f32[1, 512, 1]" - t384 = ltorch.add(t382, 1e-05, alpha=None) # t384: "cuda:0 f32[1, 512, 1]" - # t384 = prims.add(t382, 1e-05) # t384: "cuda:0 f32[1, 512, 1]" - t385 = prims.rsqrt(t384) # t385: "cuda:0 f32[1, 512, 1]" - t386 = prims.broadcast_in_dim(t385, (1, 512, 4096), (0, 1, 2)) # t386: "cuda:0 f32[1, 512, 4096]" - t387 = ltorch.mul(t376, t386) # t387: "cuda:0 f32[1, 512, 4096]" - # t387 = prims.mul(t376, t386) # t387: "cuda:0 f32[1, 512, 4096]" - t388 = prims.convert_element_type(t387, dtypes.bfloat16) # t388: "cuda:0 bf16[1, 512, 4096]" - t389 = prims.broadcast_in_dim(t_transformer_h_2_norm_2_weight, (1, 512, 4096), (2,)) # t389: "cuda:0 bf16[1, 512, 4096]" - t390 = prims.convert_element_type(t388, dtypes.float32) # t390: "cuda:0 f32[1, 512, 4096]" - t391 = prims.convert_element_type(t389, dtypes.float32) # t391: "cuda:0 f32[1, 512, 4096]" - t392 = ltorch.mul(t390, t391) # t392: "cuda:0 f32[1, 512, 4096]" - # t392 = prims.mul(t390, t391) # t392: "cuda:0 f32[1, 512, 4096]" - t393 = prims.convert_element_type(t392, dtypes.bfloat16) # t393: "cuda:0 bf16[1, 512, 4096]" - t394 = prims.linear(t393, t_transformer_h_2_mlp_fc_1_weight, None) # t394: "cuda:0 bf16[1, 512, 11008]" - t395 = prims.linear(t393, t_transformer_h_2_mlp_fc_2_weight, None) # t395: "cuda:0 bf16[1, 512, 11008]" - t396 = prims.convert_element_type(t394, dtypes.float32) # t396: "cuda:0 f32[1, 512, 11008]" - t397 = prims.neg(t396) # t397: "cuda:0 f32[1, 512, 11008]" - t398 = prims.exp(t397) # t398: "cuda:0 f32[1, 512, 11008]" - t399 = ltorch.add(1.0, t398, alpha=None) # t399: "cuda:0 f32[1, 512, 11008]" - # t399 = prims.add(1.0, t398) # t399: "cuda:0 f32[1, 512, 11008]" - t400 = prims.reciprocal(t399) # t400: "cuda:0 f32[1, 512, 11008]" - t401 = prims.convert_element_type(t400, dtypes.bfloat16) # t401: "cuda:0 bf16[1, 512, 11008]" - t402 = prims.convert_element_type(t394, dtypes.float32) # t402: "cuda:0 f32[1, 512, 11008]" - t403 = prims.convert_element_type(t401, dtypes.float32) # t403: "cuda:0 f32[1, 512, 11008]" - t404 = ltorch.mul(t402, t403) # t404: "cuda:0 f32[1, 512, 11008]" - # t404 = prims.mul(t402, t403) # t404: "cuda:0 f32[1, 512, 11008]" - t405 = prims.convert_element_type(t404, dtypes.bfloat16) # t405: "cuda:0 bf16[1, 512, 11008]" - t406 = prims.convert_element_type(t405, dtypes.float32) # t406: "cuda:0 f32[1, 512, 11008]" - t407 = prims.convert_element_type(t395, dtypes.float32) # t407: "cuda:0 f32[1, 512, 11008]" - t408 = ltorch.mul(t406, t407) # t408: "cuda:0 f32[1, 512, 11008]" - # t408 = prims.mul(t406, t407) # t408: "cuda:0 f32[1, 512, 11008]" - t409 = prims.convert_element_type(t408, dtypes.bfloat16) # t409: "cuda:0 bf16[1, 512, 11008]" - t410 = prims.linear(t409, t_transformer_h_2_mlp_proj_weight, None) # t410: "cuda:0 bf16[1, 512, 4096]" - t411 = prims.convert_element_type(t410, dtypes.float32) # t411: "cuda:0 f32[1, 512, 4096]" - t412 = prims.convert_element_type(t375, dtypes.float32) # t412: "cuda:0 f32[1, 512, 4096]" - t413 = ltorch.add(t411, t412, alpha=None) # t413: "cuda:0 f32[1, 512, 4096]" - # t413 = prims.add(t411, t412) # t413: "cuda:0 f32[1, 512, 4096]" - t414 = prims.convert_element_type(t413, dtypes.bfloat16) # t414: "cuda:0 bf16[1, 512, 4096]" - t415 = prims.convert_element_type(t414, dtypes.float32) # t415: "cuda:0 f32[1, 512, 4096]" - t416 = ltorch.mul(t415, t415) # t416: "cuda:0 f32[1, 512, 4096]" - # t416 = prims.mul(t415, t415) # t416: "cuda:0 f32[1, 512, 4096]" - t418 = prims.sum(t416, (2,)) # t418: "cuda:0 f32[1, 512]" - t419 = prims.broadcast_in_dim(t418, [1, 512, 1], [0, 1]) # t419: "cuda:0 f32[1, 512, 1]" - t421 = ltorch.true_divide(t419, 4096.0) # t421: "cuda:0 f32[1, 512, 1]" - # t421 = prims.div(t419, 4096.0) # t421: "cuda:0 f32[1, 512, 1]" - t423 = ltorch.add(t421, 1e-05, alpha=None) # t423: "cuda:0 f32[1, 512, 1]" - # t423 = prims.add(t421, 1e-05) # t423: "cuda:0 f32[1, 512, 1]" - t424 = prims.rsqrt(t423) # t424: "cuda:0 f32[1, 512, 1]" - t425 = prims.broadcast_in_dim(t424, (1, 512, 4096), (0, 1, 2)) # t425: "cuda:0 f32[1, 512, 4096]" - t426 = ltorch.mul(t415, t425) # t426: "cuda:0 f32[1, 512, 4096]" - # t426 = prims.mul(t415, t425) # t426: "cuda:0 f32[1, 512, 4096]" - t427 = prims.convert_element_type(t426, dtypes.bfloat16) # t427: "cuda:0 bf16[1, 512, 4096]" - t428 = prims.broadcast_in_dim(t_transformer_h_3_norm_1_weight, (1, 512, 4096), (2,)) # t428: "cuda:0 bf16[1, 512, 4096]" - t429 = prims.convert_element_type(t427, dtypes.float32) # t429: "cuda:0 f32[1, 512, 4096]" - t430 = prims.convert_element_type(t428, dtypes.float32) # t430: "cuda:0 f32[1, 512, 4096]" - t431 = ltorch.mul(t429, t430) # t431: "cuda:0 f32[1, 512, 4096]" - # t431 = prims.mul(t429, t430) # t431: "cuda:0 f32[1, 512, 4096]" - t432 = prims.convert_element_type(t431, dtypes.bfloat16) # t432: "cuda:0 bf16[1, 512, 4096]" - t433 = prims.linear(t432, t_transformer_h_3_attn_attn_weight, None) # t433: "cuda:0 bf16[1, 512, 12288]" - t439 = prims.reshape(t433, (1, 512, 32, 3, 128)) # t439: "cuda:0 bf16[1, 512, 32, 3, 128]" - t445 = prims.transpose(t439, (0, 2, 3, 1, 4)) # t445: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t446, t447, t448) = ltorch.split(t445, (1, 1, 1), 2) - # t446 = prims.slice_prim(t445, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t446: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t447 = prims.slice_prim(t445, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t447: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t448 = prims.slice_prim(t445, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t448: "cuda:0 bf16[1, 32, 1, 512, 128]" - t454 = prims.reshape(t446, (1, 32, 512, 128)) # t454: "cuda:0 bf16[1, 32, 512, 128]" - t460 = prims.reshape(t447, (1, 32, 512, 128)) # t460: "cuda:0 bf16[1, 32, 512, 128]" - t466 = prims.reshape(t448, (1, 32, 512, 128)) # t466: "cuda:0 bf16[1, 32, 512, 128]" - t467 = prims.slice_prim(t454, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t467: "cuda:0 bf16[1, 32, 512, 128]" - t468 = prims.slice_prim(t467, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t468: "cuda:0 bf16[1, 32, 512, 64]" - t469 = prims.slice_prim(t467, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t469: "cuda:0 bf16[1, 32, 512, 64]" - t470 = prims.convert_element_type(t469, dtypes.float32) # t470: "cuda:0 f32[1, 32, 512, 64]" - t471 = prims.neg(t470) # t471: "cuda:0 f32[1, 32, 512, 64]" - t472 = prims.convert_element_type(t471, dtypes.bfloat16) # t472: "cuda:0 bf16[1, 32, 512, 64]" - t474 = prims.cat((t472, t468), -1) # t474: "cuda:0 bf16[1, 32, 512, 128]" - t475 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t475: "cuda:0 f32[1, 32, 512, 128]" - t476 = prims.convert_element_type(t467, dtypes.float32) # t476: "cuda:0 f32[1, 32, 512, 128]" - t477 = ltorch.mul(t476, t475) # t477: "cuda:0 f32[1, 32, 512, 128]" - # t477 = prims.mul(t476, t475) # t477: "cuda:0 f32[1, 32, 512, 128]" - t478 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t478: "cuda:0 f32[1, 32, 512, 128]" - t479 = prims.convert_element_type(t474, dtypes.float32) # t479: "cuda:0 f32[1, 32, 512, 128]" - t480 = ltorch.mul(t479, t478) # t480: "cuda:0 f32[1, 32, 512, 128]" - # t480 = prims.mul(t479, t478) # t480: "cuda:0 f32[1, 32, 512, 128]" - t481 = ltorch.add(t477, t480, alpha=None) # t481: "cuda:0 f32[1, 32, 512, 128]" - # t481 = prims.add(t477, t480) # t481: "cuda:0 f32[1, 32, 512, 128]" - t482 = prims.convert_element_type(t481, dtypes.bfloat16) # t482: "cuda:0 bf16[1, 32, 512, 128]" - t483 = prims.slice_prim(t460, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t483: "cuda:0 bf16[1, 32, 512, 128]" - t484 = prims.slice_prim(t483, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t484: "cuda:0 bf16[1, 32, 512, 64]" - t485 = prims.slice_prim(t483, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t485: "cuda:0 bf16[1, 32, 512, 64]" - t486 = prims.convert_element_type(t485, dtypes.float32) # t486: "cuda:0 f32[1, 32, 512, 64]" - t487 = prims.neg(t486) # t487: "cuda:0 f32[1, 32, 512, 64]" - t488 = prims.convert_element_type(t487, dtypes.bfloat16) # t488: "cuda:0 bf16[1, 32, 512, 64]" - t490 = prims.cat((t488, t484), -1) # t490: "cuda:0 bf16[1, 32, 512, 128]" - t491 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t491: "cuda:0 f32[1, 32, 512, 128]" - t492 = prims.convert_element_type(t483, dtypes.float32) # t492: "cuda:0 f32[1, 32, 512, 128]" - t493 = ltorch.mul(t492, t491) # t493: "cuda:0 f32[1, 32, 512, 128]" - # t493 = prims.mul(t492, t491) # t493: "cuda:0 f32[1, 32, 512, 128]" - t494 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t494: "cuda:0 f32[1, 32, 512, 128]" - t495 = prims.convert_element_type(t490, dtypes.float32) # t495: "cuda:0 f32[1, 32, 512, 128]" - t496 = ltorch.mul(t495, t494) # t496: "cuda:0 f32[1, 32, 512, 128]" - # t496 = prims.mul(t495, t494) # t496: "cuda:0 f32[1, 32, 512, 128]" - t497 = ltorch.add(t493, t496, alpha=None) # t497: "cuda:0 f32[1, 32, 512, 128]" - # t497 = prims.add(t493, t496) # t497: "cuda:0 f32[1, 32, 512, 128]" - t498 = prims.convert_element_type(t497, dtypes.bfloat16) # t498: "cuda:0 bf16[1, 32, 512, 128]" - t499 = prims.slice_prim(t454, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t499: "cuda:0 bf16[1, 32, 512, 0]" - t501 = prims.cat((t482, t499), -1) # t501: "cuda:0 bf16[1, 32, 512, 128]" - t502 = prims.slice_prim(t460, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t502: "cuda:0 bf16[1, 32, 512, 0]" - t504 = prims.cat((t498, t502), -1) # t504: "cuda:0 bf16[1, 32, 512, 128]" - (t505, t506, t507, t508) = cudnn_sdpa_fwd(t501, t504, t466, None, 0.0, True, scale=0.08838834764831843) - t511 = prims.transpose(t505, (0, 2, 1, 3)) # t511: "cuda:0 bf16[1, 512, 32, 128]" - t515 = prims.reshape(t511, (1, 512, 4096)) # t515: "cuda:0 bf16[1, 512, 4096]" - t516 = prims.linear(t515, t_transformer_h_3_attn_proj_weight, None) # t516: "cuda:0 bf16[1, 512, 4096]" - t517 = prims.convert_element_type(t516, dtypes.float32) # t517: "cuda:0 f32[1, 512, 4096]" - t518 = prims.convert_element_type(t414, dtypes.float32) # t518: "cuda:0 f32[1, 512, 4096]" - t519 = ltorch.add(t517, t518, alpha=None) # t519: "cuda:0 f32[1, 512, 4096]" - # t519 = prims.add(t517, t518) # t519: "cuda:0 f32[1, 512, 4096]" - t520 = prims.convert_element_type(t519, dtypes.bfloat16) # t520: "cuda:0 bf16[1, 512, 4096]" - t521 = prims.convert_element_type(t520, dtypes.float32) # t521: "cuda:0 f32[1, 512, 4096]" - t522 = ltorch.mul(t521, t521) # t522: "cuda:0 f32[1, 512, 4096]" - # t522 = prims.mul(t521, t521) # t522: "cuda:0 f32[1, 512, 4096]" - t524 = prims.sum(t522, (2,)) # t524: "cuda:0 f32[1, 512]" - t525 = prims.broadcast_in_dim(t524, [1, 512, 1], [0, 1]) # t525: "cuda:0 f32[1, 512, 1]" - t527 = ltorch.true_divide(t525, 4096.0) # t527: "cuda:0 f32[1, 512, 1]" - # t527 = prims.div(t525, 4096.0) # t527: "cuda:0 f32[1, 512, 1]" - t529 = ltorch.add(t527, 1e-05, alpha=None) # t529: "cuda:0 f32[1, 512, 1]" - # t529 = prims.add(t527, 1e-05) # t529: "cuda:0 f32[1, 512, 1]" - t530 = prims.rsqrt(t529) # t530: "cuda:0 f32[1, 512, 1]" - t531 = prims.broadcast_in_dim(t530, (1, 512, 4096), (0, 1, 2)) # t531: "cuda:0 f32[1, 512, 4096]" - t532 = ltorch.mul(t521, t531) # t532: "cuda:0 f32[1, 512, 4096]" - # t532 = prims.mul(t521, t531) # t532: "cuda:0 f32[1, 512, 4096]" - t533 = prims.convert_element_type(t532, dtypes.bfloat16) # t533: "cuda:0 bf16[1, 512, 4096]" - t534 = prims.broadcast_in_dim(t_transformer_h_3_norm_2_weight, (1, 512, 4096), (2,)) # t534: "cuda:0 bf16[1, 512, 4096]" - t535 = prims.convert_element_type(t533, dtypes.float32) # t535: "cuda:0 f32[1, 512, 4096]" - t536 = prims.convert_element_type(t534, dtypes.float32) # t536: "cuda:0 f32[1, 512, 4096]" - t537 = ltorch.mul(t535, t536) # t537: "cuda:0 f32[1, 512, 4096]" - # t537 = prims.mul(t535, t536) # t537: "cuda:0 f32[1, 512, 4096]" - t538 = prims.convert_element_type(t537, dtypes.bfloat16) # t538: "cuda:0 bf16[1, 512, 4096]" - t539 = prims.linear(t538, t_transformer_h_3_mlp_fc_1_weight, None) # t539: "cuda:0 bf16[1, 512, 11008]" - t540 = prims.linear(t538, t_transformer_h_3_mlp_fc_2_weight, None) # t540: "cuda:0 bf16[1, 512, 11008]" - t541 = prims.convert_element_type(t539, dtypes.float32) # t541: "cuda:0 f32[1, 512, 11008]" - t542 = prims.neg(t541) # t542: "cuda:0 f32[1, 512, 11008]" - t543 = prims.exp(t542) # t543: "cuda:0 f32[1, 512, 11008]" - t544 = ltorch.add(1.0, t543, alpha=None) # t544: "cuda:0 f32[1, 512, 11008]" - # t544 = prims.add(1.0, t543) # t544: "cuda:0 f32[1, 512, 11008]" - t545 = prims.reciprocal(t544) # t545: "cuda:0 f32[1, 512, 11008]" - t546 = prims.convert_element_type(t545, dtypes.bfloat16) # t546: "cuda:0 bf16[1, 512, 11008]" - t547 = prims.convert_element_type(t539, dtypes.float32) # t547: "cuda:0 f32[1, 512, 11008]" - t548 = prims.convert_element_type(t546, dtypes.float32) # t548: "cuda:0 f32[1, 512, 11008]" - t549 = ltorch.mul(t547, t548) # t549: "cuda:0 f32[1, 512, 11008]" - # t549 = prims.mul(t547, t548) # t549: "cuda:0 f32[1, 512, 11008]" - t550 = prims.convert_element_type(t549, dtypes.bfloat16) # t550: "cuda:0 bf16[1, 512, 11008]" - t551 = prims.convert_element_type(t550, dtypes.float32) # t551: "cuda:0 f32[1, 512, 11008]" - t552 = prims.convert_element_type(t540, dtypes.float32) # t552: "cuda:0 f32[1, 512, 11008]" - t553 = ltorch.mul(t551, t552) # t553: "cuda:0 f32[1, 512, 11008]" - # t553 = prims.mul(t551, t552) # t553: "cuda:0 f32[1, 512, 11008]" - t554 = prims.convert_element_type(t553, dtypes.bfloat16) # t554: "cuda:0 bf16[1, 512, 11008]" - t555 = prims.linear(t554, t_transformer_h_3_mlp_proj_weight, None) # t555: "cuda:0 bf16[1, 512, 4096]" - t556 = prims.convert_element_type(t555, dtypes.float32) # t556: "cuda:0 f32[1, 512, 4096]" - t557 = prims.convert_element_type(t520, dtypes.float32) # t557: "cuda:0 f32[1, 512, 4096]" - t558 = ltorch.add(t556, t557, alpha=None) # t558: "cuda:0 f32[1, 512, 4096]" - # t558 = prims.add(t556, t557) # t558: "cuda:0 f32[1, 512, 4096]" - t559 = prims.convert_element_type(t558, dtypes.bfloat16) # t559: "cuda:0 bf16[1, 512, 4096]" - t560 = prims.convert_element_type(t559, dtypes.float32) # t560: "cuda:0 f32[1, 512, 4096]" - t561 = ltorch.mul(t560, t560) # t561: "cuda:0 f32[1, 512, 4096]" - # t561 = prims.mul(t560, t560) # t561: "cuda:0 f32[1, 512, 4096]" - t563 = prims.sum(t561, (2,)) # t563: "cuda:0 f32[1, 512]" - t564 = prims.broadcast_in_dim(t563, [1, 512, 1], [0, 1]) # t564: "cuda:0 f32[1, 512, 1]" - t566 = ltorch.true_divide(t564, 4096.0) # t566: "cuda:0 f32[1, 512, 1]" - # t566 = prims.div(t564, 4096.0) # t566: "cuda:0 f32[1, 512, 1]" - t568 = ltorch.add(t566, 1e-05, alpha=None) # t568: "cuda:0 f32[1, 512, 1]" - # t568 = prims.add(t566, 1e-05) # t568: "cuda:0 f32[1, 512, 1]" - t569 = prims.rsqrt(t568) # t569: "cuda:0 f32[1, 512, 1]" - t570 = prims.broadcast_in_dim(t569, (1, 512, 4096), (0, 1, 2)) # t570: "cuda:0 f32[1, 512, 4096]" - t571 = ltorch.mul(t560, t570) # t571: "cuda:0 f32[1, 512, 4096]" - # t571 = prims.mul(t560, t570) # t571: "cuda:0 f32[1, 512, 4096]" - t572 = prims.convert_element_type(t571, dtypes.bfloat16) # t572: "cuda:0 bf16[1, 512, 4096]" - t573 = prims.broadcast_in_dim(t_transformer_h_4_norm_1_weight, (1, 512, 4096), (2,)) # t573: "cuda:0 bf16[1, 512, 4096]" - t574 = prims.convert_element_type(t572, dtypes.float32) # t574: "cuda:0 f32[1, 512, 4096]" - t575 = prims.convert_element_type(t573, dtypes.float32) # t575: "cuda:0 f32[1, 512, 4096]" - t576 = ltorch.mul(t574, t575) # t576: "cuda:0 f32[1, 512, 4096]" - # t576 = prims.mul(t574, t575) # t576: "cuda:0 f32[1, 512, 4096]" - t577 = prims.convert_element_type(t576, dtypes.bfloat16) # t577: "cuda:0 bf16[1, 512, 4096]" - t578 = prims.linear(t577, t_transformer_h_4_attn_attn_weight, None) # t578: "cuda:0 bf16[1, 512, 12288]" - t584 = prims.reshape(t578, (1, 512, 32, 3, 128)) # t584: "cuda:0 bf16[1, 512, 32, 3, 128]" - t590 = prims.transpose(t584, (0, 2, 3, 1, 4)) # t590: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t591, t592, t593) = ltorch.split(t590, (1, 1, 1), 2) - # t591 = prims.slice_prim(t590, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t591: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t592 = prims.slice_prim(t590, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t592: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t593 = prims.slice_prim(t590, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t593: "cuda:0 bf16[1, 32, 1, 512, 128]" - t599 = prims.reshape(t591, (1, 32, 512, 128)) # t599: "cuda:0 bf16[1, 32, 512, 128]" - t605 = prims.reshape(t592, (1, 32, 512, 128)) # t605: "cuda:0 bf16[1, 32, 512, 128]" - t611 = prims.reshape(t593, (1, 32, 512, 128)) # t611: "cuda:0 bf16[1, 32, 512, 128]" - t612 = prims.slice_prim(t599, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t612: "cuda:0 bf16[1, 32, 512, 128]" - t613 = prims.slice_prim(t612, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t613: "cuda:0 bf16[1, 32, 512, 64]" - t614 = prims.slice_prim(t612, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t614: "cuda:0 bf16[1, 32, 512, 64]" - t615 = prims.convert_element_type(t614, dtypes.float32) # t615: "cuda:0 f32[1, 32, 512, 64]" - t616 = prims.neg(t615) # t616: "cuda:0 f32[1, 32, 512, 64]" - t617 = prims.convert_element_type(t616, dtypes.bfloat16) # t617: "cuda:0 bf16[1, 32, 512, 64]" - t619 = prims.cat((t617, t613), -1) # t619: "cuda:0 bf16[1, 32, 512, 128]" - t620 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t620: "cuda:0 f32[1, 32, 512, 128]" - t621 = prims.convert_element_type(t612, dtypes.float32) # t621: "cuda:0 f32[1, 32, 512, 128]" - t622 = ltorch.mul(t621, t620) # t622: "cuda:0 f32[1, 32, 512, 128]" - # t622 = prims.mul(t621, t620) # t622: "cuda:0 f32[1, 32, 512, 128]" - t623 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t623: "cuda:0 f32[1, 32, 512, 128]" - t624 = prims.convert_element_type(t619, dtypes.float32) # t624: "cuda:0 f32[1, 32, 512, 128]" - t625 = ltorch.mul(t624, t623) # t625: "cuda:0 f32[1, 32, 512, 128]" - # t625 = prims.mul(t624, t623) # t625: "cuda:0 f32[1, 32, 512, 128]" - t626 = ltorch.add(t622, t625, alpha=None) # t626: "cuda:0 f32[1, 32, 512, 128]" - # t626 = prims.add(t622, t625) # t626: "cuda:0 f32[1, 32, 512, 128]" - t627 = prims.convert_element_type(t626, dtypes.bfloat16) # t627: "cuda:0 bf16[1, 32, 512, 128]" - t628 = prims.slice_prim(t605, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t628: "cuda:0 bf16[1, 32, 512, 128]" - t629 = prims.slice_prim(t628, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t629: "cuda:0 bf16[1, 32, 512, 64]" - t630 = prims.slice_prim(t628, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t630: "cuda:0 bf16[1, 32, 512, 64]" - t631 = prims.convert_element_type(t630, dtypes.float32) # t631: "cuda:0 f32[1, 32, 512, 64]" - t632 = prims.neg(t631) # t632: "cuda:0 f32[1, 32, 512, 64]" - t633 = prims.convert_element_type(t632, dtypes.bfloat16) # t633: "cuda:0 bf16[1, 32, 512, 64]" - t635 = prims.cat((t633, t629), -1) # t635: "cuda:0 bf16[1, 32, 512, 128]" - t636 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t636: "cuda:0 f32[1, 32, 512, 128]" - t637 = prims.convert_element_type(t628, dtypes.float32) # t637: "cuda:0 f32[1, 32, 512, 128]" - t638 = ltorch.mul(t637, t636) # t638: "cuda:0 f32[1, 32, 512, 128]" - # t638 = prims.mul(t637, t636) # t638: "cuda:0 f32[1, 32, 512, 128]" - t639 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t639: "cuda:0 f32[1, 32, 512, 128]" - t640 = prims.convert_element_type(t635, dtypes.float32) # t640: "cuda:0 f32[1, 32, 512, 128]" - t641 = ltorch.mul(t640, t639) # t641: "cuda:0 f32[1, 32, 512, 128]" - # t641 = prims.mul(t640, t639) # t641: "cuda:0 f32[1, 32, 512, 128]" - t642 = ltorch.add(t638, t641, alpha=None) # t642: "cuda:0 f32[1, 32, 512, 128]" - # t642 = prims.add(t638, t641) # t642: "cuda:0 f32[1, 32, 512, 128]" - t643 = prims.convert_element_type(t642, dtypes.bfloat16) # t643: "cuda:0 bf16[1, 32, 512, 128]" - t644 = prims.slice_prim(t599, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t644: "cuda:0 bf16[1, 32, 512, 0]" - t646 = prims.cat((t627, t644), -1) # t646: "cuda:0 bf16[1, 32, 512, 128]" - t647 = prims.slice_prim(t605, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t647: "cuda:0 bf16[1, 32, 512, 0]" - t649 = prims.cat((t643, t647), -1) # t649: "cuda:0 bf16[1, 32, 512, 128]" - (t650, t651, t652, t653) = cudnn_sdpa_fwd(t646, t649, t611, None, 0.0, True, scale=0.08838834764831843) - t656 = prims.transpose(t650, (0, 2, 1, 3)) # t656: "cuda:0 bf16[1, 512, 32, 128]" - t660 = prims.reshape(t656, (1, 512, 4096)) # t660: "cuda:0 bf16[1, 512, 4096]" - t661 = prims.linear(t660, t_transformer_h_4_attn_proj_weight, None) # t661: "cuda:0 bf16[1, 512, 4096]" - t662 = prims.convert_element_type(t661, dtypes.float32) # t662: "cuda:0 f32[1, 512, 4096]" - t663 = prims.convert_element_type(t559, dtypes.float32) # t663: "cuda:0 f32[1, 512, 4096]" - t664 = ltorch.add(t662, t663, alpha=None) # t664: "cuda:0 f32[1, 512, 4096]" - # t664 = prims.add(t662, t663) # t664: "cuda:0 f32[1, 512, 4096]" - t665 = prims.convert_element_type(t664, dtypes.bfloat16) # t665: "cuda:0 bf16[1, 512, 4096]" - t666 = prims.convert_element_type(t665, dtypes.float32) # t666: "cuda:0 f32[1, 512, 4096]" - t667 = ltorch.mul(t666, t666) # t667: "cuda:0 f32[1, 512, 4096]" - # t667 = prims.mul(t666, t666) # t667: "cuda:0 f32[1, 512, 4096]" - t669 = prims.sum(t667, (2,)) # t669: "cuda:0 f32[1, 512]" - t670 = prims.broadcast_in_dim(t669, [1, 512, 1], [0, 1]) # t670: "cuda:0 f32[1, 512, 1]" - t672 = ltorch.true_divide(t670, 4096.0) # t672: "cuda:0 f32[1, 512, 1]" - # t672 = prims.div(t670, 4096.0) # t672: "cuda:0 f32[1, 512, 1]" - t674 = ltorch.add(t672, 1e-05, alpha=None) # t674: "cuda:0 f32[1, 512, 1]" - # t674 = prims.add(t672, 1e-05) # t674: "cuda:0 f32[1, 512, 1]" - t675 = prims.rsqrt(t674) # t675: "cuda:0 f32[1, 512, 1]" - t676 = prims.broadcast_in_dim(t675, (1, 512, 4096), (0, 1, 2)) # t676: "cuda:0 f32[1, 512, 4096]" - t677 = ltorch.mul(t666, t676) # t677: "cuda:0 f32[1, 512, 4096]" - # t677 = prims.mul(t666, t676) # t677: "cuda:0 f32[1, 512, 4096]" - t678 = prims.convert_element_type(t677, dtypes.bfloat16) # t678: "cuda:0 bf16[1, 512, 4096]" - t679 = prims.broadcast_in_dim(t_transformer_h_4_norm_2_weight, (1, 512, 4096), (2,)) # t679: "cuda:0 bf16[1, 512, 4096]" - t680 = prims.convert_element_type(t678, dtypes.float32) # t680: "cuda:0 f32[1, 512, 4096]" - t681 = prims.convert_element_type(t679, dtypes.float32) # t681: "cuda:0 f32[1, 512, 4096]" - t682 = ltorch.mul(t680, t681) # t682: "cuda:0 f32[1, 512, 4096]" - # t682 = prims.mul(t680, t681) # t682: "cuda:0 f32[1, 512, 4096]" - t683 = prims.convert_element_type(t682, dtypes.bfloat16) # t683: "cuda:0 bf16[1, 512, 4096]" - t684 = prims.linear(t683, t_transformer_h_4_mlp_fc_1_weight, None) # t684: "cuda:0 bf16[1, 512, 11008]" - t685 = prims.linear(t683, t_transformer_h_4_mlp_fc_2_weight, None) # t685: "cuda:0 bf16[1, 512, 11008]" - t686 = prims.convert_element_type(t684, dtypes.float32) # t686: "cuda:0 f32[1, 512, 11008]" - t687 = prims.neg(t686) # t687: "cuda:0 f32[1, 512, 11008]" - t688 = prims.exp(t687) # t688: "cuda:0 f32[1, 512, 11008]" - t689 = ltorch.add(1.0, t688, alpha=None) # t689: "cuda:0 f32[1, 512, 11008]" - # t689 = prims.add(1.0, t688) # t689: "cuda:0 f32[1, 512, 11008]" - t690 = prims.reciprocal(t689) # t690: "cuda:0 f32[1, 512, 11008]" - t691 = prims.convert_element_type(t690, dtypes.bfloat16) # t691: "cuda:0 bf16[1, 512, 11008]" - t692 = prims.convert_element_type(t684, dtypes.float32) # t692: "cuda:0 f32[1, 512, 11008]" - t693 = prims.convert_element_type(t691, dtypes.float32) # t693: "cuda:0 f32[1, 512, 11008]" - t694 = ltorch.mul(t692, t693) # t694: "cuda:0 f32[1, 512, 11008]" - # t694 = prims.mul(t692, t693) # t694: "cuda:0 f32[1, 512, 11008]" - t695 = prims.convert_element_type(t694, dtypes.bfloat16) # t695: "cuda:0 bf16[1, 512, 11008]" - t696 = prims.convert_element_type(t695, dtypes.float32) # t696: "cuda:0 f32[1, 512, 11008]" - t697 = prims.convert_element_type(t685, dtypes.float32) # t697: "cuda:0 f32[1, 512, 11008]" - t698 = ltorch.mul(t696, t697) # t698: "cuda:0 f32[1, 512, 11008]" - # t698 = prims.mul(t696, t697) # t698: "cuda:0 f32[1, 512, 11008]" - t699 = prims.convert_element_type(t698, dtypes.bfloat16) # t699: "cuda:0 bf16[1, 512, 11008]" - t700 = prims.linear(t699, t_transformer_h_4_mlp_proj_weight, None) # t700: "cuda:0 bf16[1, 512, 4096]" - t701 = prims.convert_element_type(t700, dtypes.float32) # t701: "cuda:0 f32[1, 512, 4096]" - t702 = prims.convert_element_type(t665, dtypes.float32) # t702: "cuda:0 f32[1, 512, 4096]" - t703 = ltorch.add(t701, t702, alpha=None) # t703: "cuda:0 f32[1, 512, 4096]" - # t703 = prims.add(t701, t702) # t703: "cuda:0 f32[1, 512, 4096]" - t704 = prims.convert_element_type(t703, dtypes.bfloat16) # t704: "cuda:0 bf16[1, 512, 4096]" - t705 = prims.convert_element_type(t704, dtypes.float32) # t705: "cuda:0 f32[1, 512, 4096]" - t706 = ltorch.mul(t705, t705) # t706: "cuda:0 f32[1, 512, 4096]" - # t706 = prims.mul(t705, t705) # t706: "cuda:0 f32[1, 512, 4096]" - t708 = prims.sum(t706, (2,)) # t708: "cuda:0 f32[1, 512]" - t709 = prims.broadcast_in_dim(t708, [1, 512, 1], [0, 1]) # t709: "cuda:0 f32[1, 512, 1]" - t711 = ltorch.true_divide(t709, 4096.0) # t711: "cuda:0 f32[1, 512, 1]" - # t711 = prims.div(t709, 4096.0) # t711: "cuda:0 f32[1, 512, 1]" - t713 = ltorch.add(t711, 1e-05, alpha=None) # t713: "cuda:0 f32[1, 512, 1]" - # t713 = prims.add(t711, 1e-05) # t713: "cuda:0 f32[1, 512, 1]" - t714 = prims.rsqrt(t713) # t714: "cuda:0 f32[1, 512, 1]" - t715 = prims.broadcast_in_dim(t714, (1, 512, 4096), (0, 1, 2)) # t715: "cuda:0 f32[1, 512, 4096]" - t716 = ltorch.mul(t705, t715) # t716: "cuda:0 f32[1, 512, 4096]" - # t716 = prims.mul(t705, t715) # t716: "cuda:0 f32[1, 512, 4096]" - t717 = prims.convert_element_type(t716, dtypes.bfloat16) # t717: "cuda:0 bf16[1, 512, 4096]" - t718 = prims.broadcast_in_dim(t_transformer_h_5_norm_1_weight, (1, 512, 4096), (2,)) # t718: "cuda:0 bf16[1, 512, 4096]" - t719 = prims.convert_element_type(t717, dtypes.float32) # t719: "cuda:0 f32[1, 512, 4096]" - t720 = prims.convert_element_type(t718, dtypes.float32) # t720: "cuda:0 f32[1, 512, 4096]" - t721 = ltorch.mul(t719, t720) # t721: "cuda:0 f32[1, 512, 4096]" - # t721 = prims.mul(t719, t720) # t721: "cuda:0 f32[1, 512, 4096]" - t722 = prims.convert_element_type(t721, dtypes.bfloat16) # t722: "cuda:0 bf16[1, 512, 4096]" - t723 = prims.linear(t722, t_transformer_h_5_attn_attn_weight, None) # t723: "cuda:0 bf16[1, 512, 12288]" - t729 = prims.reshape(t723, (1, 512, 32, 3, 128)) # t729: "cuda:0 bf16[1, 512, 32, 3, 128]" - t735 = prims.transpose(t729, (0, 2, 3, 1, 4)) # t735: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t736, t737, t738) = ltorch.split(t735, (1, 1, 1), 2) - # t736 = prims.slice_prim(t735, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t736: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t737 = prims.slice_prim(t735, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t737: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t738 = prims.slice_prim(t735, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t738: "cuda:0 bf16[1, 32, 1, 512, 128]" - t744 = prims.reshape(t736, (1, 32, 512, 128)) # t744: "cuda:0 bf16[1, 32, 512, 128]" - t750 = prims.reshape(t737, (1, 32, 512, 128)) # t750: "cuda:0 bf16[1, 32, 512, 128]" - t756 = prims.reshape(t738, (1, 32, 512, 128)) # t756: "cuda:0 bf16[1, 32, 512, 128]" - t757 = prims.slice_prim(t744, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t757: "cuda:0 bf16[1, 32, 512, 128]" - t758 = prims.slice_prim(t757, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t758: "cuda:0 bf16[1, 32, 512, 64]" - t759 = prims.slice_prim(t757, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t759: "cuda:0 bf16[1, 32, 512, 64]" - t760 = prims.convert_element_type(t759, dtypes.float32) # t760: "cuda:0 f32[1, 32, 512, 64]" - t761 = prims.neg(t760) # t761: "cuda:0 f32[1, 32, 512, 64]" - t762 = prims.convert_element_type(t761, dtypes.bfloat16) # t762: "cuda:0 bf16[1, 32, 512, 64]" - t764 = prims.cat((t762, t758), -1) # t764: "cuda:0 bf16[1, 32, 512, 128]" - t765 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t765: "cuda:0 f32[1, 32, 512, 128]" - t766 = prims.convert_element_type(t757, dtypes.float32) # t766: "cuda:0 f32[1, 32, 512, 128]" - t767 = ltorch.mul(t766, t765) # t767: "cuda:0 f32[1, 32, 512, 128]" - # t767 = prims.mul(t766, t765) # t767: "cuda:0 f32[1, 32, 512, 128]" - t768 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t768: "cuda:0 f32[1, 32, 512, 128]" - t769 = prims.convert_element_type(t764, dtypes.float32) # t769: "cuda:0 f32[1, 32, 512, 128]" - t770 = ltorch.mul(t769, t768) # t770: "cuda:0 f32[1, 32, 512, 128]" - # t770 = prims.mul(t769, t768) # t770: "cuda:0 f32[1, 32, 512, 128]" - t771 = ltorch.add(t767, t770, alpha=None) # t771: "cuda:0 f32[1, 32, 512, 128]" - # t771 = prims.add(t767, t770) # t771: "cuda:0 f32[1, 32, 512, 128]" - t772 = prims.convert_element_type(t771, dtypes.bfloat16) # t772: "cuda:0 bf16[1, 32, 512, 128]" - t773 = prims.slice_prim(t750, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t773: "cuda:0 bf16[1, 32, 512, 128]" - t774 = prims.slice_prim(t773, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t774: "cuda:0 bf16[1, 32, 512, 64]" - t775 = prims.slice_prim(t773, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t775: "cuda:0 bf16[1, 32, 512, 64]" - t776 = prims.convert_element_type(t775, dtypes.float32) # t776: "cuda:0 f32[1, 32, 512, 64]" - t777 = prims.neg(t776) # t777: "cuda:0 f32[1, 32, 512, 64]" - t778 = prims.convert_element_type(t777, dtypes.bfloat16) # t778: "cuda:0 bf16[1, 32, 512, 64]" - t780 = prims.cat((t778, t774), -1) # t780: "cuda:0 bf16[1, 32, 512, 128]" - t781 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t781: "cuda:0 f32[1, 32, 512, 128]" - t782 = prims.convert_element_type(t773, dtypes.float32) # t782: "cuda:0 f32[1, 32, 512, 128]" - t783 = ltorch.mul(t782, t781) # t783: "cuda:0 f32[1, 32, 512, 128]" - # t783 = prims.mul(t782, t781) # t783: "cuda:0 f32[1, 32, 512, 128]" - t784 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t784: "cuda:0 f32[1, 32, 512, 128]" - t785 = prims.convert_element_type(t780, dtypes.float32) # t785: "cuda:0 f32[1, 32, 512, 128]" - t786 = ltorch.mul(t785, t784) # t786: "cuda:0 f32[1, 32, 512, 128]" - # t786 = prims.mul(t785, t784) # t786: "cuda:0 f32[1, 32, 512, 128]" - t787 = ltorch.add(t783, t786, alpha=None) # t787: "cuda:0 f32[1, 32, 512, 128]" - # t787 = prims.add(t783, t786) # t787: "cuda:0 f32[1, 32, 512, 128]" - t788 = prims.convert_element_type(t787, dtypes.bfloat16) # t788: "cuda:0 bf16[1, 32, 512, 128]" - t789 = prims.slice_prim(t744, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t789: "cuda:0 bf16[1, 32, 512, 0]" - t791 = prims.cat((t772, t789), -1) # t791: "cuda:0 bf16[1, 32, 512, 128]" - t792 = prims.slice_prim(t750, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t792: "cuda:0 bf16[1, 32, 512, 0]" - t794 = prims.cat((t788, t792), -1) # t794: "cuda:0 bf16[1, 32, 512, 128]" - (t795, t796, t797, t798) = cudnn_sdpa_fwd(t791, t794, t756, None, 0.0, True, scale=0.08838834764831843) - t801 = prims.transpose(t795, (0, 2, 1, 3)) # t801: "cuda:0 bf16[1, 512, 32, 128]" - t805 = prims.reshape(t801, (1, 512, 4096)) # t805: "cuda:0 bf16[1, 512, 4096]" - t806 = prims.linear(t805, t_transformer_h_5_attn_proj_weight, None) # t806: "cuda:0 bf16[1, 512, 4096]" - t807 = prims.convert_element_type(t806, dtypes.float32) # t807: "cuda:0 f32[1, 512, 4096]" - t808 = prims.convert_element_type(t704, dtypes.float32) # t808: "cuda:0 f32[1, 512, 4096]" - t809 = ltorch.add(t807, t808, alpha=None) # t809: "cuda:0 f32[1, 512, 4096]" - # t809 = prims.add(t807, t808) # t809: "cuda:0 f32[1, 512, 4096]" - t810 = prims.convert_element_type(t809, dtypes.bfloat16) # t810: "cuda:0 bf16[1, 512, 4096]" - t811 = prims.convert_element_type(t810, dtypes.float32) # t811: "cuda:0 f32[1, 512, 4096]" - t812 = ltorch.mul(t811, t811) # t812: "cuda:0 f32[1, 512, 4096]" - # t812 = prims.mul(t811, t811) # t812: "cuda:0 f32[1, 512, 4096]" - t814 = prims.sum(t812, (2,)) # t814: "cuda:0 f32[1, 512]" - t815 = prims.broadcast_in_dim(t814, [1, 512, 1], [0, 1]) # t815: "cuda:0 f32[1, 512, 1]" - t817 = ltorch.true_divide(t815, 4096.0) # t817: "cuda:0 f32[1, 512, 1]" - # t817 = prims.div(t815, 4096.0) # t817: "cuda:0 f32[1, 512, 1]" - t819 = ltorch.add(t817, 1e-05, alpha=None) # t819: "cuda:0 f32[1, 512, 1]" - # t819 = prims.add(t817, 1e-05) # t819: "cuda:0 f32[1, 512, 1]" - t820 = prims.rsqrt(t819) # t820: "cuda:0 f32[1, 512, 1]" - t821 = prims.broadcast_in_dim(t820, (1, 512, 4096), (0, 1, 2)) # t821: "cuda:0 f32[1, 512, 4096]" - t822 = ltorch.mul(t811, t821) # t822: "cuda:0 f32[1, 512, 4096]" - # t822 = prims.mul(t811, t821) # t822: "cuda:0 f32[1, 512, 4096]" - t823 = prims.convert_element_type(t822, dtypes.bfloat16) # t823: "cuda:0 bf16[1, 512, 4096]" - t824 = prims.broadcast_in_dim(t_transformer_h_5_norm_2_weight, (1, 512, 4096), (2,)) # t824: "cuda:0 bf16[1, 512, 4096]" - t825 = prims.convert_element_type(t823, dtypes.float32) # t825: "cuda:0 f32[1, 512, 4096]" - t826 = prims.convert_element_type(t824, dtypes.float32) # t826: "cuda:0 f32[1, 512, 4096]" - t827 = ltorch.mul(t825, t826) # t827: "cuda:0 f32[1, 512, 4096]" - # t827 = prims.mul(t825, t826) # t827: "cuda:0 f32[1, 512, 4096]" - t828 = prims.convert_element_type(t827, dtypes.bfloat16) # t828: "cuda:0 bf16[1, 512, 4096]" - t829 = prims.linear(t828, t_transformer_h_5_mlp_fc_1_weight, None) # t829: "cuda:0 bf16[1, 512, 11008]" - t830 = prims.linear(t828, t_transformer_h_5_mlp_fc_2_weight, None) # t830: "cuda:0 bf16[1, 512, 11008]" - t831 = prims.convert_element_type(t829, dtypes.float32) # t831: "cuda:0 f32[1, 512, 11008]" - t832 = prims.neg(t831) # t832: "cuda:0 f32[1, 512, 11008]" - t833 = prims.exp(t832) # t833: "cuda:0 f32[1, 512, 11008]" - t834 = ltorch.add(1.0, t833, alpha=None) # t834: "cuda:0 f32[1, 512, 11008]" - # t834 = prims.add(1.0, t833) # t834: "cuda:0 f32[1, 512, 11008]" - t835 = prims.reciprocal(t834) # t835: "cuda:0 f32[1, 512, 11008]" - t836 = prims.convert_element_type(t835, dtypes.bfloat16) # t836: "cuda:0 bf16[1, 512, 11008]" - t837 = prims.convert_element_type(t829, dtypes.float32) # t837: "cuda:0 f32[1, 512, 11008]" - t838 = prims.convert_element_type(t836, dtypes.float32) # t838: "cuda:0 f32[1, 512, 11008]" - t839 = ltorch.mul(t837, t838) # t839: "cuda:0 f32[1, 512, 11008]" - # t839 = prims.mul(t837, t838) # t839: "cuda:0 f32[1, 512, 11008]" - t840 = prims.convert_element_type(t839, dtypes.bfloat16) # t840: "cuda:0 bf16[1, 512, 11008]" - t841 = prims.convert_element_type(t840, dtypes.float32) # t841: "cuda:0 f32[1, 512, 11008]" - t842 = prims.convert_element_type(t830, dtypes.float32) # t842: "cuda:0 f32[1, 512, 11008]" - t843 = ltorch.mul(t841, t842) # t843: "cuda:0 f32[1, 512, 11008]" - # t843 = prims.mul(t841, t842) # t843: "cuda:0 f32[1, 512, 11008]" - t844 = prims.convert_element_type(t843, dtypes.bfloat16) # t844: "cuda:0 bf16[1, 512, 11008]" - t845 = prims.linear(t844, t_transformer_h_5_mlp_proj_weight, None) # t845: "cuda:0 bf16[1, 512, 4096]" - t846 = prims.convert_element_type(t845, dtypes.float32) # t846: "cuda:0 f32[1, 512, 4096]" - t847 = prims.convert_element_type(t810, dtypes.float32) # t847: "cuda:0 f32[1, 512, 4096]" - t848 = ltorch.add(t846, t847, alpha=None) # t848: "cuda:0 f32[1, 512, 4096]" - # t848 = prims.add(t846, t847) # t848: "cuda:0 f32[1, 512, 4096]" - t849 = prims.convert_element_type(t848, dtypes.bfloat16) # t849: "cuda:0 bf16[1, 512, 4096]" - t850 = prims.convert_element_type(t849, dtypes.float32) # t850: "cuda:0 f32[1, 512, 4096]" - t851 = ltorch.mul(t850, t850) # t851: "cuda:0 f32[1, 512, 4096]" - # t851 = prims.mul(t850, t850) # t851: "cuda:0 f32[1, 512, 4096]" - t853 = prims.sum(t851, (2,)) # t853: "cuda:0 f32[1, 512]" - t854 = prims.broadcast_in_dim(t853, [1, 512, 1], [0, 1]) # t854: "cuda:0 f32[1, 512, 1]" - t856 = ltorch.true_divide(t854, 4096.0) # t856: "cuda:0 f32[1, 512, 1]" - # t856 = prims.div(t854, 4096.0) # t856: "cuda:0 f32[1, 512, 1]" - t858 = ltorch.add(t856, 1e-05, alpha=None) # t858: "cuda:0 f32[1, 512, 1]" - # t858 = prims.add(t856, 1e-05) # t858: "cuda:0 f32[1, 512, 1]" - t859 = prims.rsqrt(t858) # t859: "cuda:0 f32[1, 512, 1]" - t860 = prims.broadcast_in_dim(t859, (1, 512, 4096), (0, 1, 2)) # t860: "cuda:0 f32[1, 512, 4096]" - t861 = ltorch.mul(t850, t860) # t861: "cuda:0 f32[1, 512, 4096]" - # t861 = prims.mul(t850, t860) # t861: "cuda:0 f32[1, 512, 4096]" - t862 = prims.convert_element_type(t861, dtypes.bfloat16) # t862: "cuda:0 bf16[1, 512, 4096]" - t863 = prims.broadcast_in_dim(t_transformer_h_6_norm_1_weight, (1, 512, 4096), (2,)) # t863: "cuda:0 bf16[1, 512, 4096]" - t864 = prims.convert_element_type(t862, dtypes.float32) # t864: "cuda:0 f32[1, 512, 4096]" - t865 = prims.convert_element_type(t863, dtypes.float32) # t865: "cuda:0 f32[1, 512, 4096]" - t866 = ltorch.mul(t864, t865) # t866: "cuda:0 f32[1, 512, 4096]" - # t866 = prims.mul(t864, t865) # t866: "cuda:0 f32[1, 512, 4096]" - t867 = prims.convert_element_type(t866, dtypes.bfloat16) # t867: "cuda:0 bf16[1, 512, 4096]" - t868 = prims.linear(t867, t_transformer_h_6_attn_attn_weight, None) # t868: "cuda:0 bf16[1, 512, 12288]" - t874 = prims.reshape(t868, (1, 512, 32, 3, 128)) # t874: "cuda:0 bf16[1, 512, 32, 3, 128]" - t880 = prims.transpose(t874, (0, 2, 3, 1, 4)) # t880: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t881, t882, t883) = ltorch.split(t880, (1, 1, 1), 2) - # t881 = prims.slice_prim(t880, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t881: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t882 = prims.slice_prim(t880, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t882: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t883 = prims.slice_prim(t880, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t883: "cuda:0 bf16[1, 32, 1, 512, 128]" - t889 = prims.reshape(t881, (1, 32, 512, 128)) # t889: "cuda:0 bf16[1, 32, 512, 128]" - t895 = prims.reshape(t882, (1, 32, 512, 128)) # t895: "cuda:0 bf16[1, 32, 512, 128]" - t901 = prims.reshape(t883, (1, 32, 512, 128)) # t901: "cuda:0 bf16[1, 32, 512, 128]" - t902 = prims.slice_prim(t889, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t902: "cuda:0 bf16[1, 32, 512, 128]" - t903 = prims.slice_prim(t902, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t903: "cuda:0 bf16[1, 32, 512, 64]" - t904 = prims.slice_prim(t902, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t904: "cuda:0 bf16[1, 32, 512, 64]" - t905 = prims.convert_element_type(t904, dtypes.float32) # t905: "cuda:0 f32[1, 32, 512, 64]" - t906 = prims.neg(t905) # t906: "cuda:0 f32[1, 32, 512, 64]" - t907 = prims.convert_element_type(t906, dtypes.bfloat16) # t907: "cuda:0 bf16[1, 32, 512, 64]" - t909 = prims.cat((t907, t903), -1) # t909: "cuda:0 bf16[1, 32, 512, 128]" - t910 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t910: "cuda:0 f32[1, 32, 512, 128]" - t911 = prims.convert_element_type(t902, dtypes.float32) # t911: "cuda:0 f32[1, 32, 512, 128]" - t912 = ltorch.mul(t911, t910) # t912: "cuda:0 f32[1, 32, 512, 128]" - # t912 = prims.mul(t911, t910) # t912: "cuda:0 f32[1, 32, 512, 128]" - t913 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t913: "cuda:0 f32[1, 32, 512, 128]" - t914 = prims.convert_element_type(t909, dtypes.float32) # t914: "cuda:0 f32[1, 32, 512, 128]" - t915 = ltorch.mul(t914, t913) # t915: "cuda:0 f32[1, 32, 512, 128]" - # t915 = prims.mul(t914, t913) # t915: "cuda:0 f32[1, 32, 512, 128]" - t916 = ltorch.add(t912, t915, alpha=None) # t916: "cuda:0 f32[1, 32, 512, 128]" - # t916 = prims.add(t912, t915) # t916: "cuda:0 f32[1, 32, 512, 128]" - t917 = prims.convert_element_type(t916, dtypes.bfloat16) # t917: "cuda:0 bf16[1, 32, 512, 128]" - t918 = prims.slice_prim(t895, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t918: "cuda:0 bf16[1, 32, 512, 128]" - t919 = prims.slice_prim(t918, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t919: "cuda:0 bf16[1, 32, 512, 64]" - t920 = prims.slice_prim(t918, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t920: "cuda:0 bf16[1, 32, 512, 64]" - t921 = prims.convert_element_type(t920, dtypes.float32) # t921: "cuda:0 f32[1, 32, 512, 64]" - t922 = prims.neg(t921) # t922: "cuda:0 f32[1, 32, 512, 64]" - t923 = prims.convert_element_type(t922, dtypes.bfloat16) # t923: "cuda:0 bf16[1, 32, 512, 64]" - t925 = prims.cat((t923, t919), -1) # t925: "cuda:0 bf16[1, 32, 512, 128]" - t926 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t926: "cuda:0 f32[1, 32, 512, 128]" - t927 = prims.convert_element_type(t918, dtypes.float32) # t927: "cuda:0 f32[1, 32, 512, 128]" - t928 = ltorch.mul(t927, t926) # t928: "cuda:0 f32[1, 32, 512, 128]" - # t928 = prims.mul(t927, t926) # t928: "cuda:0 f32[1, 32, 512, 128]" - t929 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t929: "cuda:0 f32[1, 32, 512, 128]" - t930 = prims.convert_element_type(t925, dtypes.float32) # t930: "cuda:0 f32[1, 32, 512, 128]" - t931 = ltorch.mul(t930, t929) # t931: "cuda:0 f32[1, 32, 512, 128]" - # t931 = prims.mul(t930, t929) # t931: "cuda:0 f32[1, 32, 512, 128]" - t932 = ltorch.add(t928, t931, alpha=None) # t932: "cuda:0 f32[1, 32, 512, 128]" - # t932 = prims.add(t928, t931) # t932: "cuda:0 f32[1, 32, 512, 128]" - t933 = prims.convert_element_type(t932, dtypes.bfloat16) # t933: "cuda:0 bf16[1, 32, 512, 128]" - t934 = prims.slice_prim(t889, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t934: "cuda:0 bf16[1, 32, 512, 0]" - t936 = prims.cat((t917, t934), -1) # t936: "cuda:0 bf16[1, 32, 512, 128]" - t937 = prims.slice_prim(t895, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t937: "cuda:0 bf16[1, 32, 512, 0]" - t939 = prims.cat((t933, t937), -1) # t939: "cuda:0 bf16[1, 32, 512, 128]" - (t940, t941, t942, t943) = cudnn_sdpa_fwd(t936, t939, t901, None, 0.0, True, scale=0.08838834764831843) - t946 = prims.transpose(t940, (0, 2, 1, 3)) # t946: "cuda:0 bf16[1, 512, 32, 128]" - t950 = prims.reshape(t946, (1, 512, 4096)) # t950: "cuda:0 bf16[1, 512, 4096]" - t951 = prims.linear(t950, t_transformer_h_6_attn_proj_weight, None) # t951: "cuda:0 bf16[1, 512, 4096]" - t952 = prims.convert_element_type(t951, dtypes.float32) # t952: "cuda:0 f32[1, 512, 4096]" - t953 = prims.convert_element_type(t849, dtypes.float32) # t953: "cuda:0 f32[1, 512, 4096]" - t954 = ltorch.add(t952, t953, alpha=None) # t954: "cuda:0 f32[1, 512, 4096]" - # t954 = prims.add(t952, t953) # t954: "cuda:0 f32[1, 512, 4096]" - t955 = prims.convert_element_type(t954, dtypes.bfloat16) # t955: "cuda:0 bf16[1, 512, 4096]" - t956 = prims.convert_element_type(t955, dtypes.float32) # t956: "cuda:0 f32[1, 512, 4096]" - t957 = ltorch.mul(t956, t956) # t957: "cuda:0 f32[1, 512, 4096]" - # t957 = prims.mul(t956, t956) # t957: "cuda:0 f32[1, 512, 4096]" - t959 = prims.sum(t957, (2,)) # t959: "cuda:0 f32[1, 512]" - t960 = prims.broadcast_in_dim(t959, [1, 512, 1], [0, 1]) # t960: "cuda:0 f32[1, 512, 1]" - t962 = ltorch.true_divide(t960, 4096.0) # t962: "cuda:0 f32[1, 512, 1]" - # t962 = prims.div(t960, 4096.0) # t962: "cuda:0 f32[1, 512, 1]" - t964 = ltorch.add(t962, 1e-05, alpha=None) # t964: "cuda:0 f32[1, 512, 1]" - # t964 = prims.add(t962, 1e-05) # t964: "cuda:0 f32[1, 512, 1]" - t965 = prims.rsqrt(t964) # t965: "cuda:0 f32[1, 512, 1]" - t966 = prims.broadcast_in_dim(t965, (1, 512, 4096), (0, 1, 2)) # t966: "cuda:0 f32[1, 512, 4096]" - t967 = ltorch.mul(t956, t966) # t967: "cuda:0 f32[1, 512, 4096]" - # t967 = prims.mul(t956, t966) # t967: "cuda:0 f32[1, 512, 4096]" - t968 = prims.convert_element_type(t967, dtypes.bfloat16) # t968: "cuda:0 bf16[1, 512, 4096]" - t969 = prims.broadcast_in_dim(t_transformer_h_6_norm_2_weight, (1, 512, 4096), (2,)) # t969: "cuda:0 bf16[1, 512, 4096]" - t970 = prims.convert_element_type(t968, dtypes.float32) # t970: "cuda:0 f32[1, 512, 4096]" - t971 = prims.convert_element_type(t969, dtypes.float32) # t971: "cuda:0 f32[1, 512, 4096]" - t972 = ltorch.mul(t970, t971) # t972: "cuda:0 f32[1, 512, 4096]" - # t972 = prims.mul(t970, t971) # t972: "cuda:0 f32[1, 512, 4096]" - t973 = prims.convert_element_type(t972, dtypes.bfloat16) # t973: "cuda:0 bf16[1, 512, 4096]" - t974 = prims.linear(t973, t_transformer_h_6_mlp_fc_1_weight, None) # t974: "cuda:0 bf16[1, 512, 11008]" - t975 = prims.linear(t973, t_transformer_h_6_mlp_fc_2_weight, None) # t975: "cuda:0 bf16[1, 512, 11008]" - t976 = prims.convert_element_type(t974, dtypes.float32) # t976: "cuda:0 f32[1, 512, 11008]" - t977 = prims.neg(t976) # t977: "cuda:0 f32[1, 512, 11008]" - t978 = prims.exp(t977) # t978: "cuda:0 f32[1, 512, 11008]" - t979 = ltorch.add(1.0, t978, alpha=None) # t979: "cuda:0 f32[1, 512, 11008]" - # t979 = prims.add(1.0, t978) # t979: "cuda:0 f32[1, 512, 11008]" - t980 = prims.reciprocal(t979) # t980: "cuda:0 f32[1, 512, 11008]" - t981 = prims.convert_element_type(t980, dtypes.bfloat16) # t981: "cuda:0 bf16[1, 512, 11008]" - t982 = prims.convert_element_type(t974, dtypes.float32) # t982: "cuda:0 f32[1, 512, 11008]" - t983 = prims.convert_element_type(t981, dtypes.float32) # t983: "cuda:0 f32[1, 512, 11008]" - t984 = ltorch.mul(t982, t983) # t984: "cuda:0 f32[1, 512, 11008]" - # t984 = prims.mul(t982, t983) # t984: "cuda:0 f32[1, 512, 11008]" - t985 = prims.convert_element_type(t984, dtypes.bfloat16) # t985: "cuda:0 bf16[1, 512, 11008]" - t986 = prims.convert_element_type(t985, dtypes.float32) # t986: "cuda:0 f32[1, 512, 11008]" - t987 = prims.convert_element_type(t975, dtypes.float32) # t987: "cuda:0 f32[1, 512, 11008]" - t988 = ltorch.mul(t986, t987) # t988: "cuda:0 f32[1, 512, 11008]" - # t988 = prims.mul(t986, t987) # t988: "cuda:0 f32[1, 512, 11008]" - t989 = prims.convert_element_type(t988, dtypes.bfloat16) # t989: "cuda:0 bf16[1, 512, 11008]" - t990 = prims.linear(t989, t_transformer_h_6_mlp_proj_weight, None) # t990: "cuda:0 bf16[1, 512, 4096]" - t991 = prims.convert_element_type(t990, dtypes.float32) # t991: "cuda:0 f32[1, 512, 4096]" - t992 = prims.convert_element_type(t955, dtypes.float32) # t992: "cuda:0 f32[1, 512, 4096]" - t993 = ltorch.add(t991, t992, alpha=None) # t993: "cuda:0 f32[1, 512, 4096]" - # t993 = prims.add(t991, t992) # t993: "cuda:0 f32[1, 512, 4096]" - t994 = prims.convert_element_type(t993, dtypes.bfloat16) # t994: "cuda:0 bf16[1, 512, 4096]" - t995 = prims.convert_element_type(t994, dtypes.float32) # t995: "cuda:0 f32[1, 512, 4096]" - t996 = ltorch.mul(t995, t995) # t996: "cuda:0 f32[1, 512, 4096]" - # t996 = prims.mul(t995, t995) # t996: "cuda:0 f32[1, 512, 4096]" - t998 = prims.sum(t996, (2,)) # t998: "cuda:0 f32[1, 512]" - t999 = prims.broadcast_in_dim(t998, [1, 512, 1], [0, 1]) # t999: "cuda:0 f32[1, 512, 1]" - t1001 = ltorch.true_divide(t999, 4096.0) # t1001: "cuda:0 f32[1, 512, 1]" - # t1001 = prims.div(t999, 4096.0) # t1001: "cuda:0 f32[1, 512, 1]" - t1003 = ltorch.add(t1001, 1e-05, alpha=None) # t1003: "cuda:0 f32[1, 512, 1]" - # t1003 = prims.add(t1001, 1e-05) # t1003: "cuda:0 f32[1, 512, 1]" - t1004 = prims.rsqrt(t1003) # t1004: "cuda:0 f32[1, 512, 1]" - t1005 = prims.broadcast_in_dim(t1004, (1, 512, 4096), (0, 1, 2)) # t1005: "cuda:0 f32[1, 512, 4096]" - t1006 = ltorch.mul(t995, t1005) # t1006: "cuda:0 f32[1, 512, 4096]" - # t1006 = prims.mul(t995, t1005) # t1006: "cuda:0 f32[1, 512, 4096]" - t1007 = prims.convert_element_type(t1006, dtypes.bfloat16) # t1007: "cuda:0 bf16[1, 512, 4096]" - t1008 = prims.broadcast_in_dim(t_transformer_h_7_norm_1_weight, (1, 512, 4096), (2,)) # t1008: "cuda:0 bf16[1, 512, 4096]" - t1009 = prims.convert_element_type(t1007, dtypes.float32) # t1009: "cuda:0 f32[1, 512, 4096]" - t1010 = prims.convert_element_type(t1008, dtypes.float32) # t1010: "cuda:0 f32[1, 512, 4096]" - t1011 = ltorch.mul(t1009, t1010) # t1011: "cuda:0 f32[1, 512, 4096]" - # t1011 = prims.mul(t1009, t1010) # t1011: "cuda:0 f32[1, 512, 4096]" - t1012 = prims.convert_element_type(t1011, dtypes.bfloat16) # t1012: "cuda:0 bf16[1, 512, 4096]" - t1013 = prims.linear(t1012, t_transformer_h_7_attn_attn_weight, None) # t1013: "cuda:0 bf16[1, 512, 12288]" - t1019 = prims.reshape(t1013, (1, 512, 32, 3, 128)) # t1019: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1025 = prims.transpose(t1019, (0, 2, 3, 1, 4)) # t1025: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1026, t1027, t1028) = ltorch.split(t1025, (1, 1, 1), 2) - # t1026 = prims.slice_prim(t1025, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1026: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1027 = prims.slice_prim(t1025, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1027: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1028 = prims.slice_prim(t1025, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1028: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1034 = prims.reshape(t1026, (1, 32, 512, 128)) # t1034: "cuda:0 bf16[1, 32, 512, 128]" - t1040 = prims.reshape(t1027, (1, 32, 512, 128)) # t1040: "cuda:0 bf16[1, 32, 512, 128]" - t1046 = prims.reshape(t1028, (1, 32, 512, 128)) # t1046: "cuda:0 bf16[1, 32, 512, 128]" - t1047 = prims.slice_prim(t1034, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1047: "cuda:0 bf16[1, 32, 512, 128]" - t1048 = prims.slice_prim(t1047, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1048: "cuda:0 bf16[1, 32, 512, 64]" - t1049 = prims.slice_prim(t1047, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1049: "cuda:0 bf16[1, 32, 512, 64]" - t1050 = prims.convert_element_type(t1049, dtypes.float32) # t1050: "cuda:0 f32[1, 32, 512, 64]" - t1051 = prims.neg(t1050) # t1051: "cuda:0 f32[1, 32, 512, 64]" - t1052 = prims.convert_element_type(t1051, dtypes.bfloat16) # t1052: "cuda:0 bf16[1, 32, 512, 64]" - t1054 = prims.cat((t1052, t1048), -1) # t1054: "cuda:0 bf16[1, 32, 512, 128]" - t1055 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1055: "cuda:0 f32[1, 32, 512, 128]" - t1056 = prims.convert_element_type(t1047, dtypes.float32) # t1056: "cuda:0 f32[1, 32, 512, 128]" - t1057 = ltorch.mul(t1056, t1055) # t1057: "cuda:0 f32[1, 32, 512, 128]" - # t1057 = prims.mul(t1056, t1055) # t1057: "cuda:0 f32[1, 32, 512, 128]" - t1058 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1058: "cuda:0 f32[1, 32, 512, 128]" - t1059 = prims.convert_element_type(t1054, dtypes.float32) # t1059: "cuda:0 f32[1, 32, 512, 128]" - t1060 = ltorch.mul(t1059, t1058) # t1060: "cuda:0 f32[1, 32, 512, 128]" - # t1060 = prims.mul(t1059, t1058) # t1060: "cuda:0 f32[1, 32, 512, 128]" - t1061 = ltorch.add(t1057, t1060, alpha=None) # t1061: "cuda:0 f32[1, 32, 512, 128]" - # t1061 = prims.add(t1057, t1060) # t1061: "cuda:0 f32[1, 32, 512, 128]" - t1062 = prims.convert_element_type(t1061, dtypes.bfloat16) # t1062: "cuda:0 bf16[1, 32, 512, 128]" - t1063 = prims.slice_prim(t1040, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1063: "cuda:0 bf16[1, 32, 512, 128]" - t1064 = prims.slice_prim(t1063, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1064: "cuda:0 bf16[1, 32, 512, 64]" - t1065 = prims.slice_prim(t1063, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1065: "cuda:0 bf16[1, 32, 512, 64]" - t1066 = prims.convert_element_type(t1065, dtypes.float32) # t1066: "cuda:0 f32[1, 32, 512, 64]" - t1067 = prims.neg(t1066) # t1067: "cuda:0 f32[1, 32, 512, 64]" - t1068 = prims.convert_element_type(t1067, dtypes.bfloat16) # t1068: "cuda:0 bf16[1, 32, 512, 64]" - t1070 = prims.cat((t1068, t1064), -1) # t1070: "cuda:0 bf16[1, 32, 512, 128]" - t1071 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1071: "cuda:0 f32[1, 32, 512, 128]" - t1072 = prims.convert_element_type(t1063, dtypes.float32) # t1072: "cuda:0 f32[1, 32, 512, 128]" - t1073 = ltorch.mul(t1072, t1071) # t1073: "cuda:0 f32[1, 32, 512, 128]" - # t1073 = prims.mul(t1072, t1071) # t1073: "cuda:0 f32[1, 32, 512, 128]" - t1074 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1074: "cuda:0 f32[1, 32, 512, 128]" - t1075 = prims.convert_element_type(t1070, dtypes.float32) # t1075: "cuda:0 f32[1, 32, 512, 128]" - t1076 = ltorch.mul(t1075, t1074) # t1076: "cuda:0 f32[1, 32, 512, 128]" - # t1076 = prims.mul(t1075, t1074) # t1076: "cuda:0 f32[1, 32, 512, 128]" - t1077 = ltorch.add(t1073, t1076, alpha=None) # t1077: "cuda:0 f32[1, 32, 512, 128]" - # t1077 = prims.add(t1073, t1076) # t1077: "cuda:0 f32[1, 32, 512, 128]" - t1078 = prims.convert_element_type(t1077, dtypes.bfloat16) # t1078: "cuda:0 bf16[1, 32, 512, 128]" - t1079 = prims.slice_prim(t1034, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1079: "cuda:0 bf16[1, 32, 512, 0]" - t1081 = prims.cat((t1062, t1079), -1) # t1081: "cuda:0 bf16[1, 32, 512, 128]" - t1082 = prims.slice_prim(t1040, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1082: "cuda:0 bf16[1, 32, 512, 0]" - t1084 = prims.cat((t1078, t1082), -1) # t1084: "cuda:0 bf16[1, 32, 512, 128]" - (t1085, t1086, t1087, t1088) = cudnn_sdpa_fwd(t1081, t1084, t1046, None, 0.0, True, scale=0.08838834764831843) - t1091 = prims.transpose(t1085, (0, 2, 1, 3)) # t1091: "cuda:0 bf16[1, 512, 32, 128]" - t1095 = prims.reshape(t1091, (1, 512, 4096)) # t1095: "cuda:0 bf16[1, 512, 4096]" - t1096 = prims.linear(t1095, t_transformer_h_7_attn_proj_weight, None) # t1096: "cuda:0 bf16[1, 512, 4096]" - t1097 = prims.convert_element_type(t1096, dtypes.float32) # t1097: "cuda:0 f32[1, 512, 4096]" - t1098 = prims.convert_element_type(t994, dtypes.float32) # t1098: "cuda:0 f32[1, 512, 4096]" - t1099 = ltorch.add(t1097, t1098, alpha=None) # t1099: "cuda:0 f32[1, 512, 4096]" - # t1099 = prims.add(t1097, t1098) # t1099: "cuda:0 f32[1, 512, 4096]" - t1100 = prims.convert_element_type(t1099, dtypes.bfloat16) # t1100: "cuda:0 bf16[1, 512, 4096]" - t1101 = prims.convert_element_type(t1100, dtypes.float32) # t1101: "cuda:0 f32[1, 512, 4096]" - t1102 = ltorch.mul(t1101, t1101) # t1102: "cuda:0 f32[1, 512, 4096]" - # t1102 = prims.mul(t1101, t1101) # t1102: "cuda:0 f32[1, 512, 4096]" - t1104 = prims.sum(t1102, (2,)) # t1104: "cuda:0 f32[1, 512]" - t1105 = prims.broadcast_in_dim(t1104, [1, 512, 1], [0, 1]) # t1105: "cuda:0 f32[1, 512, 1]" - t1107 = ltorch.true_divide(t1105, 4096.0) # t1107: "cuda:0 f32[1, 512, 1]" - # t1107 = prims.div(t1105, 4096.0) # t1107: "cuda:0 f32[1, 512, 1]" - t1109 = ltorch.add(t1107, 1e-05, alpha=None) # t1109: "cuda:0 f32[1, 512, 1]" - # t1109 = prims.add(t1107, 1e-05) # t1109: "cuda:0 f32[1, 512, 1]" - t1110 = prims.rsqrt(t1109) # t1110: "cuda:0 f32[1, 512, 1]" - t1111 = prims.broadcast_in_dim(t1110, (1, 512, 4096), (0, 1, 2)) # t1111: "cuda:0 f32[1, 512, 4096]" - t1112 = ltorch.mul(t1101, t1111) # t1112: "cuda:0 f32[1, 512, 4096]" - # t1112 = prims.mul(t1101, t1111) # t1112: "cuda:0 f32[1, 512, 4096]" - t1113 = prims.convert_element_type(t1112, dtypes.bfloat16) # t1113: "cuda:0 bf16[1, 512, 4096]" - t1114 = prims.broadcast_in_dim(t_transformer_h_7_norm_2_weight, (1, 512, 4096), (2,)) # t1114: "cuda:0 bf16[1, 512, 4096]" - t1115 = prims.convert_element_type(t1113, dtypes.float32) # t1115: "cuda:0 f32[1, 512, 4096]" - t1116 = prims.convert_element_type(t1114, dtypes.float32) # t1116: "cuda:0 f32[1, 512, 4096]" - t1117 = ltorch.mul(t1115, t1116) # t1117: "cuda:0 f32[1, 512, 4096]" - # t1117 = prims.mul(t1115, t1116) # t1117: "cuda:0 f32[1, 512, 4096]" - t1118 = prims.convert_element_type(t1117, dtypes.bfloat16) # t1118: "cuda:0 bf16[1, 512, 4096]" - t1119 = prims.linear(t1118, t_transformer_h_7_mlp_fc_1_weight, None) # t1119: "cuda:0 bf16[1, 512, 11008]" - t1120 = prims.linear(t1118, t_transformer_h_7_mlp_fc_2_weight, None) # t1120: "cuda:0 bf16[1, 512, 11008]" - t1121 = prims.convert_element_type(t1119, dtypes.float32) # t1121: "cuda:0 f32[1, 512, 11008]" - t1122 = prims.neg(t1121) # t1122: "cuda:0 f32[1, 512, 11008]" - t1123 = prims.exp(t1122) # t1123: "cuda:0 f32[1, 512, 11008]" - t1124 = ltorch.add(1.0, t1123, alpha=None) # t1124: "cuda:0 f32[1, 512, 11008]" - # t1124 = prims.add(1.0, t1123) # t1124: "cuda:0 f32[1, 512, 11008]" - t1125 = prims.reciprocal(t1124) # t1125: "cuda:0 f32[1, 512, 11008]" - t1126 = prims.convert_element_type(t1125, dtypes.bfloat16) # t1126: "cuda:0 bf16[1, 512, 11008]" - t1127 = prims.convert_element_type(t1119, dtypes.float32) # t1127: "cuda:0 f32[1, 512, 11008]" - t1128 = prims.convert_element_type(t1126, dtypes.float32) # t1128: "cuda:0 f32[1, 512, 11008]" - t1129 = ltorch.mul(t1127, t1128) # t1129: "cuda:0 f32[1, 512, 11008]" - # t1129 = prims.mul(t1127, t1128) # t1129: "cuda:0 f32[1, 512, 11008]" - t1130 = prims.convert_element_type(t1129, dtypes.bfloat16) # t1130: "cuda:0 bf16[1, 512, 11008]" - t1131 = prims.convert_element_type(t1130, dtypes.float32) # t1131: "cuda:0 f32[1, 512, 11008]" - t1132 = prims.convert_element_type(t1120, dtypes.float32) # t1132: "cuda:0 f32[1, 512, 11008]" - t1133 = ltorch.mul(t1131, t1132) # t1133: "cuda:0 f32[1, 512, 11008]" - # t1133 = prims.mul(t1131, t1132) # t1133: "cuda:0 f32[1, 512, 11008]" - t1134 = prims.convert_element_type(t1133, dtypes.bfloat16) # t1134: "cuda:0 bf16[1, 512, 11008]" - t1135 = prims.linear(t1134, t_transformer_h_7_mlp_proj_weight, None) # t1135: "cuda:0 bf16[1, 512, 4096]" - t1136 = prims.convert_element_type(t1135, dtypes.float32) # t1136: "cuda:0 f32[1, 512, 4096]" - t1137 = prims.convert_element_type(t1100, dtypes.float32) # t1137: "cuda:0 f32[1, 512, 4096]" - t1138 = ltorch.add(t1136, t1137, alpha=None) # t1138: "cuda:0 f32[1, 512, 4096]" - # t1138 = prims.add(t1136, t1137) # t1138: "cuda:0 f32[1, 512, 4096]" - t1139 = prims.convert_element_type(t1138, dtypes.bfloat16) # t1139: "cuda:0 bf16[1, 512, 4096]" - t1140 = prims.convert_element_type(t1139, dtypes.float32) # t1140: "cuda:0 f32[1, 512, 4096]" - t1141 = ltorch.mul(t1140, t1140) # t1141: "cuda:0 f32[1, 512, 4096]" - # t1141 = prims.mul(t1140, t1140) # t1141: "cuda:0 f32[1, 512, 4096]" - t1143 = prims.sum(t1141, (2,)) # t1143: "cuda:0 f32[1, 512]" - t1144 = prims.broadcast_in_dim(t1143, [1, 512, 1], [0, 1]) # t1144: "cuda:0 f32[1, 512, 1]" - t1146 = ltorch.true_divide(t1144, 4096.0) # t1146: "cuda:0 f32[1, 512, 1]" - # t1146 = prims.div(t1144, 4096.0) # t1146: "cuda:0 f32[1, 512, 1]" - t1148 = ltorch.add(t1146, 1e-05, alpha=None) # t1148: "cuda:0 f32[1, 512, 1]" - # t1148 = prims.add(t1146, 1e-05) # t1148: "cuda:0 f32[1, 512, 1]" - t1149 = prims.rsqrt(t1148) # t1149: "cuda:0 f32[1, 512, 1]" - t1150 = prims.broadcast_in_dim(t1149, (1, 512, 4096), (0, 1, 2)) # t1150: "cuda:0 f32[1, 512, 4096]" - t1151 = ltorch.mul(t1140, t1150) # t1151: "cuda:0 f32[1, 512, 4096]" - # t1151 = prims.mul(t1140, t1150) # t1151: "cuda:0 f32[1, 512, 4096]" - t1152 = prims.convert_element_type(t1151, dtypes.bfloat16) # t1152: "cuda:0 bf16[1, 512, 4096]" - t1153 = prims.broadcast_in_dim(t_transformer_h_8_norm_1_weight, (1, 512, 4096), (2,)) # t1153: "cuda:0 bf16[1, 512, 4096]" - t1154 = prims.convert_element_type(t1152, dtypes.float32) # t1154: "cuda:0 f32[1, 512, 4096]" - t1155 = prims.convert_element_type(t1153, dtypes.float32) # t1155: "cuda:0 f32[1, 512, 4096]" - t1156 = ltorch.mul(t1154, t1155) # t1156: "cuda:0 f32[1, 512, 4096]" - # t1156 = prims.mul(t1154, t1155) # t1156: "cuda:0 f32[1, 512, 4096]" - t1157 = prims.convert_element_type(t1156, dtypes.bfloat16) # t1157: "cuda:0 bf16[1, 512, 4096]" - t1158 = prims.linear(t1157, t_transformer_h_8_attn_attn_weight, None) # t1158: "cuda:0 bf16[1, 512, 12288]" - t1164 = prims.reshape(t1158, (1, 512, 32, 3, 128)) # t1164: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1170 = prims.transpose(t1164, (0, 2, 3, 1, 4)) # t1170: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1171, t1172, t1173) = ltorch.split(t1170, (1, 1, 1), 2) - # t1171 = prims.slice_prim(t1170, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1171: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1172 = prims.slice_prim(t1170, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1172: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1173 = prims.slice_prim(t1170, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1173: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1179 = prims.reshape(t1171, (1, 32, 512, 128)) # t1179: "cuda:0 bf16[1, 32, 512, 128]" - t1185 = prims.reshape(t1172, (1, 32, 512, 128)) # t1185: "cuda:0 bf16[1, 32, 512, 128]" - t1191 = prims.reshape(t1173, (1, 32, 512, 128)) # t1191: "cuda:0 bf16[1, 32, 512, 128]" - t1192 = prims.slice_prim(t1179, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1192: "cuda:0 bf16[1, 32, 512, 128]" - t1193 = prims.slice_prim(t1192, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1193: "cuda:0 bf16[1, 32, 512, 64]" - t1194 = prims.slice_prim(t1192, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1194: "cuda:0 bf16[1, 32, 512, 64]" - t1195 = prims.convert_element_type(t1194, dtypes.float32) # t1195: "cuda:0 f32[1, 32, 512, 64]" - t1196 = prims.neg(t1195) # t1196: "cuda:0 f32[1, 32, 512, 64]" - t1197 = prims.convert_element_type(t1196, dtypes.bfloat16) # t1197: "cuda:0 bf16[1, 32, 512, 64]" - t1199 = prims.cat((t1197, t1193), -1) # t1199: "cuda:0 bf16[1, 32, 512, 128]" - t1200 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1200: "cuda:0 f32[1, 32, 512, 128]" - t1201 = prims.convert_element_type(t1192, dtypes.float32) # t1201: "cuda:0 f32[1, 32, 512, 128]" - t1202 = ltorch.mul(t1201, t1200) # t1202: "cuda:0 f32[1, 32, 512, 128]" - # t1202 = prims.mul(t1201, t1200) # t1202: "cuda:0 f32[1, 32, 512, 128]" - t1203 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1203: "cuda:0 f32[1, 32, 512, 128]" - t1204 = prims.convert_element_type(t1199, dtypes.float32) # t1204: "cuda:0 f32[1, 32, 512, 128]" - t1205 = ltorch.mul(t1204, t1203) # t1205: "cuda:0 f32[1, 32, 512, 128]" - # t1205 = prims.mul(t1204, t1203) # t1205: "cuda:0 f32[1, 32, 512, 128]" - t1206 = ltorch.add(t1202, t1205, alpha=None) # t1206: "cuda:0 f32[1, 32, 512, 128]" - # t1206 = prims.add(t1202, t1205) # t1206: "cuda:0 f32[1, 32, 512, 128]" - t1207 = prims.convert_element_type(t1206, dtypes.bfloat16) # t1207: "cuda:0 bf16[1, 32, 512, 128]" - t1208 = prims.slice_prim(t1185, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1208: "cuda:0 bf16[1, 32, 512, 128]" - t1209 = prims.slice_prim(t1208, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1209: "cuda:0 bf16[1, 32, 512, 64]" - t1210 = prims.slice_prim(t1208, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1210: "cuda:0 bf16[1, 32, 512, 64]" - t1211 = prims.convert_element_type(t1210, dtypes.float32) # t1211: "cuda:0 f32[1, 32, 512, 64]" - t1212 = prims.neg(t1211) # t1212: "cuda:0 f32[1, 32, 512, 64]" - t1213 = prims.convert_element_type(t1212, dtypes.bfloat16) # t1213: "cuda:0 bf16[1, 32, 512, 64]" - t1215 = prims.cat((t1213, t1209), -1) # t1215: "cuda:0 bf16[1, 32, 512, 128]" - t1216 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1216: "cuda:0 f32[1, 32, 512, 128]" - t1217 = prims.convert_element_type(t1208, dtypes.float32) # t1217: "cuda:0 f32[1, 32, 512, 128]" - t1218 = ltorch.mul(t1217, t1216) # t1218: "cuda:0 f32[1, 32, 512, 128]" - # t1218 = prims.mul(t1217, t1216) # t1218: "cuda:0 f32[1, 32, 512, 128]" - t1219 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1219: "cuda:0 f32[1, 32, 512, 128]" - t1220 = prims.convert_element_type(t1215, dtypes.float32) # t1220: "cuda:0 f32[1, 32, 512, 128]" - t1221 = ltorch.mul(t1220, t1219) # t1221: "cuda:0 f32[1, 32, 512, 128]" - # t1221 = prims.mul(t1220, t1219) # t1221: "cuda:0 f32[1, 32, 512, 128]" - t1222 = ltorch.add(t1218, t1221, alpha=None) # t1222: "cuda:0 f32[1, 32, 512, 128]" - # t1222 = prims.add(t1218, t1221) # t1222: "cuda:0 f32[1, 32, 512, 128]" - t1223 = prims.convert_element_type(t1222, dtypes.bfloat16) # t1223: "cuda:0 bf16[1, 32, 512, 128]" - t1224 = prims.slice_prim(t1179, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1224: "cuda:0 bf16[1, 32, 512, 0]" - t1226 = prims.cat((t1207, t1224), -1) # t1226: "cuda:0 bf16[1, 32, 512, 128]" - t1227 = prims.slice_prim(t1185, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1227: "cuda:0 bf16[1, 32, 512, 0]" - t1229 = prims.cat((t1223, t1227), -1) # t1229: "cuda:0 bf16[1, 32, 512, 128]" - (t1230, t1231, t1232, t1233) = cudnn_sdpa_fwd(t1226, t1229, t1191, None, 0.0, True, scale=0.08838834764831843) - t1236 = prims.transpose(t1230, (0, 2, 1, 3)) # t1236: "cuda:0 bf16[1, 512, 32, 128]" - t1240 = prims.reshape(t1236, (1, 512, 4096)) # t1240: "cuda:0 bf16[1, 512, 4096]" - t1241 = prims.linear(t1240, t_transformer_h_8_attn_proj_weight, None) # t1241: "cuda:0 bf16[1, 512, 4096]" - t1242 = prims.convert_element_type(t1241, dtypes.float32) # t1242: "cuda:0 f32[1, 512, 4096]" - t1243 = prims.convert_element_type(t1139, dtypes.float32) # t1243: "cuda:0 f32[1, 512, 4096]" - t1244 = ltorch.add(t1242, t1243, alpha=None) # t1244: "cuda:0 f32[1, 512, 4096]" - # t1244 = prims.add(t1242, t1243) # t1244: "cuda:0 f32[1, 512, 4096]" - t1245 = prims.convert_element_type(t1244, dtypes.bfloat16) # t1245: "cuda:0 bf16[1, 512, 4096]" - t1246 = prims.convert_element_type(t1245, dtypes.float32) # t1246: "cuda:0 f32[1, 512, 4096]" - t1247 = ltorch.mul(t1246, t1246) # t1247: "cuda:0 f32[1, 512, 4096]" - # t1247 = prims.mul(t1246, t1246) # t1247: "cuda:0 f32[1, 512, 4096]" - t1249 = prims.sum(t1247, (2,)) # t1249: "cuda:0 f32[1, 512]" - t1250 = prims.broadcast_in_dim(t1249, [1, 512, 1], [0, 1]) # t1250: "cuda:0 f32[1, 512, 1]" - t1252 = ltorch.true_divide(t1250, 4096.0) # t1252: "cuda:0 f32[1, 512, 1]" - # t1252 = prims.div(t1250, 4096.0) # t1252: "cuda:0 f32[1, 512, 1]" - t1254 = ltorch.add(t1252, 1e-05, alpha=None) # t1254: "cuda:0 f32[1, 512, 1]" - # t1254 = prims.add(t1252, 1e-05) # t1254: "cuda:0 f32[1, 512, 1]" - t1255 = prims.rsqrt(t1254) # t1255: "cuda:0 f32[1, 512, 1]" - t1256 = prims.broadcast_in_dim(t1255, (1, 512, 4096), (0, 1, 2)) # t1256: "cuda:0 f32[1, 512, 4096]" - t1257 = ltorch.mul(t1246, t1256) # t1257: "cuda:0 f32[1, 512, 4096]" - # t1257 = prims.mul(t1246, t1256) # t1257: "cuda:0 f32[1, 512, 4096]" - t1258 = prims.convert_element_type(t1257, dtypes.bfloat16) # t1258: "cuda:0 bf16[1, 512, 4096]" - t1259 = prims.broadcast_in_dim(t_transformer_h_8_norm_2_weight, (1, 512, 4096), (2,)) # t1259: "cuda:0 bf16[1, 512, 4096]" - t1260 = prims.convert_element_type(t1258, dtypes.float32) # t1260: "cuda:0 f32[1, 512, 4096]" - t1261 = prims.convert_element_type(t1259, dtypes.float32) # t1261: "cuda:0 f32[1, 512, 4096]" - t1262 = ltorch.mul(t1260, t1261) # t1262: "cuda:0 f32[1, 512, 4096]" - # t1262 = prims.mul(t1260, t1261) # t1262: "cuda:0 f32[1, 512, 4096]" - t1263 = prims.convert_element_type(t1262, dtypes.bfloat16) # t1263: "cuda:0 bf16[1, 512, 4096]" - t1264 = prims.linear(t1263, t_transformer_h_8_mlp_fc_1_weight, None) # t1264: "cuda:0 bf16[1, 512, 11008]" - t1265 = prims.linear(t1263, t_transformer_h_8_mlp_fc_2_weight, None) # t1265: "cuda:0 bf16[1, 512, 11008]" - t1266 = prims.convert_element_type(t1264, dtypes.float32) # t1266: "cuda:0 f32[1, 512, 11008]" - t1267 = prims.neg(t1266) # t1267: "cuda:0 f32[1, 512, 11008]" - t1268 = prims.exp(t1267) # t1268: "cuda:0 f32[1, 512, 11008]" - t1269 = ltorch.add(1.0, t1268, alpha=None) # t1269: "cuda:0 f32[1, 512, 11008]" - # t1269 = prims.add(1.0, t1268) # t1269: "cuda:0 f32[1, 512, 11008]" - t1270 = prims.reciprocal(t1269) # t1270: "cuda:0 f32[1, 512, 11008]" - t1271 = prims.convert_element_type(t1270, dtypes.bfloat16) # t1271: "cuda:0 bf16[1, 512, 11008]" - t1272 = prims.convert_element_type(t1264, dtypes.float32) # t1272: "cuda:0 f32[1, 512, 11008]" - t1273 = prims.convert_element_type(t1271, dtypes.float32) # t1273: "cuda:0 f32[1, 512, 11008]" - t1274 = ltorch.mul(t1272, t1273) # t1274: "cuda:0 f32[1, 512, 11008]" - # t1274 = prims.mul(t1272, t1273) # t1274: "cuda:0 f32[1, 512, 11008]" - t1275 = prims.convert_element_type(t1274, dtypes.bfloat16) # t1275: "cuda:0 bf16[1, 512, 11008]" - t1276 = prims.convert_element_type(t1275, dtypes.float32) # t1276: "cuda:0 f32[1, 512, 11008]" - t1277 = prims.convert_element_type(t1265, dtypes.float32) # t1277: "cuda:0 f32[1, 512, 11008]" - t1278 = ltorch.mul(t1276, t1277) # t1278: "cuda:0 f32[1, 512, 11008]" - # t1278 = prims.mul(t1276, t1277) # t1278: "cuda:0 f32[1, 512, 11008]" - t1279 = prims.convert_element_type(t1278, dtypes.bfloat16) # t1279: "cuda:0 bf16[1, 512, 11008]" - t1280 = prims.linear(t1279, t_transformer_h_8_mlp_proj_weight, None) # t1280: "cuda:0 bf16[1, 512, 4096]" - t1281 = prims.convert_element_type(t1280, dtypes.float32) # t1281: "cuda:0 f32[1, 512, 4096]" - t1282 = prims.convert_element_type(t1245, dtypes.float32) # t1282: "cuda:0 f32[1, 512, 4096]" - t1283 = ltorch.add(t1281, t1282, alpha=None) # t1283: "cuda:0 f32[1, 512, 4096]" - # t1283 = prims.add(t1281, t1282) # t1283: "cuda:0 f32[1, 512, 4096]" - t1284 = prims.convert_element_type(t1283, dtypes.bfloat16) # t1284: "cuda:0 bf16[1, 512, 4096]" - t1285 = prims.convert_element_type(t1284, dtypes.float32) # t1285: "cuda:0 f32[1, 512, 4096]" - t1286 = ltorch.mul(t1285, t1285) # t1286: "cuda:0 f32[1, 512, 4096]" - # t1286 = prims.mul(t1285, t1285) # t1286: "cuda:0 f32[1, 512, 4096]" - t1288 = prims.sum(t1286, (2,)) # t1288: "cuda:0 f32[1, 512]" - t1289 = prims.broadcast_in_dim(t1288, [1, 512, 1], [0, 1]) # t1289: "cuda:0 f32[1, 512, 1]" - t1291 = ltorch.true_divide(t1289, 4096.0) # t1291: "cuda:0 f32[1, 512, 1]" - # t1291 = prims.div(t1289, 4096.0) # t1291: "cuda:0 f32[1, 512, 1]" - t1293 = ltorch.add(t1291, 1e-05, alpha=None) # t1293: "cuda:0 f32[1, 512, 1]" - # t1293 = prims.add(t1291, 1e-05) # t1293: "cuda:0 f32[1, 512, 1]" - t1294 = prims.rsqrt(t1293) # t1294: "cuda:0 f32[1, 512, 1]" - t1295 = prims.broadcast_in_dim(t1294, (1, 512, 4096), (0, 1, 2)) # t1295: "cuda:0 f32[1, 512, 4096]" - t1296 = ltorch.mul(t1285, t1295) # t1296: "cuda:0 f32[1, 512, 4096]" - # t1296 = prims.mul(t1285, t1295) # t1296: "cuda:0 f32[1, 512, 4096]" - t1297 = prims.convert_element_type(t1296, dtypes.bfloat16) # t1297: "cuda:0 bf16[1, 512, 4096]" - t1298 = prims.broadcast_in_dim(t_transformer_h_9_norm_1_weight, (1, 512, 4096), (2,)) # t1298: "cuda:0 bf16[1, 512, 4096]" - t1299 = prims.convert_element_type(t1297, dtypes.float32) # t1299: "cuda:0 f32[1, 512, 4096]" - t1300 = prims.convert_element_type(t1298, dtypes.float32) # t1300: "cuda:0 f32[1, 512, 4096]" - t1301 = ltorch.mul(t1299, t1300) # t1301: "cuda:0 f32[1, 512, 4096]" - # t1301 = prims.mul(t1299, t1300) # t1301: "cuda:0 f32[1, 512, 4096]" - t1302 = prims.convert_element_type(t1301, dtypes.bfloat16) # t1302: "cuda:0 bf16[1, 512, 4096]" - t1303 = prims.linear(t1302, t_transformer_h_9_attn_attn_weight, None) # t1303: "cuda:0 bf16[1, 512, 12288]" - t1309 = prims.reshape(t1303, (1, 512, 32, 3, 128)) # t1309: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1315 = prims.transpose(t1309, (0, 2, 3, 1, 4)) # t1315: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1316, t1317, t1318) = ltorch.split(t1315, (1, 1, 1), 2) - # t1316 = prims.slice_prim(t1315, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1316: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1317 = prims.slice_prim(t1315, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1317: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1318 = prims.slice_prim(t1315, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1318: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1324 = prims.reshape(t1316, (1, 32, 512, 128)) # t1324: "cuda:0 bf16[1, 32, 512, 128]" - t1330 = prims.reshape(t1317, (1, 32, 512, 128)) # t1330: "cuda:0 bf16[1, 32, 512, 128]" - t1336 = prims.reshape(t1318, (1, 32, 512, 128)) # t1336: "cuda:0 bf16[1, 32, 512, 128]" - t1337 = prims.slice_prim(t1324, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1337: "cuda:0 bf16[1, 32, 512, 128]" - t1338 = prims.slice_prim(t1337, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1338: "cuda:0 bf16[1, 32, 512, 64]" - t1339 = prims.slice_prim(t1337, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1339: "cuda:0 bf16[1, 32, 512, 64]" - t1340 = prims.convert_element_type(t1339, dtypes.float32) # t1340: "cuda:0 f32[1, 32, 512, 64]" - t1341 = prims.neg(t1340) # t1341: "cuda:0 f32[1, 32, 512, 64]" - t1342 = prims.convert_element_type(t1341, dtypes.bfloat16) # t1342: "cuda:0 bf16[1, 32, 512, 64]" - t1344 = prims.cat((t1342, t1338), -1) # t1344: "cuda:0 bf16[1, 32, 512, 128]" - t1345 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1345: "cuda:0 f32[1, 32, 512, 128]" - t1346 = prims.convert_element_type(t1337, dtypes.float32) # t1346: "cuda:0 f32[1, 32, 512, 128]" - t1347 = ltorch.mul(t1346, t1345) # t1347: "cuda:0 f32[1, 32, 512, 128]" - # t1347 = prims.mul(t1346, t1345) # t1347: "cuda:0 f32[1, 32, 512, 128]" - t1348 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1348: "cuda:0 f32[1, 32, 512, 128]" - t1349 = prims.convert_element_type(t1344, dtypes.float32) # t1349: "cuda:0 f32[1, 32, 512, 128]" - t1350 = ltorch.mul(t1349, t1348) # t1350: "cuda:0 f32[1, 32, 512, 128]" - # t1350 = prims.mul(t1349, t1348) # t1350: "cuda:0 f32[1, 32, 512, 128]" - t1351 = ltorch.add(t1347, t1350, alpha=None) # t1351: "cuda:0 f32[1, 32, 512, 128]" - # t1351 = prims.add(t1347, t1350) # t1351: "cuda:0 f32[1, 32, 512, 128]" - t1352 = prims.convert_element_type(t1351, dtypes.bfloat16) # t1352: "cuda:0 bf16[1, 32, 512, 128]" - t1353 = prims.slice_prim(t1330, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1353: "cuda:0 bf16[1, 32, 512, 128]" - t1354 = prims.slice_prim(t1353, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1354: "cuda:0 bf16[1, 32, 512, 64]" - t1355 = prims.slice_prim(t1353, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1355: "cuda:0 bf16[1, 32, 512, 64]" - t1356 = prims.convert_element_type(t1355, dtypes.float32) # t1356: "cuda:0 f32[1, 32, 512, 64]" - t1357 = prims.neg(t1356) # t1357: "cuda:0 f32[1, 32, 512, 64]" - t1358 = prims.convert_element_type(t1357, dtypes.bfloat16) # t1358: "cuda:0 bf16[1, 32, 512, 64]" - t1360 = prims.cat((t1358, t1354), -1) # t1360: "cuda:0 bf16[1, 32, 512, 128]" - t1361 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1361: "cuda:0 f32[1, 32, 512, 128]" - t1362 = prims.convert_element_type(t1353, dtypes.float32) # t1362: "cuda:0 f32[1, 32, 512, 128]" - t1363 = ltorch.mul(t1362, t1361) # t1363: "cuda:0 f32[1, 32, 512, 128]" - # t1363 = prims.mul(t1362, t1361) # t1363: "cuda:0 f32[1, 32, 512, 128]" - t1364 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1364: "cuda:0 f32[1, 32, 512, 128]" - t1365 = prims.convert_element_type(t1360, dtypes.float32) # t1365: "cuda:0 f32[1, 32, 512, 128]" - t1366 = ltorch.mul(t1365, t1364) # t1366: "cuda:0 f32[1, 32, 512, 128]" - # t1366 = prims.mul(t1365, t1364) # t1366: "cuda:0 f32[1, 32, 512, 128]" - t1367 = ltorch.add(t1363, t1366, alpha=None) # t1367: "cuda:0 f32[1, 32, 512, 128]" - # t1367 = prims.add(t1363, t1366) # t1367: "cuda:0 f32[1, 32, 512, 128]" - t1368 = prims.convert_element_type(t1367, dtypes.bfloat16) # t1368: "cuda:0 bf16[1, 32, 512, 128]" - t1369 = prims.slice_prim(t1324, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1369: "cuda:0 bf16[1, 32, 512, 0]" - t1371 = prims.cat((t1352, t1369), -1) # t1371: "cuda:0 bf16[1, 32, 512, 128]" - t1372 = prims.slice_prim(t1330, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1372: "cuda:0 bf16[1, 32, 512, 0]" - t1374 = prims.cat((t1368, t1372), -1) # t1374: "cuda:0 bf16[1, 32, 512, 128]" - (t1375, t1376, t1377, t1378) = cudnn_sdpa_fwd(t1371, t1374, t1336, None, 0.0, True, scale=0.08838834764831843) - t1381 = prims.transpose(t1375, (0, 2, 1, 3)) # t1381: "cuda:0 bf16[1, 512, 32, 128]" - t1385 = prims.reshape(t1381, (1, 512, 4096)) # t1385: "cuda:0 bf16[1, 512, 4096]" - t1386 = prims.linear(t1385, t_transformer_h_9_attn_proj_weight, None) # t1386: "cuda:0 bf16[1, 512, 4096]" - t1387 = prims.convert_element_type(t1386, dtypes.float32) # t1387: "cuda:0 f32[1, 512, 4096]" - t1388 = prims.convert_element_type(t1284, dtypes.float32) # t1388: "cuda:0 f32[1, 512, 4096]" - t1389 = ltorch.add(t1387, t1388, alpha=None) # t1389: "cuda:0 f32[1, 512, 4096]" - # t1389 = prims.add(t1387, t1388) # t1389: "cuda:0 f32[1, 512, 4096]" - t1390 = prims.convert_element_type(t1389, dtypes.bfloat16) # t1390: "cuda:0 bf16[1, 512, 4096]" - t1391 = prims.convert_element_type(t1390, dtypes.float32) # t1391: "cuda:0 f32[1, 512, 4096]" - t1392 = ltorch.mul(t1391, t1391) # t1392: "cuda:0 f32[1, 512, 4096]" - # t1392 = prims.mul(t1391, t1391) # t1392: "cuda:0 f32[1, 512, 4096]" - t1394 = prims.sum(t1392, (2,)) # t1394: "cuda:0 f32[1, 512]" - t1395 = prims.broadcast_in_dim(t1394, [1, 512, 1], [0, 1]) # t1395: "cuda:0 f32[1, 512, 1]" - t1397 = ltorch.true_divide(t1395, 4096.0) # t1397: "cuda:0 f32[1, 512, 1]" - # t1397 = prims.div(t1395, 4096.0) # t1397: "cuda:0 f32[1, 512, 1]" - t1399 = ltorch.add(t1397, 1e-05, alpha=None) # t1399: "cuda:0 f32[1, 512, 1]" - # t1399 = prims.add(t1397, 1e-05) # t1399: "cuda:0 f32[1, 512, 1]" - t1400 = prims.rsqrt(t1399) # t1400: "cuda:0 f32[1, 512, 1]" - t1401 = prims.broadcast_in_dim(t1400, (1, 512, 4096), (0, 1, 2)) # t1401: "cuda:0 f32[1, 512, 4096]" - t1402 = ltorch.mul(t1391, t1401) # t1402: "cuda:0 f32[1, 512, 4096]" - # t1402 = prims.mul(t1391, t1401) # t1402: "cuda:0 f32[1, 512, 4096]" - t1403 = prims.convert_element_type(t1402, dtypes.bfloat16) # t1403: "cuda:0 bf16[1, 512, 4096]" - t1404 = prims.broadcast_in_dim(t_transformer_h_9_norm_2_weight, (1, 512, 4096), (2,)) # t1404: "cuda:0 bf16[1, 512, 4096]" - t1405 = prims.convert_element_type(t1403, dtypes.float32) # t1405: "cuda:0 f32[1, 512, 4096]" - t1406 = prims.convert_element_type(t1404, dtypes.float32) # t1406: "cuda:0 f32[1, 512, 4096]" - t1407 = ltorch.mul(t1405, t1406) # t1407: "cuda:0 f32[1, 512, 4096]" - # t1407 = prims.mul(t1405, t1406) # t1407: "cuda:0 f32[1, 512, 4096]" - t1408 = prims.convert_element_type(t1407, dtypes.bfloat16) # t1408: "cuda:0 bf16[1, 512, 4096]" - t1409 = prims.linear(t1408, t_transformer_h_9_mlp_fc_1_weight, None) # t1409: "cuda:0 bf16[1, 512, 11008]" - t1410 = prims.linear(t1408, t_transformer_h_9_mlp_fc_2_weight, None) # t1410: "cuda:0 bf16[1, 512, 11008]" - t1411 = prims.convert_element_type(t1409, dtypes.float32) # t1411: "cuda:0 f32[1, 512, 11008]" - t1412 = prims.neg(t1411) # t1412: "cuda:0 f32[1, 512, 11008]" - t1413 = prims.exp(t1412) # t1413: "cuda:0 f32[1, 512, 11008]" - t1414 = ltorch.add(1.0, t1413, alpha=None) # t1414: "cuda:0 f32[1, 512, 11008]" - # t1414 = prims.add(1.0, t1413) # t1414: "cuda:0 f32[1, 512, 11008]" - t1415 = prims.reciprocal(t1414) # t1415: "cuda:0 f32[1, 512, 11008]" - t1416 = prims.convert_element_type(t1415, dtypes.bfloat16) # t1416: "cuda:0 bf16[1, 512, 11008]" - t1417 = prims.convert_element_type(t1409, dtypes.float32) # t1417: "cuda:0 f32[1, 512, 11008]" - t1418 = prims.convert_element_type(t1416, dtypes.float32) # t1418: "cuda:0 f32[1, 512, 11008]" - t1419 = ltorch.mul(t1417, t1418) # t1419: "cuda:0 f32[1, 512, 11008]" - # t1419 = prims.mul(t1417, t1418) # t1419: "cuda:0 f32[1, 512, 11008]" - t1420 = prims.convert_element_type(t1419, dtypes.bfloat16) # t1420: "cuda:0 bf16[1, 512, 11008]" - t1421 = prims.convert_element_type(t1420, dtypes.float32) # t1421: "cuda:0 f32[1, 512, 11008]" - t1422 = prims.convert_element_type(t1410, dtypes.float32) # t1422: "cuda:0 f32[1, 512, 11008]" - t1423 = ltorch.mul(t1421, t1422) # t1423: "cuda:0 f32[1, 512, 11008]" - # t1423 = prims.mul(t1421, t1422) # t1423: "cuda:0 f32[1, 512, 11008]" - t1424 = prims.convert_element_type(t1423, dtypes.bfloat16) # t1424: "cuda:0 bf16[1, 512, 11008]" - t1425 = prims.linear(t1424, t_transformer_h_9_mlp_proj_weight, None) # t1425: "cuda:0 bf16[1, 512, 4096]" - t1426 = prims.convert_element_type(t1425, dtypes.float32) # t1426: "cuda:0 f32[1, 512, 4096]" - t1427 = prims.convert_element_type(t1390, dtypes.float32) # t1427: "cuda:0 f32[1, 512, 4096]" - t1428 = ltorch.add(t1426, t1427, alpha=None) # t1428: "cuda:0 f32[1, 512, 4096]" - # t1428 = prims.add(t1426, t1427) # t1428: "cuda:0 f32[1, 512, 4096]" - t1429 = prims.convert_element_type(t1428, dtypes.bfloat16) # t1429: "cuda:0 bf16[1, 512, 4096]" - t1430 = prims.convert_element_type(t1429, dtypes.float32) # t1430: "cuda:0 f32[1, 512, 4096]" - t1431 = ltorch.mul(t1430, t1430) # t1431: "cuda:0 f32[1, 512, 4096]" - # t1431 = prims.mul(t1430, t1430) # t1431: "cuda:0 f32[1, 512, 4096]" - t1433 = prims.sum(t1431, (2,)) # t1433: "cuda:0 f32[1, 512]" - t1434 = prims.broadcast_in_dim(t1433, [1, 512, 1], [0, 1]) # t1434: "cuda:0 f32[1, 512, 1]" - t1436 = ltorch.true_divide(t1434, 4096.0) # t1436: "cuda:0 f32[1, 512, 1]" - # t1436 = prims.div(t1434, 4096.0) # t1436: "cuda:0 f32[1, 512, 1]" - t1438 = ltorch.add(t1436, 1e-05, alpha=None) # t1438: "cuda:0 f32[1, 512, 1]" - # t1438 = prims.add(t1436, 1e-05) # t1438: "cuda:0 f32[1, 512, 1]" - t1439 = prims.rsqrt(t1438) # t1439: "cuda:0 f32[1, 512, 1]" - t1440 = prims.broadcast_in_dim(t1439, (1, 512, 4096), (0, 1, 2)) # t1440: "cuda:0 f32[1, 512, 4096]" - t1441 = ltorch.mul(t1430, t1440) # t1441: "cuda:0 f32[1, 512, 4096]" - # t1441 = prims.mul(t1430, t1440) # t1441: "cuda:0 f32[1, 512, 4096]" - t1442 = prims.convert_element_type(t1441, dtypes.bfloat16) # t1442: "cuda:0 bf16[1, 512, 4096]" - t1443 = prims.broadcast_in_dim(t_transformer_h_10_norm_1_weight, (1, 512, 4096), (2,)) # t1443: "cuda:0 bf16[1, 512, 4096]" - t1444 = prims.convert_element_type(t1442, dtypes.float32) # t1444: "cuda:0 f32[1, 512, 4096]" - t1445 = prims.convert_element_type(t1443, dtypes.float32) # t1445: "cuda:0 f32[1, 512, 4096]" - t1446 = ltorch.mul(t1444, t1445) # t1446: "cuda:0 f32[1, 512, 4096]" - # t1446 = prims.mul(t1444, t1445) # t1446: "cuda:0 f32[1, 512, 4096]" - t1447 = prims.convert_element_type(t1446, dtypes.bfloat16) # t1447: "cuda:0 bf16[1, 512, 4096]" - t1448 = prims.linear(t1447, t_transformer_h_10_attn_attn_weight, None) # t1448: "cuda:0 bf16[1, 512, 12288]" - t1454 = prims.reshape(t1448, (1, 512, 32, 3, 128)) # t1454: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1460 = prims.transpose(t1454, (0, 2, 3, 1, 4)) # t1460: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1461, t1462, t1463) = ltorch.split(t1460, (1, 1, 1), 2) - # t1461 = prims.slice_prim(t1460, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1461: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1462 = prims.slice_prim(t1460, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1462: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1463 = prims.slice_prim(t1460, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1463: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1469 = prims.reshape(t1461, (1, 32, 512, 128)) # t1469: "cuda:0 bf16[1, 32, 512, 128]" - t1475 = prims.reshape(t1462, (1, 32, 512, 128)) # t1475: "cuda:0 bf16[1, 32, 512, 128]" - t1481 = prims.reshape(t1463, (1, 32, 512, 128)) # t1481: "cuda:0 bf16[1, 32, 512, 128]" - t1482 = prims.slice_prim(t1469, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1482: "cuda:0 bf16[1, 32, 512, 128]" - t1483 = prims.slice_prim(t1482, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1483: "cuda:0 bf16[1, 32, 512, 64]" - t1484 = prims.slice_prim(t1482, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1484: "cuda:0 bf16[1, 32, 512, 64]" - t1485 = prims.convert_element_type(t1484, dtypes.float32) # t1485: "cuda:0 f32[1, 32, 512, 64]" - t1486 = prims.neg(t1485) # t1486: "cuda:0 f32[1, 32, 512, 64]" - t1487 = prims.convert_element_type(t1486, dtypes.bfloat16) # t1487: "cuda:0 bf16[1, 32, 512, 64]" - t1489 = prims.cat((t1487, t1483), -1) # t1489: "cuda:0 bf16[1, 32, 512, 128]" - t1490 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1490: "cuda:0 f32[1, 32, 512, 128]" - t1491 = prims.convert_element_type(t1482, dtypes.float32) # t1491: "cuda:0 f32[1, 32, 512, 128]" - t1492 = ltorch.mul(t1491, t1490) # t1492: "cuda:0 f32[1, 32, 512, 128]" - # t1492 = prims.mul(t1491, t1490) # t1492: "cuda:0 f32[1, 32, 512, 128]" - t1493 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1493: "cuda:0 f32[1, 32, 512, 128]" - t1494 = prims.convert_element_type(t1489, dtypes.float32) # t1494: "cuda:0 f32[1, 32, 512, 128]" - t1495 = ltorch.mul(t1494, t1493) # t1495: "cuda:0 f32[1, 32, 512, 128]" - # t1495 = prims.mul(t1494, t1493) # t1495: "cuda:0 f32[1, 32, 512, 128]" - t1496 = ltorch.add(t1492, t1495, alpha=None) # t1496: "cuda:0 f32[1, 32, 512, 128]" - # t1496 = prims.add(t1492, t1495) # t1496: "cuda:0 f32[1, 32, 512, 128]" - t1497 = prims.convert_element_type(t1496, dtypes.bfloat16) # t1497: "cuda:0 bf16[1, 32, 512, 128]" - t1498 = prims.slice_prim(t1475, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1498: "cuda:0 bf16[1, 32, 512, 128]" - t1499 = prims.slice_prim(t1498, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1499: "cuda:0 bf16[1, 32, 512, 64]" - t1500 = prims.slice_prim(t1498, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1500: "cuda:0 bf16[1, 32, 512, 64]" - t1501 = prims.convert_element_type(t1500, dtypes.float32) # t1501: "cuda:0 f32[1, 32, 512, 64]" - t1502 = prims.neg(t1501) # t1502: "cuda:0 f32[1, 32, 512, 64]" - t1503 = prims.convert_element_type(t1502, dtypes.bfloat16) # t1503: "cuda:0 bf16[1, 32, 512, 64]" - t1505 = prims.cat((t1503, t1499), -1) # t1505: "cuda:0 bf16[1, 32, 512, 128]" - t1506 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1506: "cuda:0 f32[1, 32, 512, 128]" - t1507 = prims.convert_element_type(t1498, dtypes.float32) # t1507: "cuda:0 f32[1, 32, 512, 128]" - t1508 = ltorch.mul(t1507, t1506) # t1508: "cuda:0 f32[1, 32, 512, 128]" - # t1508 = prims.mul(t1507, t1506) # t1508: "cuda:0 f32[1, 32, 512, 128]" - t1509 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1509: "cuda:0 f32[1, 32, 512, 128]" - t1510 = prims.convert_element_type(t1505, dtypes.float32) # t1510: "cuda:0 f32[1, 32, 512, 128]" - t1511 = ltorch.mul(t1510, t1509) # t1511: "cuda:0 f32[1, 32, 512, 128]" - # t1511 = prims.mul(t1510, t1509) # t1511: "cuda:0 f32[1, 32, 512, 128]" - t1512 = ltorch.add(t1508, t1511, alpha=None) # t1512: "cuda:0 f32[1, 32, 512, 128]" - # t1512 = prims.add(t1508, t1511) # t1512: "cuda:0 f32[1, 32, 512, 128]" - t1513 = prims.convert_element_type(t1512, dtypes.bfloat16) # t1513: "cuda:0 bf16[1, 32, 512, 128]" - t1514 = prims.slice_prim(t1469, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1514: "cuda:0 bf16[1, 32, 512, 0]" - t1516 = prims.cat((t1497, t1514), -1) # t1516: "cuda:0 bf16[1, 32, 512, 128]" - t1517 = prims.slice_prim(t1475, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1517: "cuda:0 bf16[1, 32, 512, 0]" - t1519 = prims.cat((t1513, t1517), -1) # t1519: "cuda:0 bf16[1, 32, 512, 128]" - (t1520, t1521, t1522, t1523) = cudnn_sdpa_fwd(t1516, t1519, t1481, None, 0.0, True, scale=0.08838834764831843) - t1526 = prims.transpose(t1520, (0, 2, 1, 3)) # t1526: "cuda:0 bf16[1, 512, 32, 128]" - t1530 = prims.reshape(t1526, (1, 512, 4096)) # t1530: "cuda:0 bf16[1, 512, 4096]" - t1531 = prims.linear(t1530, t_transformer_h_10_attn_proj_weight, None) # t1531: "cuda:0 bf16[1, 512, 4096]" - t1532 = prims.convert_element_type(t1531, dtypes.float32) # t1532: "cuda:0 f32[1, 512, 4096]" - t1533 = prims.convert_element_type(t1429, dtypes.float32) # t1533: "cuda:0 f32[1, 512, 4096]" - t1534 = ltorch.add(t1532, t1533, alpha=None) # t1534: "cuda:0 f32[1, 512, 4096]" - # t1534 = prims.add(t1532, t1533) # t1534: "cuda:0 f32[1, 512, 4096]" - t1535 = prims.convert_element_type(t1534, dtypes.bfloat16) # t1535: "cuda:0 bf16[1, 512, 4096]" - t1536 = prims.convert_element_type(t1535, dtypes.float32) # t1536: "cuda:0 f32[1, 512, 4096]" - t1537 = ltorch.mul(t1536, t1536) # t1537: "cuda:0 f32[1, 512, 4096]" - # t1537 = prims.mul(t1536, t1536) # t1537: "cuda:0 f32[1, 512, 4096]" - t1539 = prims.sum(t1537, (2,)) # t1539: "cuda:0 f32[1, 512]" - t1540 = prims.broadcast_in_dim(t1539, [1, 512, 1], [0, 1]) # t1540: "cuda:0 f32[1, 512, 1]" - t1542 = ltorch.true_divide(t1540, 4096.0) # t1542: "cuda:0 f32[1, 512, 1]" - # t1542 = prims.div(t1540, 4096.0) # t1542: "cuda:0 f32[1, 512, 1]" - t1544 = ltorch.add(t1542, 1e-05, alpha=None) # t1544: "cuda:0 f32[1, 512, 1]" - # t1544 = prims.add(t1542, 1e-05) # t1544: "cuda:0 f32[1, 512, 1]" - t1545 = prims.rsqrt(t1544) # t1545: "cuda:0 f32[1, 512, 1]" - t1546 = prims.broadcast_in_dim(t1545, (1, 512, 4096), (0, 1, 2)) # t1546: "cuda:0 f32[1, 512, 4096]" - t1547 = ltorch.mul(t1536, t1546) # t1547: "cuda:0 f32[1, 512, 4096]" - # t1547 = prims.mul(t1536, t1546) # t1547: "cuda:0 f32[1, 512, 4096]" - t1548 = prims.convert_element_type(t1547, dtypes.bfloat16) # t1548: "cuda:0 bf16[1, 512, 4096]" - t1549 = prims.broadcast_in_dim(t_transformer_h_10_norm_2_weight, (1, 512, 4096), (2,)) # t1549: "cuda:0 bf16[1, 512, 4096]" - t1550 = prims.convert_element_type(t1548, dtypes.float32) # t1550: "cuda:0 f32[1, 512, 4096]" - t1551 = prims.convert_element_type(t1549, dtypes.float32) # t1551: "cuda:0 f32[1, 512, 4096]" - t1552 = ltorch.mul(t1550, t1551) # t1552: "cuda:0 f32[1, 512, 4096]" - # t1552 = prims.mul(t1550, t1551) # t1552: "cuda:0 f32[1, 512, 4096]" - t1553 = prims.convert_element_type(t1552, dtypes.bfloat16) # t1553: "cuda:0 bf16[1, 512, 4096]" - t1554 = prims.linear(t1553, t_transformer_h_10_mlp_fc_1_weight, None) # t1554: "cuda:0 bf16[1, 512, 11008]" - t1555 = prims.linear(t1553, t_transformer_h_10_mlp_fc_2_weight, None) # t1555: "cuda:0 bf16[1, 512, 11008]" - t1556 = prims.convert_element_type(t1554, dtypes.float32) # t1556: "cuda:0 f32[1, 512, 11008]" - t1557 = prims.neg(t1556) # t1557: "cuda:0 f32[1, 512, 11008]" - t1558 = prims.exp(t1557) # t1558: "cuda:0 f32[1, 512, 11008]" - t1559 = ltorch.add(1.0, t1558, alpha=None) # t1559: "cuda:0 f32[1, 512, 11008]" - # t1559 = prims.add(1.0, t1558) # t1559: "cuda:0 f32[1, 512, 11008]" - t1560 = prims.reciprocal(t1559) # t1560: "cuda:0 f32[1, 512, 11008]" - t1561 = prims.convert_element_type(t1560, dtypes.bfloat16) # t1561: "cuda:0 bf16[1, 512, 11008]" - t1562 = prims.convert_element_type(t1554, dtypes.float32) # t1562: "cuda:0 f32[1, 512, 11008]" - t1563 = prims.convert_element_type(t1561, dtypes.float32) # t1563: "cuda:0 f32[1, 512, 11008]" - t1564 = ltorch.mul(t1562, t1563) # t1564: "cuda:0 f32[1, 512, 11008]" - # t1564 = prims.mul(t1562, t1563) # t1564: "cuda:0 f32[1, 512, 11008]" - t1565 = prims.convert_element_type(t1564, dtypes.bfloat16) # t1565: "cuda:0 bf16[1, 512, 11008]" - t1566 = prims.convert_element_type(t1565, dtypes.float32) # t1566: "cuda:0 f32[1, 512, 11008]" - t1567 = prims.convert_element_type(t1555, dtypes.float32) # t1567: "cuda:0 f32[1, 512, 11008]" - t1568 = ltorch.mul(t1566, t1567) # t1568: "cuda:0 f32[1, 512, 11008]" - # t1568 = prims.mul(t1566, t1567) # t1568: "cuda:0 f32[1, 512, 11008]" - t1569 = prims.convert_element_type(t1568, dtypes.bfloat16) # t1569: "cuda:0 bf16[1, 512, 11008]" - t1570 = prims.linear(t1569, t_transformer_h_10_mlp_proj_weight, None) # t1570: "cuda:0 bf16[1, 512, 4096]" - t1571 = prims.convert_element_type(t1570, dtypes.float32) # t1571: "cuda:0 f32[1, 512, 4096]" - t1572 = prims.convert_element_type(t1535, dtypes.float32) # t1572: "cuda:0 f32[1, 512, 4096]" - t1573 = ltorch.add(t1571, t1572, alpha=None) # t1573: "cuda:0 f32[1, 512, 4096]" - # t1573 = prims.add(t1571, t1572) # t1573: "cuda:0 f32[1, 512, 4096]" - t1574 = prims.convert_element_type(t1573, dtypes.bfloat16) # t1574: "cuda:0 bf16[1, 512, 4096]" - t1575 = prims.convert_element_type(t1574, dtypes.float32) # t1575: "cuda:0 f32[1, 512, 4096]" - t1576 = ltorch.mul(t1575, t1575) # t1576: "cuda:0 f32[1, 512, 4096]" - # t1576 = prims.mul(t1575, t1575) # t1576: "cuda:0 f32[1, 512, 4096]" - t1578 = prims.sum(t1576, (2,)) # t1578: "cuda:0 f32[1, 512]" - t1579 = prims.broadcast_in_dim(t1578, [1, 512, 1], [0, 1]) # t1579: "cuda:0 f32[1, 512, 1]" - t1581 = ltorch.true_divide(t1579, 4096.0) # t1581: "cuda:0 f32[1, 512, 1]" - # t1581 = prims.div(t1579, 4096.0) # t1581: "cuda:0 f32[1, 512, 1]" - t1583 = ltorch.add(t1581, 1e-05, alpha=None) # t1583: "cuda:0 f32[1, 512, 1]" - # t1583 = prims.add(t1581, 1e-05) # t1583: "cuda:0 f32[1, 512, 1]" - t1584 = prims.rsqrt(t1583) # t1584: "cuda:0 f32[1, 512, 1]" - t1585 = prims.broadcast_in_dim(t1584, (1, 512, 4096), (0, 1, 2)) # t1585: "cuda:0 f32[1, 512, 4096]" - t1586 = ltorch.mul(t1575, t1585) # t1586: "cuda:0 f32[1, 512, 4096]" - # t1586 = prims.mul(t1575, t1585) # t1586: "cuda:0 f32[1, 512, 4096]" - t1587 = prims.convert_element_type(t1586, dtypes.bfloat16) # t1587: "cuda:0 bf16[1, 512, 4096]" - t1588 = prims.broadcast_in_dim(t_transformer_h_11_norm_1_weight, (1, 512, 4096), (2,)) # t1588: "cuda:0 bf16[1, 512, 4096]" - t1589 = prims.convert_element_type(t1587, dtypes.float32) # t1589: "cuda:0 f32[1, 512, 4096]" - t1590 = prims.convert_element_type(t1588, dtypes.float32) # t1590: "cuda:0 f32[1, 512, 4096]" - t1591 = ltorch.mul(t1589, t1590) # t1591: "cuda:0 f32[1, 512, 4096]" - # t1591 = prims.mul(t1589, t1590) # t1591: "cuda:0 f32[1, 512, 4096]" - t1592 = prims.convert_element_type(t1591, dtypes.bfloat16) # t1592: "cuda:0 bf16[1, 512, 4096]" - t1593 = prims.linear(t1592, t_transformer_h_11_attn_attn_weight, None) # t1593: "cuda:0 bf16[1, 512, 12288]" - t1599 = prims.reshape(t1593, (1, 512, 32, 3, 128)) # t1599: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1605 = prims.transpose(t1599, (0, 2, 3, 1, 4)) # t1605: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1606, t1607, t1608) = ltorch.split(t1605, (1, 1, 1), 2) - # t1606 = prims.slice_prim(t1605, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1606: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1607 = prims.slice_prim(t1605, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1607: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1608 = prims.slice_prim(t1605, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1608: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1614 = prims.reshape(t1606, (1, 32, 512, 128)) # t1614: "cuda:0 bf16[1, 32, 512, 128]" - t1620 = prims.reshape(t1607, (1, 32, 512, 128)) # t1620: "cuda:0 bf16[1, 32, 512, 128]" - t1626 = prims.reshape(t1608, (1, 32, 512, 128)) # t1626: "cuda:0 bf16[1, 32, 512, 128]" - t1627 = prims.slice_prim(t1614, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1627: "cuda:0 bf16[1, 32, 512, 128]" - t1628 = prims.slice_prim(t1627, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1628: "cuda:0 bf16[1, 32, 512, 64]" - t1629 = prims.slice_prim(t1627, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1629: "cuda:0 bf16[1, 32, 512, 64]" - t1630 = prims.convert_element_type(t1629, dtypes.float32) # t1630: "cuda:0 f32[1, 32, 512, 64]" - t1631 = prims.neg(t1630) # t1631: "cuda:0 f32[1, 32, 512, 64]" - t1632 = prims.convert_element_type(t1631, dtypes.bfloat16) # t1632: "cuda:0 bf16[1, 32, 512, 64]" - t1634 = prims.cat((t1632, t1628), -1) # t1634: "cuda:0 bf16[1, 32, 512, 128]" - t1635 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1635: "cuda:0 f32[1, 32, 512, 128]" - t1636 = prims.convert_element_type(t1627, dtypes.float32) # t1636: "cuda:0 f32[1, 32, 512, 128]" - t1637 = ltorch.mul(t1636, t1635) # t1637: "cuda:0 f32[1, 32, 512, 128]" - # t1637 = prims.mul(t1636, t1635) # t1637: "cuda:0 f32[1, 32, 512, 128]" - t1638 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1638: "cuda:0 f32[1, 32, 512, 128]" - t1639 = prims.convert_element_type(t1634, dtypes.float32) # t1639: "cuda:0 f32[1, 32, 512, 128]" - t1640 = ltorch.mul(t1639, t1638) # t1640: "cuda:0 f32[1, 32, 512, 128]" - # t1640 = prims.mul(t1639, t1638) # t1640: "cuda:0 f32[1, 32, 512, 128]" - t1641 = ltorch.add(t1637, t1640, alpha=None) # t1641: "cuda:0 f32[1, 32, 512, 128]" - # t1641 = prims.add(t1637, t1640) # t1641: "cuda:0 f32[1, 32, 512, 128]" - t1642 = prims.convert_element_type(t1641, dtypes.bfloat16) # t1642: "cuda:0 bf16[1, 32, 512, 128]" - t1643 = prims.slice_prim(t1620, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1643: "cuda:0 bf16[1, 32, 512, 128]" - t1644 = prims.slice_prim(t1643, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1644: "cuda:0 bf16[1, 32, 512, 64]" - t1645 = prims.slice_prim(t1643, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1645: "cuda:0 bf16[1, 32, 512, 64]" - t1646 = prims.convert_element_type(t1645, dtypes.float32) # t1646: "cuda:0 f32[1, 32, 512, 64]" - t1647 = prims.neg(t1646) # t1647: "cuda:0 f32[1, 32, 512, 64]" - t1648 = prims.convert_element_type(t1647, dtypes.bfloat16) # t1648: "cuda:0 bf16[1, 32, 512, 64]" - t1650 = prims.cat((t1648, t1644), -1) # t1650: "cuda:0 bf16[1, 32, 512, 128]" - t1651 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1651: "cuda:0 f32[1, 32, 512, 128]" - t1652 = prims.convert_element_type(t1643, dtypes.float32) # t1652: "cuda:0 f32[1, 32, 512, 128]" - t1653 = ltorch.mul(t1652, t1651) # t1653: "cuda:0 f32[1, 32, 512, 128]" - # t1653 = prims.mul(t1652, t1651) # t1653: "cuda:0 f32[1, 32, 512, 128]" - t1654 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1654: "cuda:0 f32[1, 32, 512, 128]" - t1655 = prims.convert_element_type(t1650, dtypes.float32) # t1655: "cuda:0 f32[1, 32, 512, 128]" - t1656 = ltorch.mul(t1655, t1654) # t1656: "cuda:0 f32[1, 32, 512, 128]" - # t1656 = prims.mul(t1655, t1654) # t1656: "cuda:0 f32[1, 32, 512, 128]" - t1657 = ltorch.add(t1653, t1656, alpha=None) # t1657: "cuda:0 f32[1, 32, 512, 128]" - # t1657 = prims.add(t1653, t1656) # t1657: "cuda:0 f32[1, 32, 512, 128]" - t1658 = prims.convert_element_type(t1657, dtypes.bfloat16) # t1658: "cuda:0 bf16[1, 32, 512, 128]" - t1659 = prims.slice_prim(t1614, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1659: "cuda:0 bf16[1, 32, 512, 0]" - t1661 = prims.cat((t1642, t1659), -1) # t1661: "cuda:0 bf16[1, 32, 512, 128]" - t1662 = prims.slice_prim(t1620, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1662: "cuda:0 bf16[1, 32, 512, 0]" - t1664 = prims.cat((t1658, t1662), -1) # t1664: "cuda:0 bf16[1, 32, 512, 128]" - (t1665, t1666, t1667, t1668) = cudnn_sdpa_fwd(t1661, t1664, t1626, None, 0.0, True, scale=0.08838834764831843) - t1671 = prims.transpose(t1665, (0, 2, 1, 3)) # t1671: "cuda:0 bf16[1, 512, 32, 128]" - t1675 = prims.reshape(t1671, (1, 512, 4096)) # t1675: "cuda:0 bf16[1, 512, 4096]" - t1676 = prims.linear(t1675, t_transformer_h_11_attn_proj_weight, None) # t1676: "cuda:0 bf16[1, 512, 4096]" - t1677 = prims.convert_element_type(t1676, dtypes.float32) # t1677: "cuda:0 f32[1, 512, 4096]" - t1678 = prims.convert_element_type(t1574, dtypes.float32) # t1678: "cuda:0 f32[1, 512, 4096]" - t1679 = ltorch.add(t1677, t1678, alpha=None) # t1679: "cuda:0 f32[1, 512, 4096]" - # t1679 = prims.add(t1677, t1678) # t1679: "cuda:0 f32[1, 512, 4096]" - t1680 = prims.convert_element_type(t1679, dtypes.bfloat16) # t1680: "cuda:0 bf16[1, 512, 4096]" - t1681 = prims.convert_element_type(t1680, dtypes.float32) # t1681: "cuda:0 f32[1, 512, 4096]" - t1682 = ltorch.mul(t1681, t1681) # t1682: "cuda:0 f32[1, 512, 4096]" - # t1682 = prims.mul(t1681, t1681) # t1682: "cuda:0 f32[1, 512, 4096]" - t1684 = prims.sum(t1682, (2,)) # t1684: "cuda:0 f32[1, 512]" - t1685 = prims.broadcast_in_dim(t1684, [1, 512, 1], [0, 1]) # t1685: "cuda:0 f32[1, 512, 1]" - t1687 = ltorch.true_divide(t1685, 4096.0) # t1687: "cuda:0 f32[1, 512, 1]" - # t1687 = prims.div(t1685, 4096.0) # t1687: "cuda:0 f32[1, 512, 1]" - t1689 = ltorch.add(t1687, 1e-05, alpha=None) # t1689: "cuda:0 f32[1, 512, 1]" - # t1689 = prims.add(t1687, 1e-05) # t1689: "cuda:0 f32[1, 512, 1]" - t1690 = prims.rsqrt(t1689) # t1690: "cuda:0 f32[1, 512, 1]" - t1691 = prims.broadcast_in_dim(t1690, (1, 512, 4096), (0, 1, 2)) # t1691: "cuda:0 f32[1, 512, 4096]" - t1692 = ltorch.mul(t1681, t1691) # t1692: "cuda:0 f32[1, 512, 4096]" - # t1692 = prims.mul(t1681, t1691) # t1692: "cuda:0 f32[1, 512, 4096]" - t1693 = prims.convert_element_type(t1692, dtypes.bfloat16) # t1693: "cuda:0 bf16[1, 512, 4096]" - t1694 = prims.broadcast_in_dim(t_transformer_h_11_norm_2_weight, (1, 512, 4096), (2,)) # t1694: "cuda:0 bf16[1, 512, 4096]" - t1695 = prims.convert_element_type(t1693, dtypes.float32) # t1695: "cuda:0 f32[1, 512, 4096]" - t1696 = prims.convert_element_type(t1694, dtypes.float32) # t1696: "cuda:0 f32[1, 512, 4096]" - t1697 = ltorch.mul(t1695, t1696) # t1697: "cuda:0 f32[1, 512, 4096]" - # t1697 = prims.mul(t1695, t1696) # t1697: "cuda:0 f32[1, 512, 4096]" - t1698 = prims.convert_element_type(t1697, dtypes.bfloat16) # t1698: "cuda:0 bf16[1, 512, 4096]" - t1699 = prims.linear(t1698, t_transformer_h_11_mlp_fc_1_weight, None) # t1699: "cuda:0 bf16[1, 512, 11008]" - t1700 = prims.linear(t1698, t_transformer_h_11_mlp_fc_2_weight, None) # t1700: "cuda:0 bf16[1, 512, 11008]" - t1701 = prims.convert_element_type(t1699, dtypes.float32) # t1701: "cuda:0 f32[1, 512, 11008]" - t1702 = prims.neg(t1701) # t1702: "cuda:0 f32[1, 512, 11008]" - t1703 = prims.exp(t1702) # t1703: "cuda:0 f32[1, 512, 11008]" - t1704 = ltorch.add(1.0, t1703, alpha=None) # t1704: "cuda:0 f32[1, 512, 11008]" - # t1704 = prims.add(1.0, t1703) # t1704: "cuda:0 f32[1, 512, 11008]" - t1705 = prims.reciprocal(t1704) # t1705: "cuda:0 f32[1, 512, 11008]" - t1706 = prims.convert_element_type(t1705, dtypes.bfloat16) # t1706: "cuda:0 bf16[1, 512, 11008]" - t1707 = prims.convert_element_type(t1699, dtypes.float32) # t1707: "cuda:0 f32[1, 512, 11008]" - t1708 = prims.convert_element_type(t1706, dtypes.float32) # t1708: "cuda:0 f32[1, 512, 11008]" - t1709 = ltorch.mul(t1707, t1708) # t1709: "cuda:0 f32[1, 512, 11008]" - # t1709 = prims.mul(t1707, t1708) # t1709: "cuda:0 f32[1, 512, 11008]" - t1710 = prims.convert_element_type(t1709, dtypes.bfloat16) # t1710: "cuda:0 bf16[1, 512, 11008]" - t1711 = prims.convert_element_type(t1710, dtypes.float32) # t1711: "cuda:0 f32[1, 512, 11008]" - t1712 = prims.convert_element_type(t1700, dtypes.float32) # t1712: "cuda:0 f32[1, 512, 11008]" - t1713 = ltorch.mul(t1711, t1712) # t1713: "cuda:0 f32[1, 512, 11008]" - # t1713 = prims.mul(t1711, t1712) # t1713: "cuda:0 f32[1, 512, 11008]" - t1714 = prims.convert_element_type(t1713, dtypes.bfloat16) # t1714: "cuda:0 bf16[1, 512, 11008]" - t1715 = prims.linear(t1714, t_transformer_h_11_mlp_proj_weight, None) # t1715: "cuda:0 bf16[1, 512, 4096]" - t1716 = prims.convert_element_type(t1715, dtypes.float32) # t1716: "cuda:0 f32[1, 512, 4096]" - t1717 = prims.convert_element_type(t1680, dtypes.float32) # t1717: "cuda:0 f32[1, 512, 4096]" - t1718 = ltorch.add(t1716, t1717, alpha=None) # t1718: "cuda:0 f32[1, 512, 4096]" - # t1718 = prims.add(t1716, t1717) # t1718: "cuda:0 f32[1, 512, 4096]" - t1719 = prims.convert_element_type(t1718, dtypes.bfloat16) # t1719: "cuda:0 bf16[1, 512, 4096]" - t1720 = prims.convert_element_type(t1719, dtypes.float32) # t1720: "cuda:0 f32[1, 512, 4096]" - t1721 = ltorch.mul(t1720, t1720) # t1721: "cuda:0 f32[1, 512, 4096]" - # t1721 = prims.mul(t1720, t1720) # t1721: "cuda:0 f32[1, 512, 4096]" - t1723 = prims.sum(t1721, (2,)) # t1723: "cuda:0 f32[1, 512]" - t1724 = prims.broadcast_in_dim(t1723, [1, 512, 1], [0, 1]) # t1724: "cuda:0 f32[1, 512, 1]" - t1726 = ltorch.true_divide(t1724, 4096.0) # t1726: "cuda:0 f32[1, 512, 1]" - # t1726 = prims.div(t1724, 4096.0) # t1726: "cuda:0 f32[1, 512, 1]" - t1728 = ltorch.add(t1726, 1e-05, alpha=None) # t1728: "cuda:0 f32[1, 512, 1]" - # t1728 = prims.add(t1726, 1e-05) # t1728: "cuda:0 f32[1, 512, 1]" - t1729 = prims.rsqrt(t1728) # t1729: "cuda:0 f32[1, 512, 1]" - t1730 = prims.broadcast_in_dim(t1729, (1, 512, 4096), (0, 1, 2)) # t1730: "cuda:0 f32[1, 512, 4096]" - t1731 = ltorch.mul(t1720, t1730) # t1731: "cuda:0 f32[1, 512, 4096]" - # t1731 = prims.mul(t1720, t1730) # t1731: "cuda:0 f32[1, 512, 4096]" - t1732 = prims.convert_element_type(t1731, dtypes.bfloat16) # t1732: "cuda:0 bf16[1, 512, 4096]" - t1733 = prims.broadcast_in_dim(t_transformer_h_12_norm_1_weight, (1, 512, 4096), (2,)) # t1733: "cuda:0 bf16[1, 512, 4096]" - t1734 = prims.convert_element_type(t1732, dtypes.float32) # t1734: "cuda:0 f32[1, 512, 4096]" - t1735 = prims.convert_element_type(t1733, dtypes.float32) # t1735: "cuda:0 f32[1, 512, 4096]" - t1736 = ltorch.mul(t1734, t1735) # t1736: "cuda:0 f32[1, 512, 4096]" - # t1736 = prims.mul(t1734, t1735) # t1736: "cuda:0 f32[1, 512, 4096]" - t1737 = prims.convert_element_type(t1736, dtypes.bfloat16) # t1737: "cuda:0 bf16[1, 512, 4096]" - t1738 = prims.linear(t1737, t_transformer_h_12_attn_attn_weight, None) # t1738: "cuda:0 bf16[1, 512, 12288]" - t1744 = prims.reshape(t1738, (1, 512, 32, 3, 128)) # t1744: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1750 = prims.transpose(t1744, (0, 2, 3, 1, 4)) # t1750: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1751, t1752, t1753) = ltorch.split(t1750, (1, 1, 1), 2) - # t1751 = prims.slice_prim(t1750, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1751: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1752 = prims.slice_prim(t1750, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1752: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1753 = prims.slice_prim(t1750, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1753: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1759 = prims.reshape(t1751, (1, 32, 512, 128)) # t1759: "cuda:0 bf16[1, 32, 512, 128]" - t1765 = prims.reshape(t1752, (1, 32, 512, 128)) # t1765: "cuda:0 bf16[1, 32, 512, 128]" - t1771 = prims.reshape(t1753, (1, 32, 512, 128)) # t1771: "cuda:0 bf16[1, 32, 512, 128]" - t1772 = prims.slice_prim(t1759, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1772: "cuda:0 bf16[1, 32, 512, 128]" - t1773 = prims.slice_prim(t1772, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1773: "cuda:0 bf16[1, 32, 512, 64]" - t1774 = prims.slice_prim(t1772, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1774: "cuda:0 bf16[1, 32, 512, 64]" - t1775 = prims.convert_element_type(t1774, dtypes.float32) # t1775: "cuda:0 f32[1, 32, 512, 64]" - t1776 = prims.neg(t1775) # t1776: "cuda:0 f32[1, 32, 512, 64]" - t1777 = prims.convert_element_type(t1776, dtypes.bfloat16) # t1777: "cuda:0 bf16[1, 32, 512, 64]" - t1779 = prims.cat((t1777, t1773), -1) # t1779: "cuda:0 bf16[1, 32, 512, 128]" - t1780 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1780: "cuda:0 f32[1, 32, 512, 128]" - t1781 = prims.convert_element_type(t1772, dtypes.float32) # t1781: "cuda:0 f32[1, 32, 512, 128]" - t1782 = ltorch.mul(t1781, t1780) # t1782: "cuda:0 f32[1, 32, 512, 128]" - # t1782 = prims.mul(t1781, t1780) # t1782: "cuda:0 f32[1, 32, 512, 128]" - t1783 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1783: "cuda:0 f32[1, 32, 512, 128]" - t1784 = prims.convert_element_type(t1779, dtypes.float32) # t1784: "cuda:0 f32[1, 32, 512, 128]" - t1785 = ltorch.mul(t1784, t1783) # t1785: "cuda:0 f32[1, 32, 512, 128]" - # t1785 = prims.mul(t1784, t1783) # t1785: "cuda:0 f32[1, 32, 512, 128]" - t1786 = ltorch.add(t1782, t1785, alpha=None) # t1786: "cuda:0 f32[1, 32, 512, 128]" - # t1786 = prims.add(t1782, t1785) # t1786: "cuda:0 f32[1, 32, 512, 128]" - t1787 = prims.convert_element_type(t1786, dtypes.bfloat16) # t1787: "cuda:0 bf16[1, 32, 512, 128]" - t1788 = prims.slice_prim(t1765, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1788: "cuda:0 bf16[1, 32, 512, 128]" - t1789 = prims.slice_prim(t1788, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1789: "cuda:0 bf16[1, 32, 512, 64]" - t1790 = prims.slice_prim(t1788, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1790: "cuda:0 bf16[1, 32, 512, 64]" - t1791 = prims.convert_element_type(t1790, dtypes.float32) # t1791: "cuda:0 f32[1, 32, 512, 64]" - t1792 = prims.neg(t1791) # t1792: "cuda:0 f32[1, 32, 512, 64]" - t1793 = prims.convert_element_type(t1792, dtypes.bfloat16) # t1793: "cuda:0 bf16[1, 32, 512, 64]" - t1795 = prims.cat((t1793, t1789), -1) # t1795: "cuda:0 bf16[1, 32, 512, 128]" - t1796 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1796: "cuda:0 f32[1, 32, 512, 128]" - t1797 = prims.convert_element_type(t1788, dtypes.float32) # t1797: "cuda:0 f32[1, 32, 512, 128]" - t1798 = ltorch.mul(t1797, t1796) # t1798: "cuda:0 f32[1, 32, 512, 128]" - # t1798 = prims.mul(t1797, t1796) # t1798: "cuda:0 f32[1, 32, 512, 128]" - t1799 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1799: "cuda:0 f32[1, 32, 512, 128]" - t1800 = prims.convert_element_type(t1795, dtypes.float32) # t1800: "cuda:0 f32[1, 32, 512, 128]" - t1801 = ltorch.mul(t1800, t1799) # t1801: "cuda:0 f32[1, 32, 512, 128]" - # t1801 = prims.mul(t1800, t1799) # t1801: "cuda:0 f32[1, 32, 512, 128]" - t1802 = ltorch.add(t1798, t1801, alpha=None) # t1802: "cuda:0 f32[1, 32, 512, 128]" - # t1802 = prims.add(t1798, t1801) # t1802: "cuda:0 f32[1, 32, 512, 128]" - t1803 = prims.convert_element_type(t1802, dtypes.bfloat16) # t1803: "cuda:0 bf16[1, 32, 512, 128]" - t1804 = prims.slice_prim(t1759, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1804: "cuda:0 bf16[1, 32, 512, 0]" - t1806 = prims.cat((t1787, t1804), -1) # t1806: "cuda:0 bf16[1, 32, 512, 128]" - t1807 = prims.slice_prim(t1765, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1807: "cuda:0 bf16[1, 32, 512, 0]" - t1809 = prims.cat((t1803, t1807), -1) # t1809: "cuda:0 bf16[1, 32, 512, 128]" - (t1810, t1811, t1812, t1813) = cudnn_sdpa_fwd(t1806, t1809, t1771, None, 0.0, True, scale=0.08838834764831843) - t1816 = prims.transpose(t1810, (0, 2, 1, 3)) # t1816: "cuda:0 bf16[1, 512, 32, 128]" - t1820 = prims.reshape(t1816, (1, 512, 4096)) # t1820: "cuda:0 bf16[1, 512, 4096]" - t1821 = prims.linear(t1820, t_transformer_h_12_attn_proj_weight, None) # t1821: "cuda:0 bf16[1, 512, 4096]" - t1822 = prims.convert_element_type(t1821, dtypes.float32) # t1822: "cuda:0 f32[1, 512, 4096]" - t1823 = prims.convert_element_type(t1719, dtypes.float32) # t1823: "cuda:0 f32[1, 512, 4096]" - t1824 = ltorch.add(t1822, t1823, alpha=None) # t1824: "cuda:0 f32[1, 512, 4096]" - # t1824 = prims.add(t1822, t1823) # t1824: "cuda:0 f32[1, 512, 4096]" - t1825 = prims.convert_element_type(t1824, dtypes.bfloat16) # t1825: "cuda:0 bf16[1, 512, 4096]" - t1826 = prims.convert_element_type(t1825, dtypes.float32) # t1826: "cuda:0 f32[1, 512, 4096]" - t1827 = ltorch.mul(t1826, t1826) # t1827: "cuda:0 f32[1, 512, 4096]" - # t1827 = prims.mul(t1826, t1826) # t1827: "cuda:0 f32[1, 512, 4096]" - t1829 = prims.sum(t1827, (2,)) # t1829: "cuda:0 f32[1, 512]" - t1830 = prims.broadcast_in_dim(t1829, [1, 512, 1], [0, 1]) # t1830: "cuda:0 f32[1, 512, 1]" - t1832 = ltorch.true_divide(t1830, 4096.0) # t1832: "cuda:0 f32[1, 512, 1]" - # t1832 = prims.div(t1830, 4096.0) # t1832: "cuda:0 f32[1, 512, 1]" - t1834 = ltorch.add(t1832, 1e-05, alpha=None) # t1834: "cuda:0 f32[1, 512, 1]" - # t1834 = prims.add(t1832, 1e-05) # t1834: "cuda:0 f32[1, 512, 1]" - t1835 = prims.rsqrt(t1834) # t1835: "cuda:0 f32[1, 512, 1]" - t1836 = prims.broadcast_in_dim(t1835, (1, 512, 4096), (0, 1, 2)) # t1836: "cuda:0 f32[1, 512, 4096]" - t1837 = ltorch.mul(t1826, t1836) # t1837: "cuda:0 f32[1, 512, 4096]" - # t1837 = prims.mul(t1826, t1836) # t1837: "cuda:0 f32[1, 512, 4096]" - t1838 = prims.convert_element_type(t1837, dtypes.bfloat16) # t1838: "cuda:0 bf16[1, 512, 4096]" - t1839 = prims.broadcast_in_dim(t_transformer_h_12_norm_2_weight, (1, 512, 4096), (2,)) # t1839: "cuda:0 bf16[1, 512, 4096]" - t1840 = prims.convert_element_type(t1838, dtypes.float32) # t1840: "cuda:0 f32[1, 512, 4096]" - t1841 = prims.convert_element_type(t1839, dtypes.float32) # t1841: "cuda:0 f32[1, 512, 4096]" - t1842 = ltorch.mul(t1840, t1841) # t1842: "cuda:0 f32[1, 512, 4096]" - # t1842 = prims.mul(t1840, t1841) # t1842: "cuda:0 f32[1, 512, 4096]" - t1843 = prims.convert_element_type(t1842, dtypes.bfloat16) # t1843: "cuda:0 bf16[1, 512, 4096]" - t1844 = prims.linear(t1843, t_transformer_h_12_mlp_fc_1_weight, None) # t1844: "cuda:0 bf16[1, 512, 11008]" - t1845 = prims.linear(t1843, t_transformer_h_12_mlp_fc_2_weight, None) # t1845: "cuda:0 bf16[1, 512, 11008]" - t1846 = prims.convert_element_type(t1844, dtypes.float32) # t1846: "cuda:0 f32[1, 512, 11008]" - t1847 = prims.neg(t1846) # t1847: "cuda:0 f32[1, 512, 11008]" - t1848 = prims.exp(t1847) # t1848: "cuda:0 f32[1, 512, 11008]" - t1849 = ltorch.add(1.0, t1848, alpha=None) # t1849: "cuda:0 f32[1, 512, 11008]" - # t1849 = prims.add(1.0, t1848) # t1849: "cuda:0 f32[1, 512, 11008]" - t1850 = prims.reciprocal(t1849) # t1850: "cuda:0 f32[1, 512, 11008]" - t1851 = prims.convert_element_type(t1850, dtypes.bfloat16) # t1851: "cuda:0 bf16[1, 512, 11008]" - t1852 = prims.convert_element_type(t1844, dtypes.float32) # t1852: "cuda:0 f32[1, 512, 11008]" - t1853 = prims.convert_element_type(t1851, dtypes.float32) # t1853: "cuda:0 f32[1, 512, 11008]" - t1854 = ltorch.mul(t1852, t1853) # t1854: "cuda:0 f32[1, 512, 11008]" - # t1854 = prims.mul(t1852, t1853) # t1854: "cuda:0 f32[1, 512, 11008]" - t1855 = prims.convert_element_type(t1854, dtypes.bfloat16) # t1855: "cuda:0 bf16[1, 512, 11008]" - t1856 = prims.convert_element_type(t1855, dtypes.float32) # t1856: "cuda:0 f32[1, 512, 11008]" - t1857 = prims.convert_element_type(t1845, dtypes.float32) # t1857: "cuda:0 f32[1, 512, 11008]" - t1858 = ltorch.mul(t1856, t1857) # t1858: "cuda:0 f32[1, 512, 11008]" - # t1858 = prims.mul(t1856, t1857) # t1858: "cuda:0 f32[1, 512, 11008]" - t1859 = prims.convert_element_type(t1858, dtypes.bfloat16) # t1859: "cuda:0 bf16[1, 512, 11008]" - t1860 = prims.linear(t1859, t_transformer_h_12_mlp_proj_weight, None) # t1860: "cuda:0 bf16[1, 512, 4096]" - t1861 = prims.convert_element_type(t1860, dtypes.float32) # t1861: "cuda:0 f32[1, 512, 4096]" - t1862 = prims.convert_element_type(t1825, dtypes.float32) # t1862: "cuda:0 f32[1, 512, 4096]" - t1863 = ltorch.add(t1861, t1862, alpha=None) # t1863: "cuda:0 f32[1, 512, 4096]" - # t1863 = prims.add(t1861, t1862) # t1863: "cuda:0 f32[1, 512, 4096]" - t1864 = prims.convert_element_type(t1863, dtypes.bfloat16) # t1864: "cuda:0 bf16[1, 512, 4096]" - t1865 = prims.convert_element_type(t1864, dtypes.float32) # t1865: "cuda:0 f32[1, 512, 4096]" - t1866 = ltorch.mul(t1865, t1865) # t1866: "cuda:0 f32[1, 512, 4096]" - # t1866 = prims.mul(t1865, t1865) # t1866: "cuda:0 f32[1, 512, 4096]" - t1868 = prims.sum(t1866, (2,)) # t1868: "cuda:0 f32[1, 512]" - t1869 = prims.broadcast_in_dim(t1868, [1, 512, 1], [0, 1]) # t1869: "cuda:0 f32[1, 512, 1]" - t1871 = ltorch.true_divide(t1869, 4096.0) # t1871: "cuda:0 f32[1, 512, 1]" - # t1871 = prims.div(t1869, 4096.0) # t1871: "cuda:0 f32[1, 512, 1]" - t1873 = ltorch.add(t1871, 1e-05, alpha=None) # t1873: "cuda:0 f32[1, 512, 1]" - # t1873 = prims.add(t1871, 1e-05) # t1873: "cuda:0 f32[1, 512, 1]" - t1874 = prims.rsqrt(t1873) # t1874: "cuda:0 f32[1, 512, 1]" - t1875 = prims.broadcast_in_dim(t1874, (1, 512, 4096), (0, 1, 2)) # t1875: "cuda:0 f32[1, 512, 4096]" - t1876 = ltorch.mul(t1865, t1875) # t1876: "cuda:0 f32[1, 512, 4096]" - # t1876 = prims.mul(t1865, t1875) # t1876: "cuda:0 f32[1, 512, 4096]" - t1877 = prims.convert_element_type(t1876, dtypes.bfloat16) # t1877: "cuda:0 bf16[1, 512, 4096]" - t1878 = prims.broadcast_in_dim(t_transformer_h_13_norm_1_weight, (1, 512, 4096), (2,)) # t1878: "cuda:0 bf16[1, 512, 4096]" - t1879 = prims.convert_element_type(t1877, dtypes.float32) # t1879: "cuda:0 f32[1, 512, 4096]" - t1880 = prims.convert_element_type(t1878, dtypes.float32) # t1880: "cuda:0 f32[1, 512, 4096]" - t1881 = ltorch.mul(t1879, t1880) # t1881: "cuda:0 f32[1, 512, 4096]" - # t1881 = prims.mul(t1879, t1880) # t1881: "cuda:0 f32[1, 512, 4096]" - t1882 = prims.convert_element_type(t1881, dtypes.bfloat16) # t1882: "cuda:0 bf16[1, 512, 4096]" - t1883 = prims.linear(t1882, t_transformer_h_13_attn_attn_weight, None) # t1883: "cuda:0 bf16[1, 512, 12288]" - t1889 = prims.reshape(t1883, (1, 512, 32, 3, 128)) # t1889: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1895 = prims.transpose(t1889, (0, 2, 3, 1, 4)) # t1895: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1896, t1897, t1898) = ltorch.split(t1895, (1, 1, 1), 2) - # t1896 = prims.slice_prim(t1895, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1896: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1897 = prims.slice_prim(t1895, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1897: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1898 = prims.slice_prim(t1895, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1898: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1904 = prims.reshape(t1896, (1, 32, 512, 128)) # t1904: "cuda:0 bf16[1, 32, 512, 128]" - t1910 = prims.reshape(t1897, (1, 32, 512, 128)) # t1910: "cuda:0 bf16[1, 32, 512, 128]" - t1916 = prims.reshape(t1898, (1, 32, 512, 128)) # t1916: "cuda:0 bf16[1, 32, 512, 128]" - t1917 = prims.slice_prim(t1904, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1917: "cuda:0 bf16[1, 32, 512, 128]" - t1918 = prims.slice_prim(t1917, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1918: "cuda:0 bf16[1, 32, 512, 64]" - t1919 = prims.slice_prim(t1917, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1919: "cuda:0 bf16[1, 32, 512, 64]" - t1920 = prims.convert_element_type(t1919, dtypes.float32) # t1920: "cuda:0 f32[1, 32, 512, 64]" - t1921 = prims.neg(t1920) # t1921: "cuda:0 f32[1, 32, 512, 64]" - t1922 = prims.convert_element_type(t1921, dtypes.bfloat16) # t1922: "cuda:0 bf16[1, 32, 512, 64]" - t1924 = prims.cat((t1922, t1918), -1) # t1924: "cuda:0 bf16[1, 32, 512, 128]" - t1925 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1925: "cuda:0 f32[1, 32, 512, 128]" - t1926 = prims.convert_element_type(t1917, dtypes.float32) # t1926: "cuda:0 f32[1, 32, 512, 128]" - t1927 = ltorch.mul(t1926, t1925) # t1927: "cuda:0 f32[1, 32, 512, 128]" - # t1927 = prims.mul(t1926, t1925) # t1927: "cuda:0 f32[1, 32, 512, 128]" - t1928 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1928: "cuda:0 f32[1, 32, 512, 128]" - t1929 = prims.convert_element_type(t1924, dtypes.float32) # t1929: "cuda:0 f32[1, 32, 512, 128]" - t1930 = ltorch.mul(t1929, t1928) # t1930: "cuda:0 f32[1, 32, 512, 128]" - # t1930 = prims.mul(t1929, t1928) # t1930: "cuda:0 f32[1, 32, 512, 128]" - t1931 = ltorch.add(t1927, t1930, alpha=None) # t1931: "cuda:0 f32[1, 32, 512, 128]" - # t1931 = prims.add(t1927, t1930) # t1931: "cuda:0 f32[1, 32, 512, 128]" - t1932 = prims.convert_element_type(t1931, dtypes.bfloat16) # t1932: "cuda:0 bf16[1, 32, 512, 128]" - t1933 = prims.slice_prim(t1910, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1933: "cuda:0 bf16[1, 32, 512, 128]" - t1934 = prims.slice_prim(t1933, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1934: "cuda:0 bf16[1, 32, 512, 64]" - t1935 = prims.slice_prim(t1933, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1935: "cuda:0 bf16[1, 32, 512, 64]" - t1936 = prims.convert_element_type(t1935, dtypes.float32) # t1936: "cuda:0 f32[1, 32, 512, 64]" - t1937 = prims.neg(t1936) # t1937: "cuda:0 f32[1, 32, 512, 64]" - t1938 = prims.convert_element_type(t1937, dtypes.bfloat16) # t1938: "cuda:0 bf16[1, 32, 512, 64]" - t1940 = prims.cat((t1938, t1934), -1) # t1940: "cuda:0 bf16[1, 32, 512, 128]" - t1941 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1941: "cuda:0 f32[1, 32, 512, 128]" - t1942 = prims.convert_element_type(t1933, dtypes.float32) # t1942: "cuda:0 f32[1, 32, 512, 128]" - t1943 = ltorch.mul(t1942, t1941) # t1943: "cuda:0 f32[1, 32, 512, 128]" - # t1943 = prims.mul(t1942, t1941) # t1943: "cuda:0 f32[1, 32, 512, 128]" - t1944 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1944: "cuda:0 f32[1, 32, 512, 128]" - t1945 = prims.convert_element_type(t1940, dtypes.float32) # t1945: "cuda:0 f32[1, 32, 512, 128]" - t1946 = ltorch.mul(t1945, t1944) # t1946: "cuda:0 f32[1, 32, 512, 128]" - # t1946 = prims.mul(t1945, t1944) # t1946: "cuda:0 f32[1, 32, 512, 128]" - t1947 = ltorch.add(t1943, t1946, alpha=None) # t1947: "cuda:0 f32[1, 32, 512, 128]" - # t1947 = prims.add(t1943, t1946) # t1947: "cuda:0 f32[1, 32, 512, 128]" - t1948 = prims.convert_element_type(t1947, dtypes.bfloat16) # t1948: "cuda:0 bf16[1, 32, 512, 128]" - t1949 = prims.slice_prim(t1904, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1949: "cuda:0 bf16[1, 32, 512, 0]" - t1951 = prims.cat((t1932, t1949), -1) # t1951: "cuda:0 bf16[1, 32, 512, 128]" - t1952 = prims.slice_prim(t1910, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1952: "cuda:0 bf16[1, 32, 512, 0]" - t1954 = prims.cat((t1948, t1952), -1) # t1954: "cuda:0 bf16[1, 32, 512, 128]" - (t1955, t1956, t1957, t1958) = cudnn_sdpa_fwd(t1951, t1954, t1916, None, 0.0, True, scale=0.08838834764831843) - t1961 = prims.transpose(t1955, (0, 2, 1, 3)) # t1961: "cuda:0 bf16[1, 512, 32, 128]" - t1965 = prims.reshape(t1961, (1, 512, 4096)) # t1965: "cuda:0 bf16[1, 512, 4096]" - t1966 = prims.linear(t1965, t_transformer_h_13_attn_proj_weight, None) # t1966: "cuda:0 bf16[1, 512, 4096]" - t1967 = prims.convert_element_type(t1966, dtypes.float32) # t1967: "cuda:0 f32[1, 512, 4096]" - t1968 = prims.convert_element_type(t1864, dtypes.float32) # t1968: "cuda:0 f32[1, 512, 4096]" - t1969 = ltorch.add(t1967, t1968, alpha=None) # t1969: "cuda:0 f32[1, 512, 4096]" - # t1969 = prims.add(t1967, t1968) # t1969: "cuda:0 f32[1, 512, 4096]" - t1970 = prims.convert_element_type(t1969, dtypes.bfloat16) # t1970: "cuda:0 bf16[1, 512, 4096]" - t1971 = prims.convert_element_type(t1970, dtypes.float32) # t1971: "cuda:0 f32[1, 512, 4096]" - t1972 = ltorch.mul(t1971, t1971) # t1972: "cuda:0 f32[1, 512, 4096]" - # t1972 = prims.mul(t1971, t1971) # t1972: "cuda:0 f32[1, 512, 4096]" - t1974 = prims.sum(t1972, (2,)) # t1974: "cuda:0 f32[1, 512]" - t1975 = prims.broadcast_in_dim(t1974, [1, 512, 1], [0, 1]) # t1975: "cuda:0 f32[1, 512, 1]" - t1977 = ltorch.true_divide(t1975, 4096.0) # t1977: "cuda:0 f32[1, 512, 1]" - # t1977 = prims.div(t1975, 4096.0) # t1977: "cuda:0 f32[1, 512, 1]" - t1979 = ltorch.add(t1977, 1e-05, alpha=None) # t1979: "cuda:0 f32[1, 512, 1]" - # t1979 = prims.add(t1977, 1e-05) # t1979: "cuda:0 f32[1, 512, 1]" - t1980 = prims.rsqrt(t1979) # t1980: "cuda:0 f32[1, 512, 1]" - t1981 = prims.broadcast_in_dim(t1980, (1, 512, 4096), (0, 1, 2)) # t1981: "cuda:0 f32[1, 512, 4096]" - t1982 = ltorch.mul(t1971, t1981) # t1982: "cuda:0 f32[1, 512, 4096]" - # t1982 = prims.mul(t1971, t1981) # t1982: "cuda:0 f32[1, 512, 4096]" - t1983 = prims.convert_element_type(t1982, dtypes.bfloat16) # t1983: "cuda:0 bf16[1, 512, 4096]" - t1984 = prims.broadcast_in_dim(t_transformer_h_13_norm_2_weight, (1, 512, 4096), (2,)) # t1984: "cuda:0 bf16[1, 512, 4096]" - t1985 = prims.convert_element_type(t1983, dtypes.float32) # t1985: "cuda:0 f32[1, 512, 4096]" - t1986 = prims.convert_element_type(t1984, dtypes.float32) # t1986: "cuda:0 f32[1, 512, 4096]" - t1987 = ltorch.mul(t1985, t1986) # t1987: "cuda:0 f32[1, 512, 4096]" - # t1987 = prims.mul(t1985, t1986) # t1987: "cuda:0 f32[1, 512, 4096]" - t1988 = prims.convert_element_type(t1987, dtypes.bfloat16) # t1988: "cuda:0 bf16[1, 512, 4096]" - t1989 = prims.linear(t1988, t_transformer_h_13_mlp_fc_1_weight, None) # t1989: "cuda:0 bf16[1, 512, 11008]" - t1990 = prims.linear(t1988, t_transformer_h_13_mlp_fc_2_weight, None) # t1990: "cuda:0 bf16[1, 512, 11008]" - t1991 = prims.convert_element_type(t1989, dtypes.float32) # t1991: "cuda:0 f32[1, 512, 11008]" - t1992 = prims.neg(t1991) # t1992: "cuda:0 f32[1, 512, 11008]" - t1993 = prims.exp(t1992) # t1993: "cuda:0 f32[1, 512, 11008]" - t1994 = ltorch.add(1.0, t1993, alpha=None) # t1994: "cuda:0 f32[1, 512, 11008]" - # t1994 = prims.add(1.0, t1993) # t1994: "cuda:0 f32[1, 512, 11008]" - t1995 = prims.reciprocal(t1994) # t1995: "cuda:0 f32[1, 512, 11008]" - t1996 = prims.convert_element_type(t1995, dtypes.bfloat16) # t1996: "cuda:0 bf16[1, 512, 11008]" - t1997 = prims.convert_element_type(t1989, dtypes.float32) # t1997: "cuda:0 f32[1, 512, 11008]" - t1998 = prims.convert_element_type(t1996, dtypes.float32) # t1998: "cuda:0 f32[1, 512, 11008]" - t1999 = ltorch.mul(t1997, t1998) # t1999: "cuda:0 f32[1, 512, 11008]" - # t1999 = prims.mul(t1997, t1998) # t1999: "cuda:0 f32[1, 512, 11008]" - t2000 = prims.convert_element_type(t1999, dtypes.bfloat16) # t2000: "cuda:0 bf16[1, 512, 11008]" - t2001 = prims.convert_element_type(t2000, dtypes.float32) # t2001: "cuda:0 f32[1, 512, 11008]" - t2002 = prims.convert_element_type(t1990, dtypes.float32) # t2002: "cuda:0 f32[1, 512, 11008]" - t2003 = ltorch.mul(t2001, t2002) # t2003: "cuda:0 f32[1, 512, 11008]" - # t2003 = prims.mul(t2001, t2002) # t2003: "cuda:0 f32[1, 512, 11008]" - t2004 = prims.convert_element_type(t2003, dtypes.bfloat16) # t2004: "cuda:0 bf16[1, 512, 11008]" - t2005 = prims.linear(t2004, t_transformer_h_13_mlp_proj_weight, None) # t2005: "cuda:0 bf16[1, 512, 4096]" - t2006 = prims.convert_element_type(t2005, dtypes.float32) # t2006: "cuda:0 f32[1, 512, 4096]" - t2007 = prims.convert_element_type(t1970, dtypes.float32) # t2007: "cuda:0 f32[1, 512, 4096]" - t2008 = ltorch.add(t2006, t2007, alpha=None) # t2008: "cuda:0 f32[1, 512, 4096]" - # t2008 = prims.add(t2006, t2007) # t2008: "cuda:0 f32[1, 512, 4096]" - t2009 = prims.convert_element_type(t2008, dtypes.bfloat16) # t2009: "cuda:0 bf16[1, 512, 4096]" - t2010 = prims.convert_element_type(t2009, dtypes.float32) # t2010: "cuda:0 f32[1, 512, 4096]" - t2011 = ltorch.mul(t2010, t2010) # t2011: "cuda:0 f32[1, 512, 4096]" - # t2011 = prims.mul(t2010, t2010) # t2011: "cuda:0 f32[1, 512, 4096]" - t2013 = prims.sum(t2011, (2,)) # t2013: "cuda:0 f32[1, 512]" - t2014 = prims.broadcast_in_dim(t2013, [1, 512, 1], [0, 1]) # t2014: "cuda:0 f32[1, 512, 1]" - t2016 = ltorch.true_divide(t2014, 4096.0) # t2016: "cuda:0 f32[1, 512, 1]" - # t2016 = prims.div(t2014, 4096.0) # t2016: "cuda:0 f32[1, 512, 1]" - t2018 = ltorch.add(t2016, 1e-05, alpha=None) # t2018: "cuda:0 f32[1, 512, 1]" - # t2018 = prims.add(t2016, 1e-05) # t2018: "cuda:0 f32[1, 512, 1]" - t2019 = prims.rsqrt(t2018) # t2019: "cuda:0 f32[1, 512, 1]" - t2020 = prims.broadcast_in_dim(t2019, (1, 512, 4096), (0, 1, 2)) # t2020: "cuda:0 f32[1, 512, 4096]" - t2021 = ltorch.mul(t2010, t2020) # t2021: "cuda:0 f32[1, 512, 4096]" - # t2021 = prims.mul(t2010, t2020) # t2021: "cuda:0 f32[1, 512, 4096]" - t2022 = prims.convert_element_type(t2021, dtypes.bfloat16) # t2022: "cuda:0 bf16[1, 512, 4096]" - t2023 = prims.broadcast_in_dim(t_transformer_h_14_norm_1_weight, (1, 512, 4096), (2,)) # t2023: "cuda:0 bf16[1, 512, 4096]" - t2024 = prims.convert_element_type(t2022, dtypes.float32) # t2024: "cuda:0 f32[1, 512, 4096]" - t2025 = prims.convert_element_type(t2023, dtypes.float32) # t2025: "cuda:0 f32[1, 512, 4096]" - t2026 = ltorch.mul(t2024, t2025) # t2026: "cuda:0 f32[1, 512, 4096]" - # t2026 = prims.mul(t2024, t2025) # t2026: "cuda:0 f32[1, 512, 4096]" - t2027 = prims.convert_element_type(t2026, dtypes.bfloat16) # t2027: "cuda:0 bf16[1, 512, 4096]" - t2028 = prims.linear(t2027, t_transformer_h_14_attn_attn_weight, None) # t2028: "cuda:0 bf16[1, 512, 12288]" - t2034 = prims.reshape(t2028, (1, 512, 32, 3, 128)) # t2034: "cuda:0 bf16[1, 512, 32, 3, 128]" - t2040 = prims.transpose(t2034, (0, 2, 3, 1, 4)) # t2040: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t2041, t2042, t2043) = ltorch.split(t2040, (1, 1, 1), 2) - # t2041 = prims.slice_prim(t2040, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t2041: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2042 = prims.slice_prim(t2040, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t2042: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2043 = prims.slice_prim(t2040, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t2043: "cuda:0 bf16[1, 32, 1, 512, 128]" - t2049 = prims.reshape(t2041, (1, 32, 512, 128)) # t2049: "cuda:0 bf16[1, 32, 512, 128]" - t2055 = prims.reshape(t2042, (1, 32, 512, 128)) # t2055: "cuda:0 bf16[1, 32, 512, 128]" - t2061 = prims.reshape(t2043, (1, 32, 512, 128)) # t2061: "cuda:0 bf16[1, 32, 512, 128]" - t2062 = prims.slice_prim(t2049, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2062: "cuda:0 bf16[1, 32, 512, 128]" - t2063 = prims.slice_prim(t2062, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2063: "cuda:0 bf16[1, 32, 512, 64]" - t2064 = prims.slice_prim(t2062, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2064: "cuda:0 bf16[1, 32, 512, 64]" - t2065 = prims.convert_element_type(t2064, dtypes.float32) # t2065: "cuda:0 f32[1, 32, 512, 64]" - t2066 = prims.neg(t2065) # t2066: "cuda:0 f32[1, 32, 512, 64]" - t2067 = prims.convert_element_type(t2066, dtypes.bfloat16) # t2067: "cuda:0 bf16[1, 32, 512, 64]" - t2069 = prims.cat((t2067, t2063), -1) # t2069: "cuda:0 bf16[1, 32, 512, 128]" - t2070 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t2070: "cuda:0 f32[1, 32, 512, 128]" - t2071 = prims.convert_element_type(t2062, dtypes.float32) # t2071: "cuda:0 f32[1, 32, 512, 128]" - t2072 = ltorch.mul(t2071, t2070) # t2072: "cuda:0 f32[1, 32, 512, 128]" - # t2072 = prims.mul(t2071, t2070) # t2072: "cuda:0 f32[1, 32, 512, 128]" - t2073 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t2073: "cuda:0 f32[1, 32, 512, 128]" - t2074 = prims.convert_element_type(t2069, dtypes.float32) # t2074: "cuda:0 f32[1, 32, 512, 128]" - t2075 = ltorch.mul(t2074, t2073) # t2075: "cuda:0 f32[1, 32, 512, 128]" - # t2075 = prims.mul(t2074, t2073) # t2075: "cuda:0 f32[1, 32, 512, 128]" - t2076 = ltorch.add(t2072, t2075, alpha=None) # t2076: "cuda:0 f32[1, 32, 512, 128]" - # t2076 = prims.add(t2072, t2075) # t2076: "cuda:0 f32[1, 32, 512, 128]" - t2077 = prims.convert_element_type(t2076, dtypes.bfloat16) # t2077: "cuda:0 bf16[1, 32, 512, 128]" - t2078 = prims.slice_prim(t2055, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2078: "cuda:0 bf16[1, 32, 512, 128]" - t2079 = prims.slice_prim(t2078, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2079: "cuda:0 bf16[1, 32, 512, 64]" - t2080 = prims.slice_prim(t2078, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2080: "cuda:0 bf16[1, 32, 512, 64]" - t2081 = prims.convert_element_type(t2080, dtypes.float32) # t2081: "cuda:0 f32[1, 32, 512, 64]" - t2082 = prims.neg(t2081) # t2082: "cuda:0 f32[1, 32, 512, 64]" - t2083 = prims.convert_element_type(t2082, dtypes.bfloat16) # t2083: "cuda:0 bf16[1, 32, 512, 64]" - t2085 = prims.cat((t2083, t2079), -1) # t2085: "cuda:0 bf16[1, 32, 512, 128]" - t2086 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t2086: "cuda:0 f32[1, 32, 512, 128]" - t2087 = prims.convert_element_type(t2078, dtypes.float32) # t2087: "cuda:0 f32[1, 32, 512, 128]" - t2088 = ltorch.mul(t2087, t2086) # t2088: "cuda:0 f32[1, 32, 512, 128]" - # t2088 = prims.mul(t2087, t2086) # t2088: "cuda:0 f32[1, 32, 512, 128]" - t2089 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t2089: "cuda:0 f32[1, 32, 512, 128]" - t2090 = prims.convert_element_type(t2085, dtypes.float32) # t2090: "cuda:0 f32[1, 32, 512, 128]" - t2091 = ltorch.mul(t2090, t2089) # t2091: "cuda:0 f32[1, 32, 512, 128]" - # t2091 = prims.mul(t2090, t2089) # t2091: "cuda:0 f32[1, 32, 512, 128]" - t2092 = ltorch.add(t2088, t2091, alpha=None) # t2092: "cuda:0 f32[1, 32, 512, 128]" - # t2092 = prims.add(t2088, t2091) # t2092: "cuda:0 f32[1, 32, 512, 128]" - t2093 = prims.convert_element_type(t2092, dtypes.bfloat16) # t2093: "cuda:0 bf16[1, 32, 512, 128]" - t2094 = prims.slice_prim(t2049, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2094: "cuda:0 bf16[1, 32, 512, 0]" - t2096 = prims.cat((t2077, t2094), -1) # t2096: "cuda:0 bf16[1, 32, 512, 128]" - t2097 = prims.slice_prim(t2055, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2097: "cuda:0 bf16[1, 32, 512, 0]" - t2099 = prims.cat((t2093, t2097), -1) # t2099: "cuda:0 bf16[1, 32, 512, 128]" - (t2100, t2101, t2102, t2103) = cudnn_sdpa_fwd(t2096, t2099, t2061, None, 0.0, True, scale=0.08838834764831843) - t2106 = prims.transpose(t2100, (0, 2, 1, 3)) # t2106: "cuda:0 bf16[1, 512, 32, 128]" - t2110 = prims.reshape(t2106, (1, 512, 4096)) # t2110: "cuda:0 bf16[1, 512, 4096]" - t2111 = prims.linear(t2110, t_transformer_h_14_attn_proj_weight, None) # t2111: "cuda:0 bf16[1, 512, 4096]" - t2112 = prims.convert_element_type(t2111, dtypes.float32) # t2112: "cuda:0 f32[1, 512, 4096]" - t2113 = prims.convert_element_type(t2009, dtypes.float32) # t2113: "cuda:0 f32[1, 512, 4096]" - t2114 = ltorch.add(t2112, t2113, alpha=None) # t2114: "cuda:0 f32[1, 512, 4096]" - # t2114 = prims.add(t2112, t2113) # t2114: "cuda:0 f32[1, 512, 4096]" - t2115 = prims.convert_element_type(t2114, dtypes.bfloat16) # t2115: "cuda:0 bf16[1, 512, 4096]" - t2116 = prims.convert_element_type(t2115, dtypes.float32) # t2116: "cuda:0 f32[1, 512, 4096]" - t2117 = ltorch.mul(t2116, t2116) # t2117: "cuda:0 f32[1, 512, 4096]" - # t2117 = prims.mul(t2116, t2116) # t2117: "cuda:0 f32[1, 512, 4096]" - t2119 = prims.sum(t2117, (2,)) # t2119: "cuda:0 f32[1, 512]" - t2120 = prims.broadcast_in_dim(t2119, [1, 512, 1], [0, 1]) # t2120: "cuda:0 f32[1, 512, 1]" - t2122 = ltorch.true_divide(t2120, 4096.0) # t2122: "cuda:0 f32[1, 512, 1]" - # t2122 = prims.div(t2120, 4096.0) # t2122: "cuda:0 f32[1, 512, 1]" - t2124 = ltorch.add(t2122, 1e-05, alpha=None) # t2124: "cuda:0 f32[1, 512, 1]" - # t2124 = prims.add(t2122, 1e-05) # t2124: "cuda:0 f32[1, 512, 1]" - t2125 = prims.rsqrt(t2124) # t2125: "cuda:0 f32[1, 512, 1]" - t2126 = prims.broadcast_in_dim(t2125, (1, 512, 4096), (0, 1, 2)) # t2126: "cuda:0 f32[1, 512, 4096]" - t2127 = ltorch.mul(t2116, t2126) # t2127: "cuda:0 f32[1, 512, 4096]" - # t2127 = prims.mul(t2116, t2126) # t2127: "cuda:0 f32[1, 512, 4096]" - t2128 = prims.convert_element_type(t2127, dtypes.bfloat16) # t2128: "cuda:0 bf16[1, 512, 4096]" - t2129 = prims.broadcast_in_dim(t_transformer_h_14_norm_2_weight, (1, 512, 4096), (2,)) # t2129: "cuda:0 bf16[1, 512, 4096]" - t2130 = prims.convert_element_type(t2128, dtypes.float32) # t2130: "cuda:0 f32[1, 512, 4096]" - t2131 = prims.convert_element_type(t2129, dtypes.float32) # t2131: "cuda:0 f32[1, 512, 4096]" - t2132 = ltorch.mul(t2130, t2131) # t2132: "cuda:0 f32[1, 512, 4096]" - # t2132 = prims.mul(t2130, t2131) # t2132: "cuda:0 f32[1, 512, 4096]" - t2133 = prims.convert_element_type(t2132, dtypes.bfloat16) # t2133: "cuda:0 bf16[1, 512, 4096]" - t2134 = prims.linear(t2133, t_transformer_h_14_mlp_fc_1_weight, None) # t2134: "cuda:0 bf16[1, 512, 11008]" - t2135 = prims.linear(t2133, t_transformer_h_14_mlp_fc_2_weight, None) # t2135: "cuda:0 bf16[1, 512, 11008]" - t2136 = prims.convert_element_type(t2134, dtypes.float32) # t2136: "cuda:0 f32[1, 512, 11008]" - t2137 = prims.neg(t2136) # t2137: "cuda:0 f32[1, 512, 11008]" - t2138 = prims.exp(t2137) # t2138: "cuda:0 f32[1, 512, 11008]" - t2139 = ltorch.add(1.0, t2138, alpha=None) # t2139: "cuda:0 f32[1, 512, 11008]" - # t2139 = prims.add(1.0, t2138) # t2139: "cuda:0 f32[1, 512, 11008]" - t2140 = prims.reciprocal(t2139) # t2140: "cuda:0 f32[1, 512, 11008]" - t2141 = prims.convert_element_type(t2140, dtypes.bfloat16) # t2141: "cuda:0 bf16[1, 512, 11008]" - t2142 = prims.convert_element_type(t2134, dtypes.float32) # t2142: "cuda:0 f32[1, 512, 11008]" - t2143 = prims.convert_element_type(t2141, dtypes.float32) # t2143: "cuda:0 f32[1, 512, 11008]" - t2144 = ltorch.mul(t2142, t2143) # t2144: "cuda:0 f32[1, 512, 11008]" - # t2144 = prims.mul(t2142, t2143) # t2144: "cuda:0 f32[1, 512, 11008]" - t2145 = prims.convert_element_type(t2144, dtypes.bfloat16) # t2145: "cuda:0 bf16[1, 512, 11008]" - t2146 = prims.convert_element_type(t2145, dtypes.float32) # t2146: "cuda:0 f32[1, 512, 11008]" - t2147 = prims.convert_element_type(t2135, dtypes.float32) # t2147: "cuda:0 f32[1, 512, 11008]" - t2148 = ltorch.mul(t2146, t2147) # t2148: "cuda:0 f32[1, 512, 11008]" - # t2148 = prims.mul(t2146, t2147) # t2148: "cuda:0 f32[1, 512, 11008]" - t2149 = prims.convert_element_type(t2148, dtypes.bfloat16) # t2149: "cuda:0 bf16[1, 512, 11008]" - t2150 = prims.linear(t2149, t_transformer_h_14_mlp_proj_weight, None) # t2150: "cuda:0 bf16[1, 512, 4096]" - t2151 = prims.convert_element_type(t2150, dtypes.float32) # t2151: "cuda:0 f32[1, 512, 4096]" - t2152 = prims.convert_element_type(t2115, dtypes.float32) # t2152: "cuda:0 f32[1, 512, 4096]" - t2153 = ltorch.add(t2151, t2152, alpha=None) # t2153: "cuda:0 f32[1, 512, 4096]" - # t2153 = prims.add(t2151, t2152) # t2153: "cuda:0 f32[1, 512, 4096]" - t2154 = prims.convert_element_type(t2153, dtypes.bfloat16) # t2154: "cuda:0 bf16[1, 512, 4096]" - t2155 = prims.convert_element_type(t2154, dtypes.float32) # t2155: "cuda:0 f32[1, 512, 4096]" - t2156 = ltorch.mul(t2155, t2155) # t2156: "cuda:0 f32[1, 512, 4096]" - # t2156 = prims.mul(t2155, t2155) # t2156: "cuda:0 f32[1, 512, 4096]" - t2158 = prims.sum(t2156, (2,)) # t2158: "cuda:0 f32[1, 512]" - t2159 = prims.broadcast_in_dim(t2158, [1, 512, 1], [0, 1]) # t2159: "cuda:0 f32[1, 512, 1]" - t2161 = ltorch.true_divide(t2159, 4096.0) # t2161: "cuda:0 f32[1, 512, 1]" - # t2161 = prims.div(t2159, 4096.0) # t2161: "cuda:0 f32[1, 512, 1]" - t2163 = ltorch.add(t2161, 1e-05, alpha=None) # t2163: "cuda:0 f32[1, 512, 1]" - # t2163 = prims.add(t2161, 1e-05) # t2163: "cuda:0 f32[1, 512, 1]" - t2164 = prims.rsqrt(t2163) # t2164: "cuda:0 f32[1, 512, 1]" - t2165 = prims.broadcast_in_dim(t2164, (1, 512, 4096), (0, 1, 2)) # t2165: "cuda:0 f32[1, 512, 4096]" - t2166 = ltorch.mul(t2155, t2165) # t2166: "cuda:0 f32[1, 512, 4096]" - # t2166 = prims.mul(t2155, t2165) # t2166: "cuda:0 f32[1, 512, 4096]" - t2167 = prims.convert_element_type(t2166, dtypes.bfloat16) # t2167: "cuda:0 bf16[1, 512, 4096]" - t2168 = prims.broadcast_in_dim(t_transformer_h_15_norm_1_weight, (1, 512, 4096), (2,)) # t2168: "cuda:0 bf16[1, 512, 4096]" - t2169 = prims.convert_element_type(t2167, dtypes.float32) # t2169: "cuda:0 f32[1, 512, 4096]" - t2170 = prims.convert_element_type(t2168, dtypes.float32) # t2170: "cuda:0 f32[1, 512, 4096]" - t2171 = ltorch.mul(t2169, t2170) # t2171: "cuda:0 f32[1, 512, 4096]" - # t2171 = prims.mul(t2169, t2170) # t2171: "cuda:0 f32[1, 512, 4096]" - t2172 = prims.convert_element_type(t2171, dtypes.bfloat16) # t2172: "cuda:0 bf16[1, 512, 4096]" - t2173 = prims.linear(t2172, t_transformer_h_15_attn_attn_weight, None) # t2173: "cuda:0 bf16[1, 512, 12288]" - t2179 = prims.reshape(t2173, (1, 512, 32, 3, 128)) # t2179: "cuda:0 bf16[1, 512, 32, 3, 128]" - t2185 = prims.transpose(t2179, (0, 2, 3, 1, 4)) # t2185: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t2186, t2187, t2188) = ltorch.split(t2185, (1, 1, 1), 2) - # t2186 = prims.slice_prim(t2185, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t2186: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2187 = prims.slice_prim(t2185, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t2187: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2188 = prims.slice_prim(t2185, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t2188: "cuda:0 bf16[1, 32, 1, 512, 128]" - t2194 = prims.reshape(t2186, (1, 32, 512, 128)) # t2194: "cuda:0 bf16[1, 32, 512, 128]" - t2200 = prims.reshape(t2187, (1, 32, 512, 128)) # t2200: "cuda:0 bf16[1, 32, 512, 128]" - t2206 = prims.reshape(t2188, (1, 32, 512, 128)) # t2206: "cuda:0 bf16[1, 32, 512, 128]" - t2207 = prims.slice_prim(t2194, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2207: "cuda:0 bf16[1, 32, 512, 128]" - t2208 = prims.slice_prim(t2207, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2208: "cuda:0 bf16[1, 32, 512, 64]" - t2209 = prims.slice_prim(t2207, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2209: "cuda:0 bf16[1, 32, 512, 64]" - t2210 = prims.convert_element_type(t2209, dtypes.float32) # t2210: "cuda:0 f32[1, 32, 512, 64]" - t2211 = prims.neg(t2210) # t2211: "cuda:0 f32[1, 32, 512, 64]" - t2212 = prims.convert_element_type(t2211, dtypes.bfloat16) # t2212: "cuda:0 bf16[1, 32, 512, 64]" - t2214 = prims.cat((t2212, t2208), -1) # t2214: "cuda:0 bf16[1, 32, 512, 128]" - t2215 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t2215: "cuda:0 f32[1, 32, 512, 128]" - t2216 = prims.convert_element_type(t2207, dtypes.float32) # t2216: "cuda:0 f32[1, 32, 512, 128]" - t2217 = ltorch.mul(t2216, t2215) # t2217: "cuda:0 f32[1, 32, 512, 128]" - # t2217 = prims.mul(t2216, t2215) # t2217: "cuda:0 f32[1, 32, 512, 128]" - t2218 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t2218: "cuda:0 f32[1, 32, 512, 128]" - t2219 = prims.convert_element_type(t2214, dtypes.float32) # t2219: "cuda:0 f32[1, 32, 512, 128]" - t2220 = ltorch.mul(t2219, t2218) # t2220: "cuda:0 f32[1, 32, 512, 128]" - # t2220 = prims.mul(t2219, t2218) # t2220: "cuda:0 f32[1, 32, 512, 128]" - t2221 = ltorch.add(t2217, t2220, alpha=None) # t2221: "cuda:0 f32[1, 32, 512, 128]" - # t2221 = prims.add(t2217, t2220) # t2221: "cuda:0 f32[1, 32, 512, 128]" - t2222 = prims.convert_element_type(t2221, dtypes.bfloat16) # t2222: "cuda:0 bf16[1, 32, 512, 128]" - t2223 = prims.slice_prim(t2200, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2223: "cuda:0 bf16[1, 32, 512, 128]" - t2224 = prims.slice_prim(t2223, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2224: "cuda:0 bf16[1, 32, 512, 64]" - t2225 = prims.slice_prim(t2223, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2225: "cuda:0 bf16[1, 32, 512, 64]" - t2226 = prims.convert_element_type(t2225, dtypes.float32) # t2226: "cuda:0 f32[1, 32, 512, 64]" - t2227 = prims.neg(t2226) # t2227: "cuda:0 f32[1, 32, 512, 64]" - t2228 = prims.convert_element_type(t2227, dtypes.bfloat16) # t2228: "cuda:0 bf16[1, 32, 512, 64]" - t2230 = prims.cat((t2228, t2224), -1) # t2230: "cuda:0 bf16[1, 32, 512, 128]" - t2231 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t2231: "cuda:0 f32[1, 32, 512, 128]" - t2232 = prims.convert_element_type(t2223, dtypes.float32) # t2232: "cuda:0 f32[1, 32, 512, 128]" - t2233 = ltorch.mul(t2232, t2231) # t2233: "cuda:0 f32[1, 32, 512, 128]" - # t2233 = prims.mul(t2232, t2231) # t2233: "cuda:0 f32[1, 32, 512, 128]" - t2234 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t2234: "cuda:0 f32[1, 32, 512, 128]" - t2235 = prims.convert_element_type(t2230, dtypes.float32) # t2235: "cuda:0 f32[1, 32, 512, 128]" - t2236 = ltorch.mul(t2235, t2234) # t2236: "cuda:0 f32[1, 32, 512, 128]" - # t2236 = prims.mul(t2235, t2234) # t2236: "cuda:0 f32[1, 32, 512, 128]" - t2237 = ltorch.add(t2233, t2236, alpha=None) # t2237: "cuda:0 f32[1, 32, 512, 128]" - # t2237 = prims.add(t2233, t2236) # t2237: "cuda:0 f32[1, 32, 512, 128]" - t2238 = prims.convert_element_type(t2237, dtypes.bfloat16) # t2238: "cuda:0 bf16[1, 32, 512, 128]" - t2239 = prims.slice_prim(t2194, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2239: "cuda:0 bf16[1, 32, 512, 0]" - t2241 = prims.cat((t2222, t2239), -1) # t2241: "cuda:0 bf16[1, 32, 512, 128]" - t2242 = prims.slice_prim(t2200, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2242: "cuda:0 bf16[1, 32, 512, 0]" - t2244 = prims.cat((t2238, t2242), -1) # t2244: "cuda:0 bf16[1, 32, 512, 128]" - (t2245, t2246, t2247, t2248) = cudnn_sdpa_fwd(t2241, t2244, t2206, None, 0.0, True, scale=0.08838834764831843) - t2251 = prims.transpose(t2245, (0, 2, 1, 3)) # t2251: "cuda:0 bf16[1, 512, 32, 128]" - t2255 = prims.reshape(t2251, (1, 512, 4096)) # t2255: "cuda:0 bf16[1, 512, 4096]" - t2256 = prims.linear(t2255, t_transformer_h_15_attn_proj_weight, None) # t2256: "cuda:0 bf16[1, 512, 4096]" - t2257 = prims.convert_element_type(t2256, dtypes.float32) # t2257: "cuda:0 f32[1, 512, 4096]" - t2258 = prims.convert_element_type(t2154, dtypes.float32) # t2258: "cuda:0 f32[1, 512, 4096]" - t2259 = ltorch.add(t2257, t2258, alpha=None) # t2259: "cuda:0 f32[1, 512, 4096]" - # t2259 = prims.add(t2257, t2258) # t2259: "cuda:0 f32[1, 512, 4096]" - t2260 = prims.convert_element_type(t2259, dtypes.bfloat16) # t2260: "cuda:0 bf16[1, 512, 4096]" - t2261 = prims.convert_element_type(t2260, dtypes.float32) # t2261: "cuda:0 f32[1, 512, 4096]" - t2262 = ltorch.mul(t2261, t2261) # t2262: "cuda:0 f32[1, 512, 4096]" - # t2262 = prims.mul(t2261, t2261) # t2262: "cuda:0 f32[1, 512, 4096]" - t2264 = prims.sum(t2262, (2,)) # t2264: "cuda:0 f32[1, 512]" - t2265 = prims.broadcast_in_dim(t2264, [1, 512, 1], [0, 1]) # t2265: "cuda:0 f32[1, 512, 1]" - t2267 = ltorch.true_divide(t2265, 4096.0) # t2267: "cuda:0 f32[1, 512, 1]" - # t2267 = prims.div(t2265, 4096.0) # t2267: "cuda:0 f32[1, 512, 1]" - t2269 = ltorch.add(t2267, 1e-05, alpha=None) # t2269: "cuda:0 f32[1, 512, 1]" - # t2269 = prims.add(t2267, 1e-05) # t2269: "cuda:0 f32[1, 512, 1]" - t2270 = prims.rsqrt(t2269) # t2270: "cuda:0 f32[1, 512, 1]" - t2271 = prims.broadcast_in_dim(t2270, (1, 512, 4096), (0, 1, 2)) # t2271: "cuda:0 f32[1, 512, 4096]" - t2272 = ltorch.mul(t2261, t2271) # t2272: "cuda:0 f32[1, 512, 4096]" - # t2272 = prims.mul(t2261, t2271) # t2272: "cuda:0 f32[1, 512, 4096]" - t2273 = prims.convert_element_type(t2272, dtypes.bfloat16) # t2273: "cuda:0 bf16[1, 512, 4096]" - t2274 = prims.broadcast_in_dim(t_transformer_h_15_norm_2_weight, (1, 512, 4096), (2,)) # t2274: "cuda:0 bf16[1, 512, 4096]" - t2275 = prims.convert_element_type(t2273, dtypes.float32) # t2275: "cuda:0 f32[1, 512, 4096]" - t2276 = prims.convert_element_type(t2274, dtypes.float32) # t2276: "cuda:0 f32[1, 512, 4096]" - t2277 = ltorch.mul(t2275, t2276) # t2277: "cuda:0 f32[1, 512, 4096]" - # t2277 = prims.mul(t2275, t2276) # t2277: "cuda:0 f32[1, 512, 4096]" - t2278 = prims.convert_element_type(t2277, dtypes.bfloat16) # t2278: "cuda:0 bf16[1, 512, 4096]" - t2279 = prims.linear(t2278, t_transformer_h_15_mlp_fc_1_weight, None) # t2279: "cuda:0 bf16[1, 512, 11008]" - t2280 = prims.linear(t2278, t_transformer_h_15_mlp_fc_2_weight, None) # t2280: "cuda:0 bf16[1, 512, 11008]" - t2281 = prims.convert_element_type(t2279, dtypes.float32) # t2281: "cuda:0 f32[1, 512, 11008]" - t2282 = prims.neg(t2281) # t2282: "cuda:0 f32[1, 512, 11008]" - t2283 = prims.exp(t2282) # t2283: "cuda:0 f32[1, 512, 11008]" - t2284 = ltorch.add(1.0, t2283, alpha=None) # t2284: "cuda:0 f32[1, 512, 11008]" - # t2284 = prims.add(1.0, t2283) # t2284: "cuda:0 f32[1, 512, 11008]" - t2285 = prims.reciprocal(t2284) # t2285: "cuda:0 f32[1, 512, 11008]" - t2286 = prims.convert_element_type(t2285, dtypes.bfloat16) # t2286: "cuda:0 bf16[1, 512, 11008]" - t2287 = prims.convert_element_type(t2279, dtypes.float32) # t2287: "cuda:0 f32[1, 512, 11008]" - t2288 = prims.convert_element_type(t2286, dtypes.float32) # t2288: "cuda:0 f32[1, 512, 11008]" - t2289 = ltorch.mul(t2287, t2288) # t2289: "cuda:0 f32[1, 512, 11008]" - # t2289 = prims.mul(t2287, t2288) # t2289: "cuda:0 f32[1, 512, 11008]" - t2290 = prims.convert_element_type(t2289, dtypes.bfloat16) # t2290: "cuda:0 bf16[1, 512, 11008]" - t2291 = prims.convert_element_type(t2290, dtypes.float32) # t2291: "cuda:0 f32[1, 512, 11008]" - t2292 = prims.convert_element_type(t2280, dtypes.float32) # t2292: "cuda:0 f32[1, 512, 11008]" - t2293 = ltorch.mul(t2291, t2292) # t2293: "cuda:0 f32[1, 512, 11008]" - # t2293 = prims.mul(t2291, t2292) # t2293: "cuda:0 f32[1, 512, 11008]" - t2294 = prims.convert_element_type(t2293, dtypes.bfloat16) # t2294: "cuda:0 bf16[1, 512, 11008]" - t2295 = prims.linear(t2294, t_transformer_h_15_mlp_proj_weight, None) # t2295: "cuda:0 bf16[1, 512, 4096]" - t2296 = prims.convert_element_type(t2295, dtypes.float32) # t2296: "cuda:0 f32[1, 512, 4096]" - t2297 = prims.convert_element_type(t2260, dtypes.float32) # t2297: "cuda:0 f32[1, 512, 4096]" - t2298 = ltorch.add(t2296, t2297, alpha=None) # t2298: "cuda:0 f32[1, 512, 4096]" - # t2298 = prims.add(t2296, t2297) # t2298: "cuda:0 f32[1, 512, 4096]" - t2299 = prims.convert_element_type(t2298, dtypes.bfloat16) # t2299: "cuda:0 bf16[1, 512, 4096]" - t2300 = prims.convert_element_type(t2299, dtypes.float32) # t2300: "cuda:0 f32[1, 512, 4096]" - t2301 = ltorch.mul(t2300, t2300) # t2301: "cuda:0 f32[1, 512, 4096]" - # t2301 = prims.mul(t2300, t2300) # t2301: "cuda:0 f32[1, 512, 4096]" - t2303 = prims.sum(t2301, (2,)) # t2303: "cuda:0 f32[1, 512]" - t2304 = prims.broadcast_in_dim(t2303, [1, 512, 1], [0, 1]) # t2304: "cuda:0 f32[1, 512, 1]" - t2306 = ltorch.true_divide(t2304, 4096.0) # t2306: "cuda:0 f32[1, 512, 1]" - # t2306 = prims.div(t2304, 4096.0) # t2306: "cuda:0 f32[1, 512, 1]" - t2308 = ltorch.add(t2306, 1e-05, alpha=None) # t2308: "cuda:0 f32[1, 512, 1]" - # t2308 = prims.add(t2306, 1e-05) # t2308: "cuda:0 f32[1, 512, 1]" - t2309 = prims.rsqrt(t2308) # t2309: "cuda:0 f32[1, 512, 1]" - t2310 = prims.broadcast_in_dim(t2309, (1, 512, 4096), (0, 1, 2)) # t2310: "cuda:0 f32[1, 512, 4096]" - t2311 = ltorch.mul(t2300, t2310) # t2311: "cuda:0 f32[1, 512, 4096]" - # t2311 = prims.mul(t2300, t2310) # t2311: "cuda:0 f32[1, 512, 4096]" - t2312 = prims.convert_element_type(t2311, dtypes.bfloat16) # t2312: "cuda:0 bf16[1, 512, 4096]" - t2313 = prims.broadcast_in_dim(t_transformer_ln_f_weight, (1, 512, 4096), (2,)) # t2313: "cuda:0 bf16[1, 512, 4096]" - t2314 = prims.convert_element_type(t2312, dtypes.float32) # t2314: "cuda:0 f32[1, 512, 4096]" - t2315 = prims.convert_element_type(t2313, dtypes.float32) # t2315: "cuda:0 f32[1, 512, 4096]" - t2316 = ltorch.mul(t2314, t2315) # t2316: "cuda:0 f32[1, 512, 4096]" - # t2316 = prims.mul(t2314, t2315) # t2316: "cuda:0 f32[1, 512, 4096]" - t2317 = prims.convert_element_type(t2316, dtypes.bfloat16) # t2317: "cuda:0 bf16[1, 512, 4096]" - t2318 = prims.linear(t2317, t_lm_head_weight, None) # t2318: "cuda:0 bf16[1, 512, 32000]" - return {'output': t2318, 'flat_args': [idx, tos1, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_1_weight, t_transformer_h_0_mlp_fc_2_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_weight, t_transformer_h_0_norm_2_weight, t_transformer_h_1_attn_attn_weight, t_transformer_h_1_attn_proj_weight, t_transformer_h_1_mlp_fc_1_weight, t_transformer_h_1_mlp_fc_2_weight, t_transformer_h_1_mlp_proj_weight, t_transformer_h_1_norm_1_weight, t_transformer_h_1_norm_2_weight, t_transformer_h_2_attn_attn_weight, t_transformer_h_2_attn_proj_weight, t_transformer_h_2_mlp_fc_1_weight, t_transformer_h_2_mlp_fc_2_weight, t_transformer_h_2_mlp_proj_weight, t_transformer_h_2_norm_1_weight, t_transformer_h_2_norm_2_weight, t_transformer_h_3_attn_attn_weight, t_transformer_h_3_attn_proj_weight, t_transformer_h_3_mlp_fc_1_weight, t_transformer_h_3_mlp_fc_2_weight, t_transformer_h_3_mlp_proj_weight, t_transformer_h_3_norm_1_weight, t_transformer_h_3_norm_2_weight, t_transformer_h_4_attn_attn_weight, t_transformer_h_4_attn_proj_weight, t_transformer_h_4_mlp_fc_1_weight, t_transformer_h_4_mlp_fc_2_weight, t_transformer_h_4_mlp_proj_weight, t_transformer_h_4_norm_1_weight, t_transformer_h_4_norm_2_weight, t_transformer_h_5_attn_attn_weight, t_transformer_h_5_attn_proj_weight, t_transformer_h_5_mlp_fc_1_weight, t_transformer_h_5_mlp_fc_2_weight, t_transformer_h_5_mlp_proj_weight, t_transformer_h_5_norm_1_weight, t_transformer_h_5_norm_2_weight, t_transformer_h_6_attn_attn_weight, t_transformer_h_6_attn_proj_weight, t_transformer_h_6_mlp_fc_1_weight, t_transformer_h_6_mlp_fc_2_weight, t_transformer_h_6_mlp_proj_weight, t_transformer_h_6_norm_1_weight, t_transformer_h_6_norm_2_weight, t_transformer_h_7_attn_attn_weight, t_transformer_h_7_attn_proj_weight, t_transformer_h_7_mlp_fc_1_weight, t_transformer_h_7_mlp_fc_2_weight, t_transformer_h_7_mlp_proj_weight, t_transformer_h_7_norm_1_weight, t_transformer_h_7_norm_2_weight, t_transformer_h_8_attn_attn_weight, t_transformer_h_8_attn_proj_weight, t_transformer_h_8_mlp_fc_1_weight, t_transformer_h_8_mlp_fc_2_weight, t_transformer_h_8_mlp_proj_weight, t_transformer_h_8_norm_1_weight, t_transformer_h_8_norm_2_weight, t_transformer_h_9_attn_attn_weight, t_transformer_h_9_attn_proj_weight, t_transformer_h_9_mlp_fc_1_weight, t_transformer_h_9_mlp_fc_2_weight, t_transformer_h_9_mlp_proj_weight, t_transformer_h_9_norm_1_weight, t_transformer_h_9_norm_2_weight, t_transformer_h_10_attn_attn_weight, t_transformer_h_10_attn_proj_weight, t_transformer_h_10_mlp_fc_1_weight, t_transformer_h_10_mlp_fc_2_weight, t_transformer_h_10_mlp_proj_weight, t_transformer_h_10_norm_1_weight, t_transformer_h_10_norm_2_weight, t_transformer_h_11_attn_attn_weight, t_transformer_h_11_attn_proj_weight, t_transformer_h_11_mlp_fc_1_weight, t_transformer_h_11_mlp_fc_2_weight, t_transformer_h_11_mlp_proj_weight, t_transformer_h_11_norm_1_weight, t_transformer_h_11_norm_2_weight, t_transformer_h_12_attn_attn_weight, t_transformer_h_12_attn_proj_weight, t_transformer_h_12_mlp_fc_1_weight, t_transformer_h_12_mlp_fc_2_weight, t_transformer_h_12_mlp_proj_weight, t_transformer_h_12_norm_1_weight, t_transformer_h_12_norm_2_weight, t_transformer_h_13_attn_attn_weight, t_transformer_h_13_attn_proj_weight, t_transformer_h_13_mlp_fc_1_weight, t_transformer_h_13_mlp_fc_2_weight, t_transformer_h_13_mlp_proj_weight, t_transformer_h_13_norm_1_weight, t_transformer_h_13_norm_2_weight, t_transformer_h_14_attn_attn_weight, t_transformer_h_14_attn_proj_weight, t_transformer_h_14_mlp_fc_1_weight, t_transformer_h_14_mlp_fc_2_weight, t_transformer_h_14_mlp_proj_weight, t_transformer_h_14_norm_1_weight, t_transformer_h_14_norm_2_weight, t_transformer_h_15_attn_attn_weight, t_transformer_h_15_attn_proj_weight, t_transformer_h_15_mlp_fc_1_weight, t_transformer_h_15_mlp_fc_2_weight, t_transformer_h_15_mlp_proj_weight, t_transformer_h_15_norm_1_weight, t_transformer_h_15_norm_2_weight, t_transformer_ln_f_weight, t_transformer_wte_weight], 'flat_output': (t2318,)}, ((idx, t5, t11, t12, t17, t16, t19, t_transformer_h_0_attn_attn_weight, t46, t47, t49, t50, t62, t63, t65, t66, t71, t74, t38, t75, t76, t77, t78, t80, t_transformer_h_0_attn_proj_weight, t86, t95, t96, t101, t100, t103, t_transformer_h_0_mlp_fc_1_weight, t_transformer_h_0_mlp_fc_2_weight, t108, t110, t113, t112, t117, t116, t119, t_transformer_h_0_mlp_proj_weight, t125, t134, t135, t140, t139, t142, t_transformer_h_1_attn_attn_weight, t185, t186, t188, t189, t201, t202, t204, t205, t211, t214, t176, t215, t216, t217, t218, t225, t_transformer_h_1_attn_proj_weight, t231, t240, t241, t246, t245, t248, t_transformer_h_1_mlp_fc_1_weight, t_transformer_h_1_mlp_fc_2_weight, t253, t255, t258, t257, t262, t261, t264, t_transformer_h_1_mlp_proj_weight, t270, t279, t280, t285, t284, t287, t_transformer_h_2_attn_attn_weight, t330, t331, t333, t334, t346, t347, t349, t350, t356, t359, t321, t360, t361, t362, t363, t370, t_transformer_h_2_attn_proj_weight, t376, t385, t386, t391, t390, t393, t_transformer_h_2_mlp_fc_1_weight, t_transformer_h_2_mlp_fc_2_weight, t398, t400, t403, t402, t407, t406, t409, t_transformer_h_2_mlp_proj_weight, t415, t424, t425, t430, t429, t432, t_transformer_h_3_attn_attn_weight, t475, t476, t478, t479, t491, t492, t494, t495, t501, t504, t466, t505, t506, t507, t508, t515, t_transformer_h_3_attn_proj_weight, t521, t530, t531, t536, t535, t538, t_transformer_h_3_mlp_fc_1_weight, t_transformer_h_3_mlp_fc_2_weight, t543, t545, t548, t547, t552, t551, t554, t_transformer_h_3_mlp_proj_weight, t560, t569, t570, t575, t574, t577, t_transformer_h_4_attn_attn_weight, t620, t621, t623, t624, t636, t637, t639, t640, t646, t649, t611, t650, t651, t652, t653, t660, t_transformer_h_4_attn_proj_weight, t666, t675, t676, t681, t680, t683, t_transformer_h_4_mlp_fc_1_weight, t_transformer_h_4_mlp_fc_2_weight, t688, t690, t693, t692, t697, t696, t699, t_transformer_h_4_mlp_proj_weight, t705, t714, t715, t720, t719, t722, t_transformer_h_5_attn_attn_weight, t765, t766, t768, t769, t781, t782, t784, t785, t791, t794, t756, t795, t796, t797, t798, t805, t_transformer_h_5_attn_proj_weight, t811, t820, t821, t826, t825, t828, t_transformer_h_5_mlp_fc_1_weight, t_transformer_h_5_mlp_fc_2_weight, t833, t835, t838, t837, t842, t841, t844, t_transformer_h_5_mlp_proj_weight, t850, t859, t860, t865, t864, t867, t_transformer_h_6_attn_attn_weight, t910, t911, t913, t914, t926, t927, t929, t930, t936, t939, t901, t940, t941, t942, t943, t950, t_transformer_h_6_attn_proj_weight, t956, t965, t966, t971, t970, t973, t_transformer_h_6_mlp_fc_1_weight, t_transformer_h_6_mlp_fc_2_weight, t978, t980, t983, t982, t987, t986, t989, t_transformer_h_6_mlp_proj_weight, t995, t1004, t1005, t1010, t1009, t1012, t_transformer_h_7_attn_attn_weight, t1055, t1056, t1058, t1059, t1071, t1072, t1074, t1075, t1081, t1084, t1046, t1085, t1086, t1087, t1088, t1095, t_transformer_h_7_attn_proj_weight, t1101, t1110, t1111, t1116, t1115, t1118, t_transformer_h_7_mlp_fc_1_weight, t_transformer_h_7_mlp_fc_2_weight, t1123, t1125, t1128, t1127, t1132, t1131, t1134, t_transformer_h_7_mlp_proj_weight, t1140, t1149, t1150, t1155, t1154, t1157, t_transformer_h_8_attn_attn_weight, t1200, t1201, t1203, t1204, t1216, t1217, t1219, t1220, t1226, t1229, t1191, t1230, t1231, t1232, t1233, t1240, t_transformer_h_8_attn_proj_weight, t1246, t1255, t1256, t1261, t1260, t1263, t_transformer_h_8_mlp_fc_1_weight, t_transformer_h_8_mlp_fc_2_weight, t1268, t1270, t1273, t1272, t1277, t1276, t1279, t_transformer_h_8_mlp_proj_weight, t1285, t1294, t1295, t1300, t1299, t1302, t_transformer_h_9_attn_attn_weight, t1345, t1346, t1348, t1349, t1361, t1362, t1364, t1365, t1371, t1374, t1336, t1375, t1376, t1377, t1378, t1385, t_transformer_h_9_attn_proj_weight, t1391, t1400, t1401, t1406, t1405, t1408, t_transformer_h_9_mlp_fc_1_weight, t_transformer_h_9_mlp_fc_2_weight, t1413, t1415, t1418, t1417, t1422, t1421, t1424, t_transformer_h_9_mlp_proj_weight, t1430, t1439, t1440, t1445, t1444, t1447, t_transformer_h_10_attn_attn_weight, t1490, t1491, t1493, t1494, t1506, t1507, t1509, t1510, t1516, t1519, t1481, t1520, t1521, t1522, t1523, t1530, t_transformer_h_10_attn_proj_weight, t1536, t1545, t1546, t1551, t1550, t1553, t_transformer_h_10_mlp_fc_1_weight, t_transformer_h_10_mlp_fc_2_weight, t1558, t1560, t1563, t1562, t1567, t1566, t1569, t_transformer_h_10_mlp_proj_weight, t1575, t1584, t1585, t1590, t1589, t1592, t_transformer_h_11_attn_attn_weight, t1635, t1636, t1638, t1639, t1651, t1652, t1654, t1655, t1661, t1664, t1626, t1665, t1666, t1667, t1668, t1675, t_transformer_h_11_attn_proj_weight, t1681, t1690, t1691, t1696, t1695, t1698, t_transformer_h_11_mlp_fc_1_weight, t_transformer_h_11_mlp_fc_2_weight, t1703, t1705, t1708, t1707, t1712, t1711, t1714, t_transformer_h_11_mlp_proj_weight, t1720, t1729, t1730, t1735, t1734, t1737, t_transformer_h_12_attn_attn_weight, t1780, t1781, t1783, t1784, t1796, t1797, t1799, t1800, t1806, t1809, t1771, t1810, t1811, t1812, t1813, t1820, t_transformer_h_12_attn_proj_weight, t1826, t1835, t1836, t1841, t1840, t1843, t_transformer_h_12_mlp_fc_1_weight, t_transformer_h_12_mlp_fc_2_weight, t1848, t1850, t1853, t1852, t1857, t1856, t1859, t_transformer_h_12_mlp_proj_weight, t1865, t1874, t1875, t1880, t1879, t1882, t_transformer_h_13_attn_attn_weight, t1925, t1926, t1928, t1929, t1941, t1942, t1944, t1945, t1951, t1954, t1916, t1955, t1956, t1957, t1958, t1965, t_transformer_h_13_attn_proj_weight, t1971, t1980, t1981, t1986, t1985, t1988, t_transformer_h_13_mlp_fc_1_weight, t_transformer_h_13_mlp_fc_2_weight, t1993, t1995, t1998, t1997, t2002, t2001, t2004, t_transformer_h_13_mlp_proj_weight, t2010, t2019, t2020, t2025, t2024, t2027, t_transformer_h_14_attn_attn_weight, t2070, t2071, t2073, t2074, t2086, t2087, t2089, t2090, t2096, t2099, t2061, t2100, t2101, t2102, t2103, t2110, t_transformer_h_14_attn_proj_weight, t2116, t2125, t2126, t2131, t2130, t2133, t_transformer_h_14_mlp_fc_1_weight, t_transformer_h_14_mlp_fc_2_weight, t2138, t2140, t2143, t2142, t2147, t2146, t2149, t_transformer_h_14_mlp_proj_weight, t2155, t2164, t2165, t2170, t2169, t2172, t_transformer_h_15_attn_attn_weight, t2215, t2216, t2218, t2219, t2231, t2232, t2234, t2235, t2241, t2244, t2206, t2245, t2246, t2247, t2248, t2255, t_transformer_h_15_attn_proj_weight, t2261, t2270, t2271, t2276, t2275, t2278, t_transformer_h_15_mlp_fc_1_weight, t_transformer_h_15_mlp_fc_2_weight, t2283, t2285, t2288, t2287, t2292, t2291, t2294, t_transformer_h_15_mlp_proj_weight, t2300, t2309, t2310, t2315, t2314, t2317, t_lm_head_weight), (32000, False, False, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0)) -============================================ END: before _transform_for_operator_executor_execution -============================================ START: after _transform_for_operator_executor_execution -# Constructed by Transform for operator executor execution (took 52 milliseconds) -import thunder -import thunder.core.dtypes as dtypes -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -import torch.nn.functional -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(idx, tos1, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_1_weight, t_transformer_h_0_mlp_fc_2_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_weight, t_transformer_h_0_norm_2_weight, t_transformer_h_1_attn_attn_weight, t_transformer_h_1_attn_proj_weight, t_transformer_h_1_mlp_fc_1_weight, t_transformer_h_1_mlp_fc_2_weight, t_transformer_h_1_mlp_proj_weight, t_transformer_h_1_norm_1_weight, t_transformer_h_1_norm_2_weight, t_transformer_h_2_attn_attn_weight, t_transformer_h_2_attn_proj_weight, t_transformer_h_2_mlp_fc_1_weight, t_transformer_h_2_mlp_fc_2_weight, t_transformer_h_2_mlp_proj_weight, t_transformer_h_2_norm_1_weight, t_transformer_h_2_norm_2_weight, t_transformer_h_3_attn_attn_weight, t_transformer_h_3_attn_proj_weight, t_transformer_h_3_mlp_fc_1_weight, t_transformer_h_3_mlp_fc_2_weight, t_transformer_h_3_mlp_proj_weight, t_transformer_h_3_norm_1_weight, t_transformer_h_3_norm_2_weight, t_transformer_h_4_attn_attn_weight, t_transformer_h_4_attn_proj_weight, t_transformer_h_4_mlp_fc_1_weight, t_transformer_h_4_mlp_fc_2_weight, t_transformer_h_4_mlp_proj_weight, t_transformer_h_4_norm_1_weight, t_transformer_h_4_norm_2_weight, t_transformer_h_5_attn_attn_weight, t_transformer_h_5_attn_proj_weight, t_transformer_h_5_mlp_fc_1_weight, t_transformer_h_5_mlp_fc_2_weight, t_transformer_h_5_mlp_proj_weight, t_transformer_h_5_norm_1_weight, t_transformer_h_5_norm_2_weight, t_transformer_h_6_attn_attn_weight, t_transformer_h_6_attn_proj_weight, t_transformer_h_6_mlp_fc_1_weight, t_transformer_h_6_mlp_fc_2_weight, t_transformer_h_6_mlp_proj_weight, t_transformer_h_6_norm_1_weight, t_transformer_h_6_norm_2_weight, t_transformer_h_7_attn_attn_weight, t_transformer_h_7_attn_proj_weight, t_transformer_h_7_mlp_fc_1_weight, t_transformer_h_7_mlp_fc_2_weight, t_transformer_h_7_mlp_proj_weight, t_transformer_h_7_norm_1_weight, t_transformer_h_7_norm_2_weight, t_transformer_h_8_attn_attn_weight, t_transformer_h_8_attn_proj_weight, t_transformer_h_8_mlp_fc_1_weight, t_transformer_h_8_mlp_fc_2_weight, t_transformer_h_8_mlp_proj_weight, t_transformer_h_8_norm_1_weight, t_transformer_h_8_norm_2_weight, t_transformer_h_9_attn_attn_weight, t_transformer_h_9_attn_proj_weight, t_transformer_h_9_mlp_fc_1_weight, t_transformer_h_9_mlp_fc_2_weight, t_transformer_h_9_mlp_proj_weight, t_transformer_h_9_norm_1_weight, t_transformer_h_9_norm_2_weight, t_transformer_h_10_attn_attn_weight, t_transformer_h_10_attn_proj_weight, t_transformer_h_10_mlp_fc_1_weight, t_transformer_h_10_mlp_fc_2_weight, t_transformer_h_10_mlp_proj_weight, t_transformer_h_10_norm_1_weight, t_transformer_h_10_norm_2_weight, t_transformer_h_11_attn_attn_weight, t_transformer_h_11_attn_proj_weight, t_transformer_h_11_mlp_fc_1_weight, t_transformer_h_11_mlp_fc_2_weight, t_transformer_h_11_mlp_proj_weight, t_transformer_h_11_norm_1_weight, t_transformer_h_11_norm_2_weight, t_transformer_h_12_attn_attn_weight, t_transformer_h_12_attn_proj_weight, t_transformer_h_12_mlp_fc_1_weight, t_transformer_h_12_mlp_fc_2_weight, t_transformer_h_12_mlp_proj_weight, t_transformer_h_12_norm_1_weight, t_transformer_h_12_norm_2_weight, t_transformer_h_13_attn_attn_weight, t_transformer_h_13_attn_proj_weight, t_transformer_h_13_mlp_fc_1_weight, t_transformer_h_13_mlp_fc_2_weight, t_transformer_h_13_mlp_proj_weight, t_transformer_h_13_norm_1_weight, t_transformer_h_13_norm_2_weight, t_transformer_h_14_attn_attn_weight, t_transformer_h_14_attn_proj_weight, t_transformer_h_14_mlp_fc_1_weight, t_transformer_h_14_mlp_fc_2_weight, t_transformer_h_14_mlp_proj_weight, t_transformer_h_14_norm_1_weight, t_transformer_h_14_norm_2_weight, t_transformer_h_15_attn_attn_weight, t_transformer_h_15_attn_proj_weight, t_transformer_h_15_mlp_fc_1_weight, t_transformer_h_15_mlp_fc_2_weight, t_transformer_h_15_mlp_proj_weight, t_transformer_h_15_norm_1_weight, t_transformer_h_15_norm_2_weight, t_transformer_ln_f_weight, t_transformer_wte_weight): - # idx: "cuda:0 i64[1, 512]" - # tos1: "cuda:0 f32[4096, 128]" - # t_lm_head_weight: "cuda:0 bf16[32000, 4096]" - # t_sin: "cuda:0 f32[4096, 128]" - # t_transformer_h_0_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_0_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_0_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_0_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_0_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_0_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_0_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_1_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_1_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_1_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_1_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_1_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_1_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_1_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_2_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_2_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_2_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_2_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_2_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_2_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_2_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_3_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_3_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_3_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_3_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_3_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_3_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_3_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_4_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_4_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_4_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_4_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_4_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_4_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_4_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_5_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_5_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_5_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_5_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_5_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_5_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_5_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_6_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_6_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_6_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_6_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_6_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_6_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_6_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_7_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_7_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_7_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_7_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_7_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_7_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_7_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_8_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_8_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_8_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_8_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_8_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_8_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_8_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_9_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_9_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_9_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_9_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_9_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_9_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_9_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_10_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_10_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_10_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_10_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_10_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_10_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_10_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_11_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_11_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_11_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_11_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_11_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_11_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_11_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_12_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_12_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_12_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_12_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_12_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_12_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_12_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_13_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_13_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_13_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_13_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_13_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_13_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_13_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_14_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_14_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_14_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_14_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_14_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_14_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_14_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_h_15_attn_attn_weight: "cuda:0 bf16[12288, 4096]" - # t_transformer_h_15_attn_proj_weight: "cuda:0 bf16[4096, 4096]" - # t_transformer_h_15_mlp_fc_1_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_15_mlp_fc_2_weight: "cuda:0 bf16[11008, 4096]" - # t_transformer_h_15_mlp_proj_weight: "cuda:0 bf16[4096, 11008]" - # t_transformer_h_15_norm_1_weight: "cuda:0 bf16[4096]" - # t_transformer_h_15_norm_2_weight: "cuda:0 bf16[4096]" - # t_transformer_ln_f_weight: "cuda:0 bf16[4096]" - # t_transformer_wte_weight: "cuda:0 bf16[32000, 4096]" - t0 = prims.slice_prim(tos1, [0, 0], [512, 128], [1, 1]) # t0: "cuda:0 f32[512, 128]" - t1 = prims.slice_prim(t_sin, [0, 0], [512, 128], [1, 1]) # t1: "cuda:0 f32[512, 128]" - t4 = torch.nn.functional.embedding(idx, t_transformer_wte_weight, None, None, 2.0, False, False) # t4: "cuda:0 bf16[1, 512, 4096]" - # t4 = ltorch.embedding(idx, t_transformer_wte_weight, None, None, 2.0, False, False) # t4: "cuda:0 bf16[1, 512, 4096]" - # _ = ltorch.numel(idx) - # t2319 = ltorch.reshape(idx, [512]) # t2319: "cuda:0 i64[512]" - # t2319 = prims.reshape(idx, (512,)) # t2319: "cuda:0 i64[512]" - # t2320 = prims.take(t_transformer_wte_weight, t2319, 0) # t2320: "cuda:0 bf16[512, 4096]" - # t4 = ltorch.reshape(t2320, [1, 512, 4096]) # t4: "cuda:0 bf16[1, 512, 4096]" - # t4 = prims.reshape(t2320, (1, 512, 4096)) # t4: "cuda:0 bf16[1, 512, 4096]" - t5 = prims.convert_element_type(t4, dtypes.float32) # t5: "cuda:0 f32[1, 512, 4096]" - t6 = ltorch.mul(t5, t5) # t6: "cuda:0 f32[1, 512, 4096]" - # t6 = prims.mul(t5, t5) # t6: "cuda:0 f32[1, 512, 4096]" - t7 = prims.sum(t6, (2,)) # t7: "cuda:0 f32[1, 512]" - t8 = prims.broadcast_in_dim(t7, [1, 512, 1], [0, 1]) # t8: "cuda:0 f32[1, 512, 1]" - t9 = ltorch.true_divide(t8, 4096.0) # t9: "cuda:0 f32[1, 512, 1]" - # t9 = prims.div(t8, 4096.0) # t9: "cuda:0 f32[1, 512, 1]" - t10 = ltorch.add(t9, 1e-05, alpha=None) # t10: "cuda:0 f32[1, 512, 1]" - # t10 = prims.add(t9, 1e-05) # t10: "cuda:0 f32[1, 512, 1]" - t11 = prims.rsqrt(t10) # t11: "cuda:0 f32[1, 512, 1]" - t12 = prims.broadcast_in_dim(t11, (1, 512, 4096), (0, 1, 2)) # t12: "cuda:0 f32[1, 512, 4096]" - t13 = ltorch.mul(t5, t12) # t13: "cuda:0 f32[1, 512, 4096]" - # t13 = prims.mul(t5, t12) # t13: "cuda:0 f32[1, 512, 4096]" - t14 = prims.convert_element_type(t13, dtypes.bfloat16) # t14: "cuda:0 bf16[1, 512, 4096]" - t15 = prims.broadcast_in_dim(t_transformer_h_0_norm_1_weight, (1, 512, 4096), (2,)) # t15: "cuda:0 bf16[1, 512, 4096]" - t16 = prims.convert_element_type(t14, dtypes.float32) # t16: "cuda:0 f32[1, 512, 4096]" - t17 = prims.convert_element_type(t15, dtypes.float32) # t17: "cuda:0 f32[1, 512, 4096]" - t18 = ltorch.mul(t16, t17) # t18: "cuda:0 f32[1, 512, 4096]" - # t18 = prims.mul(t16, t17) # t18: "cuda:0 f32[1, 512, 4096]" - t19 = prims.convert_element_type(t18, dtypes.bfloat16) # t19: "cuda:0 bf16[1, 512, 4096]" - t20 = torch.nn.functional.linear(t19, t_transformer_h_0_attn_attn_weight, None) # t20: "cuda:0 bf16[1, 512, 12288]" - # t20 = ltorch.linear(t19, t_transformer_h_0_attn_attn_weight, None) # t20: "cuda:0 bf16[1, 512, 12288]" - # t20 = prims.linear(t19, t_transformer_h_0_attn_attn_weight, None) # t20: "cuda:0 bf16[1, 512, 12288]" - t21 = prims.reshape(t20, (1, 512, 32, 3, 128)) # t21: "cuda:0 bf16[1, 512, 32, 3, 128]" - t22 = prims.transpose(t21, (0, 2, 3, 1, 4)) # t22: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t23, t24, t25) = ltorch.split(t22, (1, 1, 1), 2) - # t23 = prims.slice_prim(t22, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t23: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t24 = prims.slice_prim(t22, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t24: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t25 = prims.slice_prim(t22, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t25: "cuda:0 bf16[1, 32, 1, 512, 128]" - t26 = prims.reshape(t23, (1, 32, 512, 128)) # t26: "cuda:0 bf16[1, 32, 512, 128]" - t32 = prims.reshape(t24, (1, 32, 512, 128)) # t32: "cuda:0 bf16[1, 32, 512, 128]" - t38 = prims.reshape(t25, (1, 32, 512, 128)) # t38: "cuda:0 bf16[1, 32, 512, 128]" - t39 = prims.slice_prim(t26, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t39: "cuda:0 bf16[1, 32, 512, 128]" - t40 = prims.slice_prim(t39, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t40: "cuda:0 bf16[1, 32, 512, 64]" - t41 = prims.slice_prim(t39, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t41: "cuda:0 bf16[1, 32, 512, 64]" - t42 = prims.convert_element_type(t41, dtypes.float32) # t42: "cuda:0 f32[1, 32, 512, 64]" - t43 = prims.neg(t42) # t43: "cuda:0 f32[1, 32, 512, 64]" - t44 = prims.convert_element_type(t43, dtypes.bfloat16) # t44: "cuda:0 bf16[1, 32, 512, 64]" - t45 = prims.cat((t44, t40), -1) # t45: "cuda:0 bf16[1, 32, 512, 128]" - t46 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t46: "cuda:0 f32[1, 32, 512, 128]" - t47 = prims.convert_element_type(t39, dtypes.float32) # t47: "cuda:0 f32[1, 32, 512, 128]" - t48 = ltorch.mul(t47, t46) # t48: "cuda:0 f32[1, 32, 512, 128]" - # t48 = prims.mul(t47, t46) # t48: "cuda:0 f32[1, 32, 512, 128]" - t49 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t49: "cuda:0 f32[1, 32, 512, 128]" - t50 = prims.convert_element_type(t45, dtypes.float32) # t50: "cuda:0 f32[1, 32, 512, 128]" - t51 = ltorch.mul(t50, t49) # t51: "cuda:0 f32[1, 32, 512, 128]" - # t51 = prims.mul(t50, t49) # t51: "cuda:0 f32[1, 32, 512, 128]" - t52 = ltorch.add(t48, t51, alpha=None) # t52: "cuda:0 f32[1, 32, 512, 128]" - # t52 = prims.add(t48, t51) # t52: "cuda:0 f32[1, 32, 512, 128]" - t53 = prims.convert_element_type(t52, dtypes.bfloat16) # t53: "cuda:0 bf16[1, 32, 512, 128]" - t54 = prims.slice_prim(t32, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t54: "cuda:0 bf16[1, 32, 512, 128]" - t55 = prims.slice_prim(t54, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t55: "cuda:0 bf16[1, 32, 512, 64]" - t56 = prims.slice_prim(t54, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t56: "cuda:0 bf16[1, 32, 512, 64]" - t57 = prims.convert_element_type(t56, dtypes.float32) # t57: "cuda:0 f32[1, 32, 512, 64]" - t58 = prims.neg(t57) # t58: "cuda:0 f32[1, 32, 512, 64]" - t59 = prims.convert_element_type(t58, dtypes.bfloat16) # t59: "cuda:0 bf16[1, 32, 512, 64]" - t61 = prims.cat((t59, t55), -1) # t61: "cuda:0 bf16[1, 32, 512, 128]" - t62 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t62: "cuda:0 f32[1, 32, 512, 128]" - t63 = prims.convert_element_type(t54, dtypes.float32) # t63: "cuda:0 f32[1, 32, 512, 128]" - t64 = ltorch.mul(t63, t62) # t64: "cuda:0 f32[1, 32, 512, 128]" - # t64 = prims.mul(t63, t62) # t64: "cuda:0 f32[1, 32, 512, 128]" - t65 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t65: "cuda:0 f32[1, 32, 512, 128]" - t66 = prims.convert_element_type(t61, dtypes.float32) # t66: "cuda:0 f32[1, 32, 512, 128]" - t67 = ltorch.mul(t66, t65) # t67: "cuda:0 f32[1, 32, 512, 128]" - # t67 = prims.mul(t66, t65) # t67: "cuda:0 f32[1, 32, 512, 128]" - t68 = ltorch.add(t64, t67, alpha=None) # t68: "cuda:0 f32[1, 32, 512, 128]" - # t68 = prims.add(t64, t67) # t68: "cuda:0 f32[1, 32, 512, 128]" - t69 = prims.convert_element_type(t68, dtypes.bfloat16) # t69: "cuda:0 bf16[1, 32, 512, 128]" - t70 = prims.slice_prim(t26, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t70: "cuda:0 bf16[1, 32, 512, 0]" - t71 = prims.cat((t53, t70), -1) # t71: "cuda:0 bf16[1, 32, 512, 128]" - t72 = prims.slice_prim(t32, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t72: "cuda:0 bf16[1, 32, 512, 0]" - t74 = prims.cat((t69, t72), -1) # t74: "cuda:0 bf16[1, 32, 512, 128]" - (t75, t76, t77, t78) = cudnn_sdpa_fwd(t71, t74, t38, None, 0.0, True, scale=0.08838834764831843) - t79 = prims.transpose(t75, (0, 2, 1, 3)) # t79: "cuda:0 bf16[1, 512, 32, 128]" - t80 = prims.reshape(t79, (1, 512, 4096)) # t80: "cuda:0 bf16[1, 512, 4096]" - t81 = torch.nn.functional.linear(t80, t_transformer_h_0_attn_proj_weight, None) # t81: "cuda:0 bf16[1, 512, 4096]" - # t81 = ltorch.linear(t80, t_transformer_h_0_attn_proj_weight, None) # t81: "cuda:0 bf16[1, 512, 4096]" - # t81 = prims.linear(t80, t_transformer_h_0_attn_proj_weight, None) # t81: "cuda:0 bf16[1, 512, 4096]" - t82 = prims.convert_element_type(t81, dtypes.float32) # t82: "cuda:0 f32[1, 512, 4096]" - t83 = prims.convert_element_type(t4, dtypes.float32) # t83: "cuda:0 f32[1, 512, 4096]" - t84 = ltorch.add(t82, t83, alpha=None) # t84: "cuda:0 f32[1, 512, 4096]" - # t84 = prims.add(t82, t83) # t84: "cuda:0 f32[1, 512, 4096]" - t85 = prims.convert_element_type(t84, dtypes.bfloat16) # t85: "cuda:0 bf16[1, 512, 4096]" - t86 = prims.convert_element_type(t85, dtypes.float32) # t86: "cuda:0 f32[1, 512, 4096]" - t87 = ltorch.mul(t86, t86) # t87: "cuda:0 f32[1, 512, 4096]" - # t87 = prims.mul(t86, t86) # t87: "cuda:0 f32[1, 512, 4096]" - t89 = prims.sum(t87, (2,)) # t89: "cuda:0 f32[1, 512]" - t90 = prims.broadcast_in_dim(t89, [1, 512, 1], [0, 1]) # t90: "cuda:0 f32[1, 512, 1]" - t92 = ltorch.true_divide(t90, 4096.0) # t92: "cuda:0 f32[1, 512, 1]" - # t92 = prims.div(t90, 4096.0) # t92: "cuda:0 f32[1, 512, 1]" - t94 = ltorch.add(t92, 1e-05, alpha=None) # t94: "cuda:0 f32[1, 512, 1]" - # t94 = prims.add(t92, 1e-05) # t94: "cuda:0 f32[1, 512, 1]" - t95 = prims.rsqrt(t94) # t95: "cuda:0 f32[1, 512, 1]" - t96 = prims.broadcast_in_dim(t95, (1, 512, 4096), (0, 1, 2)) # t96: "cuda:0 f32[1, 512, 4096]" - t97 = ltorch.mul(t86, t96) # t97: "cuda:0 f32[1, 512, 4096]" - # t97 = prims.mul(t86, t96) # t97: "cuda:0 f32[1, 512, 4096]" - t98 = prims.convert_element_type(t97, dtypes.bfloat16) # t98: "cuda:0 bf16[1, 512, 4096]" - t99 = prims.broadcast_in_dim(t_transformer_h_0_norm_2_weight, (1, 512, 4096), (2,)) # t99: "cuda:0 bf16[1, 512, 4096]" - t100 = prims.convert_element_type(t98, dtypes.float32) # t100: "cuda:0 f32[1, 512, 4096]" - t101 = prims.convert_element_type(t99, dtypes.float32) # t101: "cuda:0 f32[1, 512, 4096]" - t102 = ltorch.mul(t100, t101) # t102: "cuda:0 f32[1, 512, 4096]" - # t102 = prims.mul(t100, t101) # t102: "cuda:0 f32[1, 512, 4096]" - t103 = prims.convert_element_type(t102, dtypes.bfloat16) # t103: "cuda:0 bf16[1, 512, 4096]" - t104 = torch.nn.functional.linear(t103, t_transformer_h_0_mlp_fc_1_weight, None) # t104: "cuda:0 bf16[1, 512, 11008]" - # t104 = ltorch.linear(t103, t_transformer_h_0_mlp_fc_1_weight, None) # t104: "cuda:0 bf16[1, 512, 11008]" - # t104 = prims.linear(t103, t_transformer_h_0_mlp_fc_1_weight, None) # t104: "cuda:0 bf16[1, 512, 11008]" - t105 = torch.nn.functional.linear(t103, t_transformer_h_0_mlp_fc_2_weight, None) # t105: "cuda:0 bf16[1, 512, 11008]" - # t105 = ltorch.linear(t103, t_transformer_h_0_mlp_fc_2_weight, None) # t105: "cuda:0 bf16[1, 512, 11008]" - # t105 = prims.linear(t103, t_transformer_h_0_mlp_fc_2_weight, None) # t105: "cuda:0 bf16[1, 512, 11008]" - t106 = prims.convert_element_type(t104, dtypes.float32) # t106: "cuda:0 f32[1, 512, 11008]" - t107 = prims.neg(t106) # t107: "cuda:0 f32[1, 512, 11008]" - t108 = prims.exp(t107) # t108: "cuda:0 f32[1, 512, 11008]" - t109 = ltorch.add(1.0, t108, alpha=None) # t109: "cuda:0 f32[1, 512, 11008]" - # t109 = prims.add(1.0, t108) # t109: "cuda:0 f32[1, 512, 11008]" - t110 = prims.reciprocal(t109) # t110: "cuda:0 f32[1, 512, 11008]" - t111 = prims.convert_element_type(t110, dtypes.bfloat16) # t111: "cuda:0 bf16[1, 512, 11008]" - t112 = prims.convert_element_type(t104, dtypes.float32) # t112: "cuda:0 f32[1, 512, 11008]" - t113 = prims.convert_element_type(t111, dtypes.float32) # t113: "cuda:0 f32[1, 512, 11008]" - t114 = ltorch.mul(t112, t113) # t114: "cuda:0 f32[1, 512, 11008]" - # t114 = prims.mul(t112, t113) # t114: "cuda:0 f32[1, 512, 11008]" - t115 = prims.convert_element_type(t114, dtypes.bfloat16) # t115: "cuda:0 bf16[1, 512, 11008]" - t116 = prims.convert_element_type(t115, dtypes.float32) # t116: "cuda:0 f32[1, 512, 11008]" - t117 = prims.convert_element_type(t105, dtypes.float32) # t117: "cuda:0 f32[1, 512, 11008]" - t118 = ltorch.mul(t116, t117) # t118: "cuda:0 f32[1, 512, 11008]" - # t118 = prims.mul(t116, t117) # t118: "cuda:0 f32[1, 512, 11008]" - t119 = prims.convert_element_type(t118, dtypes.bfloat16) # t119: "cuda:0 bf16[1, 512, 11008]" - t120 = torch.nn.functional.linear(t119, t_transformer_h_0_mlp_proj_weight, None) # t120: "cuda:0 bf16[1, 512, 4096]" - # t120 = ltorch.linear(t119, t_transformer_h_0_mlp_proj_weight, None) # t120: "cuda:0 bf16[1, 512, 4096]" - # t120 = prims.linear(t119, t_transformer_h_0_mlp_proj_weight, None) # t120: "cuda:0 bf16[1, 512, 4096]" - t121 = prims.convert_element_type(t120, dtypes.float32) # t121: "cuda:0 f32[1, 512, 4096]" - t122 = prims.convert_element_type(t85, dtypes.float32) # t122: "cuda:0 f32[1, 512, 4096]" - t123 = ltorch.add(t121, t122, alpha=None) # t123: "cuda:0 f32[1, 512, 4096]" - # t123 = prims.add(t121, t122) # t123: "cuda:0 f32[1, 512, 4096]" - t124 = prims.convert_element_type(t123, dtypes.bfloat16) # t124: "cuda:0 bf16[1, 512, 4096]" - t125 = prims.convert_element_type(t124, dtypes.float32) # t125: "cuda:0 f32[1, 512, 4096]" - t126 = ltorch.mul(t125, t125) # t126: "cuda:0 f32[1, 512, 4096]" - # t126 = prims.mul(t125, t125) # t126: "cuda:0 f32[1, 512, 4096]" - t128 = prims.sum(t126, (2,)) # t128: "cuda:0 f32[1, 512]" - t129 = prims.broadcast_in_dim(t128, [1, 512, 1], [0, 1]) # t129: "cuda:0 f32[1, 512, 1]" - t131 = ltorch.true_divide(t129, 4096.0) # t131: "cuda:0 f32[1, 512, 1]" - # t131 = prims.div(t129, 4096.0) # t131: "cuda:0 f32[1, 512, 1]" - t133 = ltorch.add(t131, 1e-05, alpha=None) # t133: "cuda:0 f32[1, 512, 1]" - # t133 = prims.add(t131, 1e-05) # t133: "cuda:0 f32[1, 512, 1]" - t134 = prims.rsqrt(t133) # t134: "cuda:0 f32[1, 512, 1]" - t135 = prims.broadcast_in_dim(t134, (1, 512, 4096), (0, 1, 2)) # t135: "cuda:0 f32[1, 512, 4096]" - t136 = ltorch.mul(t125, t135) # t136: "cuda:0 f32[1, 512, 4096]" - # t136 = prims.mul(t125, t135) # t136: "cuda:0 f32[1, 512, 4096]" - t137 = prims.convert_element_type(t136, dtypes.bfloat16) # t137: "cuda:0 bf16[1, 512, 4096]" - t138 = prims.broadcast_in_dim(t_transformer_h_1_norm_1_weight, (1, 512, 4096), (2,)) # t138: "cuda:0 bf16[1, 512, 4096]" - t139 = prims.convert_element_type(t137, dtypes.float32) # t139: "cuda:0 f32[1, 512, 4096]" - t140 = prims.convert_element_type(t138, dtypes.float32) # t140: "cuda:0 f32[1, 512, 4096]" - t141 = ltorch.mul(t139, t140) # t141: "cuda:0 f32[1, 512, 4096]" - # t141 = prims.mul(t139, t140) # t141: "cuda:0 f32[1, 512, 4096]" - t142 = prims.convert_element_type(t141, dtypes.bfloat16) # t142: "cuda:0 bf16[1, 512, 4096]" - t143 = torch.nn.functional.linear(t142, t_transformer_h_1_attn_attn_weight, None) # t143: "cuda:0 bf16[1, 512, 12288]" - # t143 = ltorch.linear(t142, t_transformer_h_1_attn_attn_weight, None) # t143: "cuda:0 bf16[1, 512, 12288]" - # t143 = prims.linear(t142, t_transformer_h_1_attn_attn_weight, None) # t143: "cuda:0 bf16[1, 512, 12288]" - t149 = prims.reshape(t143, (1, 512, 32, 3, 128)) # t149: "cuda:0 bf16[1, 512, 32, 3, 128]" - t155 = prims.transpose(t149, (0, 2, 3, 1, 4)) # t155: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t156, t157, t158) = ltorch.split(t155, (1, 1, 1), 2) - # t156 = prims.slice_prim(t155, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t156: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t157 = prims.slice_prim(t155, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t157: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t158 = prims.slice_prim(t155, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t158: "cuda:0 bf16[1, 32, 1, 512, 128]" - t164 = prims.reshape(t156, (1, 32, 512, 128)) # t164: "cuda:0 bf16[1, 32, 512, 128]" - t170 = prims.reshape(t157, (1, 32, 512, 128)) # t170: "cuda:0 bf16[1, 32, 512, 128]" - t176 = prims.reshape(t158, (1, 32, 512, 128)) # t176: "cuda:0 bf16[1, 32, 512, 128]" - t177 = prims.slice_prim(t164, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t177: "cuda:0 bf16[1, 32, 512, 128]" - t178 = prims.slice_prim(t177, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t178: "cuda:0 bf16[1, 32, 512, 64]" - t179 = prims.slice_prim(t177, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t179: "cuda:0 bf16[1, 32, 512, 64]" - t180 = prims.convert_element_type(t179, dtypes.float32) # t180: "cuda:0 f32[1, 32, 512, 64]" - t181 = prims.neg(t180) # t181: "cuda:0 f32[1, 32, 512, 64]" - t182 = prims.convert_element_type(t181, dtypes.bfloat16) # t182: "cuda:0 bf16[1, 32, 512, 64]" - t184 = prims.cat((t182, t178), -1) # t184: "cuda:0 bf16[1, 32, 512, 128]" - t185 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t185: "cuda:0 f32[1, 32, 512, 128]" - t186 = prims.convert_element_type(t177, dtypes.float32) # t186: "cuda:0 f32[1, 32, 512, 128]" - t187 = ltorch.mul(t186, t185) # t187: "cuda:0 f32[1, 32, 512, 128]" - # t187 = prims.mul(t186, t185) # t187: "cuda:0 f32[1, 32, 512, 128]" - t188 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t188: "cuda:0 f32[1, 32, 512, 128]" - t189 = prims.convert_element_type(t184, dtypes.float32) # t189: "cuda:0 f32[1, 32, 512, 128]" - t190 = ltorch.mul(t189, t188) # t190: "cuda:0 f32[1, 32, 512, 128]" - # t190 = prims.mul(t189, t188) # t190: "cuda:0 f32[1, 32, 512, 128]" - t191 = ltorch.add(t187, t190, alpha=None) # t191: "cuda:0 f32[1, 32, 512, 128]" - # t191 = prims.add(t187, t190) # t191: "cuda:0 f32[1, 32, 512, 128]" - t192 = prims.convert_element_type(t191, dtypes.bfloat16) # t192: "cuda:0 bf16[1, 32, 512, 128]" - t193 = prims.slice_prim(t170, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t193: "cuda:0 bf16[1, 32, 512, 128]" - t194 = prims.slice_prim(t193, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t194: "cuda:0 bf16[1, 32, 512, 64]" - t195 = prims.slice_prim(t193, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t195: "cuda:0 bf16[1, 32, 512, 64]" - t196 = prims.convert_element_type(t195, dtypes.float32) # t196: "cuda:0 f32[1, 32, 512, 64]" - t197 = prims.neg(t196) # t197: "cuda:0 f32[1, 32, 512, 64]" - t198 = prims.convert_element_type(t197, dtypes.bfloat16) # t198: "cuda:0 bf16[1, 32, 512, 64]" - t200 = prims.cat((t198, t194), -1) # t200: "cuda:0 bf16[1, 32, 512, 128]" - t201 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t201: "cuda:0 f32[1, 32, 512, 128]" - t202 = prims.convert_element_type(t193, dtypes.float32) # t202: "cuda:0 f32[1, 32, 512, 128]" - t203 = ltorch.mul(t202, t201) # t203: "cuda:0 f32[1, 32, 512, 128]" - # t203 = prims.mul(t202, t201) # t203: "cuda:0 f32[1, 32, 512, 128]" - t204 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t204: "cuda:0 f32[1, 32, 512, 128]" - t205 = prims.convert_element_type(t200, dtypes.float32) # t205: "cuda:0 f32[1, 32, 512, 128]" - t206 = ltorch.mul(t205, t204) # t206: "cuda:0 f32[1, 32, 512, 128]" - # t206 = prims.mul(t205, t204) # t206: "cuda:0 f32[1, 32, 512, 128]" - t207 = ltorch.add(t203, t206, alpha=None) # t207: "cuda:0 f32[1, 32, 512, 128]" - # t207 = prims.add(t203, t206) # t207: "cuda:0 f32[1, 32, 512, 128]" - t208 = prims.convert_element_type(t207, dtypes.bfloat16) # t208: "cuda:0 bf16[1, 32, 512, 128]" - t209 = prims.slice_prim(t164, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t209: "cuda:0 bf16[1, 32, 512, 0]" - t211 = prims.cat((t192, t209), -1) # t211: "cuda:0 bf16[1, 32, 512, 128]" - t212 = prims.slice_prim(t170, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t212: "cuda:0 bf16[1, 32, 512, 0]" - t214 = prims.cat((t208, t212), -1) # t214: "cuda:0 bf16[1, 32, 512, 128]" - (t215, t216, t217, t218) = cudnn_sdpa_fwd(t211, t214, t176, None, 0.0, True, scale=0.08838834764831843) - t221 = prims.transpose(t215, (0, 2, 1, 3)) # t221: "cuda:0 bf16[1, 512, 32, 128]" - t225 = prims.reshape(t221, (1, 512, 4096)) # t225: "cuda:0 bf16[1, 512, 4096]" - t226 = torch.nn.functional.linear(t225, t_transformer_h_1_attn_proj_weight, None) # t226: "cuda:0 bf16[1, 512, 4096]" - # t226 = ltorch.linear(t225, t_transformer_h_1_attn_proj_weight, None) # t226: "cuda:0 bf16[1, 512, 4096]" - # t226 = prims.linear(t225, t_transformer_h_1_attn_proj_weight, None) # t226: "cuda:0 bf16[1, 512, 4096]" - t227 = prims.convert_element_type(t226, dtypes.float32) # t227: "cuda:0 f32[1, 512, 4096]" - t228 = prims.convert_element_type(t124, dtypes.float32) # t228: "cuda:0 f32[1, 512, 4096]" - t229 = ltorch.add(t227, t228, alpha=None) # t229: "cuda:0 f32[1, 512, 4096]" - # t229 = prims.add(t227, t228) # t229: "cuda:0 f32[1, 512, 4096]" - t230 = prims.convert_element_type(t229, dtypes.bfloat16) # t230: "cuda:0 bf16[1, 512, 4096]" - t231 = prims.convert_element_type(t230, dtypes.float32) # t231: "cuda:0 f32[1, 512, 4096]" - t232 = ltorch.mul(t231, t231) # t232: "cuda:0 f32[1, 512, 4096]" - # t232 = prims.mul(t231, t231) # t232: "cuda:0 f32[1, 512, 4096]" - t234 = prims.sum(t232, (2,)) # t234: "cuda:0 f32[1, 512]" - t235 = prims.broadcast_in_dim(t234, [1, 512, 1], [0, 1]) # t235: "cuda:0 f32[1, 512, 1]" - t237 = ltorch.true_divide(t235, 4096.0) # t237: "cuda:0 f32[1, 512, 1]" - # t237 = prims.div(t235, 4096.0) # t237: "cuda:0 f32[1, 512, 1]" - t239 = ltorch.add(t237, 1e-05, alpha=None) # t239: "cuda:0 f32[1, 512, 1]" - # t239 = prims.add(t237, 1e-05) # t239: "cuda:0 f32[1, 512, 1]" - t240 = prims.rsqrt(t239) # t240: "cuda:0 f32[1, 512, 1]" - t241 = prims.broadcast_in_dim(t240, (1, 512, 4096), (0, 1, 2)) # t241: "cuda:0 f32[1, 512, 4096]" - t242 = ltorch.mul(t231, t241) # t242: "cuda:0 f32[1, 512, 4096]" - # t242 = prims.mul(t231, t241) # t242: "cuda:0 f32[1, 512, 4096]" - t243 = prims.convert_element_type(t242, dtypes.bfloat16) # t243: "cuda:0 bf16[1, 512, 4096]" - t244 = prims.broadcast_in_dim(t_transformer_h_1_norm_2_weight, (1, 512, 4096), (2,)) # t244: "cuda:0 bf16[1, 512, 4096]" - t245 = prims.convert_element_type(t243, dtypes.float32) # t245: "cuda:0 f32[1, 512, 4096]" - t246 = prims.convert_element_type(t244, dtypes.float32) # t246: "cuda:0 f32[1, 512, 4096]" - t247 = ltorch.mul(t245, t246) # t247: "cuda:0 f32[1, 512, 4096]" - # t247 = prims.mul(t245, t246) # t247: "cuda:0 f32[1, 512, 4096]" - t248 = prims.convert_element_type(t247, dtypes.bfloat16) # t248: "cuda:0 bf16[1, 512, 4096]" - t249 = torch.nn.functional.linear(t248, t_transformer_h_1_mlp_fc_1_weight, None) # t249: "cuda:0 bf16[1, 512, 11008]" - # t249 = ltorch.linear(t248, t_transformer_h_1_mlp_fc_1_weight, None) # t249: "cuda:0 bf16[1, 512, 11008]" - # t249 = prims.linear(t248, t_transformer_h_1_mlp_fc_1_weight, None) # t249: "cuda:0 bf16[1, 512, 11008]" - t250 = torch.nn.functional.linear(t248, t_transformer_h_1_mlp_fc_2_weight, None) # t250: "cuda:0 bf16[1, 512, 11008]" - # t250 = ltorch.linear(t248, t_transformer_h_1_mlp_fc_2_weight, None) # t250: "cuda:0 bf16[1, 512, 11008]" - # t250 = prims.linear(t248, t_transformer_h_1_mlp_fc_2_weight, None) # t250: "cuda:0 bf16[1, 512, 11008]" - t251 = prims.convert_element_type(t249, dtypes.float32) # t251: "cuda:0 f32[1, 512, 11008]" - t252 = prims.neg(t251) # t252: "cuda:0 f32[1, 512, 11008]" - t253 = prims.exp(t252) # t253: "cuda:0 f32[1, 512, 11008]" - t254 = ltorch.add(1.0, t253, alpha=None) # t254: "cuda:0 f32[1, 512, 11008]" - # t254 = prims.add(1.0, t253) # t254: "cuda:0 f32[1, 512, 11008]" - t255 = prims.reciprocal(t254) # t255: "cuda:0 f32[1, 512, 11008]" - t256 = prims.convert_element_type(t255, dtypes.bfloat16) # t256: "cuda:0 bf16[1, 512, 11008]" - t257 = prims.convert_element_type(t249, dtypes.float32) # t257: "cuda:0 f32[1, 512, 11008]" - t258 = prims.convert_element_type(t256, dtypes.float32) # t258: "cuda:0 f32[1, 512, 11008]" - t259 = ltorch.mul(t257, t258) # t259: "cuda:0 f32[1, 512, 11008]" - # t259 = prims.mul(t257, t258) # t259: "cuda:0 f32[1, 512, 11008]" - t260 = prims.convert_element_type(t259, dtypes.bfloat16) # t260: "cuda:0 bf16[1, 512, 11008]" - t261 = prims.convert_element_type(t260, dtypes.float32) # t261: "cuda:0 f32[1, 512, 11008]" - t262 = prims.convert_element_type(t250, dtypes.float32) # t262: "cuda:0 f32[1, 512, 11008]" - t263 = ltorch.mul(t261, t262) # t263: "cuda:0 f32[1, 512, 11008]" - # t263 = prims.mul(t261, t262) # t263: "cuda:0 f32[1, 512, 11008]" - t264 = prims.convert_element_type(t263, dtypes.bfloat16) # t264: "cuda:0 bf16[1, 512, 11008]" - t265 = torch.nn.functional.linear(t264, t_transformer_h_1_mlp_proj_weight, None) # t265: "cuda:0 bf16[1, 512, 4096]" - # t265 = ltorch.linear(t264, t_transformer_h_1_mlp_proj_weight, None) # t265: "cuda:0 bf16[1, 512, 4096]" - # t265 = prims.linear(t264, t_transformer_h_1_mlp_proj_weight, None) # t265: "cuda:0 bf16[1, 512, 4096]" - t266 = prims.convert_element_type(t265, dtypes.float32) # t266: "cuda:0 f32[1, 512, 4096]" - t267 = prims.convert_element_type(t230, dtypes.float32) # t267: "cuda:0 f32[1, 512, 4096]" - t268 = ltorch.add(t266, t267, alpha=None) # t268: "cuda:0 f32[1, 512, 4096]" - # t268 = prims.add(t266, t267) # t268: "cuda:0 f32[1, 512, 4096]" - t269 = prims.convert_element_type(t268, dtypes.bfloat16) # t269: "cuda:0 bf16[1, 512, 4096]" - t270 = prims.convert_element_type(t269, dtypes.float32) # t270: "cuda:0 f32[1, 512, 4096]" - t271 = ltorch.mul(t270, t270) # t271: "cuda:0 f32[1, 512, 4096]" - # t271 = prims.mul(t270, t270) # t271: "cuda:0 f32[1, 512, 4096]" - t273 = prims.sum(t271, (2,)) # t273: "cuda:0 f32[1, 512]" - t274 = prims.broadcast_in_dim(t273, [1, 512, 1], [0, 1]) # t274: "cuda:0 f32[1, 512, 1]" - t276 = ltorch.true_divide(t274, 4096.0) # t276: "cuda:0 f32[1, 512, 1]" - # t276 = prims.div(t274, 4096.0) # t276: "cuda:0 f32[1, 512, 1]" - t278 = ltorch.add(t276, 1e-05, alpha=None) # t278: "cuda:0 f32[1, 512, 1]" - # t278 = prims.add(t276, 1e-05) # t278: "cuda:0 f32[1, 512, 1]" - t279 = prims.rsqrt(t278) # t279: "cuda:0 f32[1, 512, 1]" - t280 = prims.broadcast_in_dim(t279, (1, 512, 4096), (0, 1, 2)) # t280: "cuda:0 f32[1, 512, 4096]" - t281 = ltorch.mul(t270, t280) # t281: "cuda:0 f32[1, 512, 4096]" - # t281 = prims.mul(t270, t280) # t281: "cuda:0 f32[1, 512, 4096]" - t282 = prims.convert_element_type(t281, dtypes.bfloat16) # t282: "cuda:0 bf16[1, 512, 4096]" - t283 = prims.broadcast_in_dim(t_transformer_h_2_norm_1_weight, (1, 512, 4096), (2,)) # t283: "cuda:0 bf16[1, 512, 4096]" - t284 = prims.convert_element_type(t282, dtypes.float32) # t284: "cuda:0 f32[1, 512, 4096]" - t285 = prims.convert_element_type(t283, dtypes.float32) # t285: "cuda:0 f32[1, 512, 4096]" - t286 = ltorch.mul(t284, t285) # t286: "cuda:0 f32[1, 512, 4096]" - # t286 = prims.mul(t284, t285) # t286: "cuda:0 f32[1, 512, 4096]" - t287 = prims.convert_element_type(t286, dtypes.bfloat16) # t287: "cuda:0 bf16[1, 512, 4096]" - t288 = torch.nn.functional.linear(t287, t_transformer_h_2_attn_attn_weight, None) # t288: "cuda:0 bf16[1, 512, 12288]" - # t288 = ltorch.linear(t287, t_transformer_h_2_attn_attn_weight, None) # t288: "cuda:0 bf16[1, 512, 12288]" - # t288 = prims.linear(t287, t_transformer_h_2_attn_attn_weight, None) # t288: "cuda:0 bf16[1, 512, 12288]" - t294 = prims.reshape(t288, (1, 512, 32, 3, 128)) # t294: "cuda:0 bf16[1, 512, 32, 3, 128]" - t300 = prims.transpose(t294, (0, 2, 3, 1, 4)) # t300: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t301, t302, t303) = ltorch.split(t300, (1, 1, 1), 2) - # t301 = prims.slice_prim(t300, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t301: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t302 = prims.slice_prim(t300, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t302: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t303 = prims.slice_prim(t300, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t303: "cuda:0 bf16[1, 32, 1, 512, 128]" - t309 = prims.reshape(t301, (1, 32, 512, 128)) # t309: "cuda:0 bf16[1, 32, 512, 128]" - t315 = prims.reshape(t302, (1, 32, 512, 128)) # t315: "cuda:0 bf16[1, 32, 512, 128]" - t321 = prims.reshape(t303, (1, 32, 512, 128)) # t321: "cuda:0 bf16[1, 32, 512, 128]" - t322 = prims.slice_prim(t309, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t322: "cuda:0 bf16[1, 32, 512, 128]" - t323 = prims.slice_prim(t322, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t323: "cuda:0 bf16[1, 32, 512, 64]" - t324 = prims.slice_prim(t322, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t324: "cuda:0 bf16[1, 32, 512, 64]" - t325 = prims.convert_element_type(t324, dtypes.float32) # t325: "cuda:0 f32[1, 32, 512, 64]" - t326 = prims.neg(t325) # t326: "cuda:0 f32[1, 32, 512, 64]" - t327 = prims.convert_element_type(t326, dtypes.bfloat16) # t327: "cuda:0 bf16[1, 32, 512, 64]" - t329 = prims.cat((t327, t323), -1) # t329: "cuda:0 bf16[1, 32, 512, 128]" - t330 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t330: "cuda:0 f32[1, 32, 512, 128]" - t331 = prims.convert_element_type(t322, dtypes.float32) # t331: "cuda:0 f32[1, 32, 512, 128]" - t332 = ltorch.mul(t331, t330) # t332: "cuda:0 f32[1, 32, 512, 128]" - # t332 = prims.mul(t331, t330) # t332: "cuda:0 f32[1, 32, 512, 128]" - t333 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t333: "cuda:0 f32[1, 32, 512, 128]" - t334 = prims.convert_element_type(t329, dtypes.float32) # t334: "cuda:0 f32[1, 32, 512, 128]" - t335 = ltorch.mul(t334, t333) # t335: "cuda:0 f32[1, 32, 512, 128]" - # t335 = prims.mul(t334, t333) # t335: "cuda:0 f32[1, 32, 512, 128]" - t336 = ltorch.add(t332, t335, alpha=None) # t336: "cuda:0 f32[1, 32, 512, 128]" - # t336 = prims.add(t332, t335) # t336: "cuda:0 f32[1, 32, 512, 128]" - t337 = prims.convert_element_type(t336, dtypes.bfloat16) # t337: "cuda:0 bf16[1, 32, 512, 128]" - t338 = prims.slice_prim(t315, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t338: "cuda:0 bf16[1, 32, 512, 128]" - t339 = prims.slice_prim(t338, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t339: "cuda:0 bf16[1, 32, 512, 64]" - t340 = prims.slice_prim(t338, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t340: "cuda:0 bf16[1, 32, 512, 64]" - t341 = prims.convert_element_type(t340, dtypes.float32) # t341: "cuda:0 f32[1, 32, 512, 64]" - t342 = prims.neg(t341) # t342: "cuda:0 f32[1, 32, 512, 64]" - t343 = prims.convert_element_type(t342, dtypes.bfloat16) # t343: "cuda:0 bf16[1, 32, 512, 64]" - t345 = prims.cat((t343, t339), -1) # t345: "cuda:0 bf16[1, 32, 512, 128]" - t346 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t346: "cuda:0 f32[1, 32, 512, 128]" - t347 = prims.convert_element_type(t338, dtypes.float32) # t347: "cuda:0 f32[1, 32, 512, 128]" - t348 = ltorch.mul(t347, t346) # t348: "cuda:0 f32[1, 32, 512, 128]" - # t348 = prims.mul(t347, t346) # t348: "cuda:0 f32[1, 32, 512, 128]" - t349 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t349: "cuda:0 f32[1, 32, 512, 128]" - t350 = prims.convert_element_type(t345, dtypes.float32) # t350: "cuda:0 f32[1, 32, 512, 128]" - t351 = ltorch.mul(t350, t349) # t351: "cuda:0 f32[1, 32, 512, 128]" - # t351 = prims.mul(t350, t349) # t351: "cuda:0 f32[1, 32, 512, 128]" - t352 = ltorch.add(t348, t351, alpha=None) # t352: "cuda:0 f32[1, 32, 512, 128]" - # t352 = prims.add(t348, t351) # t352: "cuda:0 f32[1, 32, 512, 128]" - t353 = prims.convert_element_type(t352, dtypes.bfloat16) # t353: "cuda:0 bf16[1, 32, 512, 128]" - t354 = prims.slice_prim(t309, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t354: "cuda:0 bf16[1, 32, 512, 0]" - t356 = prims.cat((t337, t354), -1) # t356: "cuda:0 bf16[1, 32, 512, 128]" - t357 = prims.slice_prim(t315, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t357: "cuda:0 bf16[1, 32, 512, 0]" - t359 = prims.cat((t353, t357), -1) # t359: "cuda:0 bf16[1, 32, 512, 128]" - (t360, t361, t362, t363) = cudnn_sdpa_fwd(t356, t359, t321, None, 0.0, True, scale=0.08838834764831843) - t366 = prims.transpose(t360, (0, 2, 1, 3)) # t366: "cuda:0 bf16[1, 512, 32, 128]" - t370 = prims.reshape(t366, (1, 512, 4096)) # t370: "cuda:0 bf16[1, 512, 4096]" - t371 = torch.nn.functional.linear(t370, t_transformer_h_2_attn_proj_weight, None) # t371: "cuda:0 bf16[1, 512, 4096]" - # t371 = ltorch.linear(t370, t_transformer_h_2_attn_proj_weight, None) # t371: "cuda:0 bf16[1, 512, 4096]" - # t371 = prims.linear(t370, t_transformer_h_2_attn_proj_weight, None) # t371: "cuda:0 bf16[1, 512, 4096]" - t372 = prims.convert_element_type(t371, dtypes.float32) # t372: "cuda:0 f32[1, 512, 4096]" - t373 = prims.convert_element_type(t269, dtypes.float32) # t373: "cuda:0 f32[1, 512, 4096]" - t374 = ltorch.add(t372, t373, alpha=None) # t374: "cuda:0 f32[1, 512, 4096]" - # t374 = prims.add(t372, t373) # t374: "cuda:0 f32[1, 512, 4096]" - t375 = prims.convert_element_type(t374, dtypes.bfloat16) # t375: "cuda:0 bf16[1, 512, 4096]" - t376 = prims.convert_element_type(t375, dtypes.float32) # t376: "cuda:0 f32[1, 512, 4096]" - t377 = ltorch.mul(t376, t376) # t377: "cuda:0 f32[1, 512, 4096]" - # t377 = prims.mul(t376, t376) # t377: "cuda:0 f32[1, 512, 4096]" - t379 = prims.sum(t377, (2,)) # t379: "cuda:0 f32[1, 512]" - t380 = prims.broadcast_in_dim(t379, [1, 512, 1], [0, 1]) # t380: "cuda:0 f32[1, 512, 1]" - t382 = ltorch.true_divide(t380, 4096.0) # t382: "cuda:0 f32[1, 512, 1]" - # t382 = prims.div(t380, 4096.0) # t382: "cuda:0 f32[1, 512, 1]" - t384 = ltorch.add(t382, 1e-05, alpha=None) # t384: "cuda:0 f32[1, 512, 1]" - # t384 = prims.add(t382, 1e-05) # t384: "cuda:0 f32[1, 512, 1]" - t385 = prims.rsqrt(t384) # t385: "cuda:0 f32[1, 512, 1]" - t386 = prims.broadcast_in_dim(t385, (1, 512, 4096), (0, 1, 2)) # t386: "cuda:0 f32[1, 512, 4096]" - t387 = ltorch.mul(t376, t386) # t387: "cuda:0 f32[1, 512, 4096]" - # t387 = prims.mul(t376, t386) # t387: "cuda:0 f32[1, 512, 4096]" - t388 = prims.convert_element_type(t387, dtypes.bfloat16) # t388: "cuda:0 bf16[1, 512, 4096]" - t389 = prims.broadcast_in_dim(t_transformer_h_2_norm_2_weight, (1, 512, 4096), (2,)) # t389: "cuda:0 bf16[1, 512, 4096]" - t390 = prims.convert_element_type(t388, dtypes.float32) # t390: "cuda:0 f32[1, 512, 4096]" - t391 = prims.convert_element_type(t389, dtypes.float32) # t391: "cuda:0 f32[1, 512, 4096]" - t392 = ltorch.mul(t390, t391) # t392: "cuda:0 f32[1, 512, 4096]" - # t392 = prims.mul(t390, t391) # t392: "cuda:0 f32[1, 512, 4096]" - t393 = prims.convert_element_type(t392, dtypes.bfloat16) # t393: "cuda:0 bf16[1, 512, 4096]" - t394 = torch.nn.functional.linear(t393, t_transformer_h_2_mlp_fc_1_weight, None) # t394: "cuda:0 bf16[1, 512, 11008]" - # t394 = ltorch.linear(t393, t_transformer_h_2_mlp_fc_1_weight, None) # t394: "cuda:0 bf16[1, 512, 11008]" - # t394 = prims.linear(t393, t_transformer_h_2_mlp_fc_1_weight, None) # t394: "cuda:0 bf16[1, 512, 11008]" - t395 = torch.nn.functional.linear(t393, t_transformer_h_2_mlp_fc_2_weight, None) # t395: "cuda:0 bf16[1, 512, 11008]" - # t395 = ltorch.linear(t393, t_transformer_h_2_mlp_fc_2_weight, None) # t395: "cuda:0 bf16[1, 512, 11008]" - # t395 = prims.linear(t393, t_transformer_h_2_mlp_fc_2_weight, None) # t395: "cuda:0 bf16[1, 512, 11008]" - t396 = prims.convert_element_type(t394, dtypes.float32) # t396: "cuda:0 f32[1, 512, 11008]" - t397 = prims.neg(t396) # t397: "cuda:0 f32[1, 512, 11008]" - t398 = prims.exp(t397) # t398: "cuda:0 f32[1, 512, 11008]" - t399 = ltorch.add(1.0, t398, alpha=None) # t399: "cuda:0 f32[1, 512, 11008]" - # t399 = prims.add(1.0, t398) # t399: "cuda:0 f32[1, 512, 11008]" - t400 = prims.reciprocal(t399) # t400: "cuda:0 f32[1, 512, 11008]" - t401 = prims.convert_element_type(t400, dtypes.bfloat16) # t401: "cuda:0 bf16[1, 512, 11008]" - t402 = prims.convert_element_type(t394, dtypes.float32) # t402: "cuda:0 f32[1, 512, 11008]" - t403 = prims.convert_element_type(t401, dtypes.float32) # t403: "cuda:0 f32[1, 512, 11008]" - t404 = ltorch.mul(t402, t403) # t404: "cuda:0 f32[1, 512, 11008]" - # t404 = prims.mul(t402, t403) # t404: "cuda:0 f32[1, 512, 11008]" - t405 = prims.convert_element_type(t404, dtypes.bfloat16) # t405: "cuda:0 bf16[1, 512, 11008]" - t406 = prims.convert_element_type(t405, dtypes.float32) # t406: "cuda:0 f32[1, 512, 11008]" - t407 = prims.convert_element_type(t395, dtypes.float32) # t407: "cuda:0 f32[1, 512, 11008]" - t408 = ltorch.mul(t406, t407) # t408: "cuda:0 f32[1, 512, 11008]" - # t408 = prims.mul(t406, t407) # t408: "cuda:0 f32[1, 512, 11008]" - t409 = prims.convert_element_type(t408, dtypes.bfloat16) # t409: "cuda:0 bf16[1, 512, 11008]" - t410 = torch.nn.functional.linear(t409, t_transformer_h_2_mlp_proj_weight, None) # t410: "cuda:0 bf16[1, 512, 4096]" - # t410 = ltorch.linear(t409, t_transformer_h_2_mlp_proj_weight, None) # t410: "cuda:0 bf16[1, 512, 4096]" - # t410 = prims.linear(t409, t_transformer_h_2_mlp_proj_weight, None) # t410: "cuda:0 bf16[1, 512, 4096]" - t411 = prims.convert_element_type(t410, dtypes.float32) # t411: "cuda:0 f32[1, 512, 4096]" - t412 = prims.convert_element_type(t375, dtypes.float32) # t412: "cuda:0 f32[1, 512, 4096]" - t413 = ltorch.add(t411, t412, alpha=None) # t413: "cuda:0 f32[1, 512, 4096]" - # t413 = prims.add(t411, t412) # t413: "cuda:0 f32[1, 512, 4096]" - t414 = prims.convert_element_type(t413, dtypes.bfloat16) # t414: "cuda:0 bf16[1, 512, 4096]" - t415 = prims.convert_element_type(t414, dtypes.float32) # t415: "cuda:0 f32[1, 512, 4096]" - t416 = ltorch.mul(t415, t415) # t416: "cuda:0 f32[1, 512, 4096]" - # t416 = prims.mul(t415, t415) # t416: "cuda:0 f32[1, 512, 4096]" - t418 = prims.sum(t416, (2,)) # t418: "cuda:0 f32[1, 512]" - t419 = prims.broadcast_in_dim(t418, [1, 512, 1], [0, 1]) # t419: "cuda:0 f32[1, 512, 1]" - t421 = ltorch.true_divide(t419, 4096.0) # t421: "cuda:0 f32[1, 512, 1]" - # t421 = prims.div(t419, 4096.0) # t421: "cuda:0 f32[1, 512, 1]" - t423 = ltorch.add(t421, 1e-05, alpha=None) # t423: "cuda:0 f32[1, 512, 1]" - # t423 = prims.add(t421, 1e-05) # t423: "cuda:0 f32[1, 512, 1]" - t424 = prims.rsqrt(t423) # t424: "cuda:0 f32[1, 512, 1]" - t425 = prims.broadcast_in_dim(t424, (1, 512, 4096), (0, 1, 2)) # t425: "cuda:0 f32[1, 512, 4096]" - t426 = ltorch.mul(t415, t425) # t426: "cuda:0 f32[1, 512, 4096]" - # t426 = prims.mul(t415, t425) # t426: "cuda:0 f32[1, 512, 4096]" - t427 = prims.convert_element_type(t426, dtypes.bfloat16) # t427: "cuda:0 bf16[1, 512, 4096]" - t428 = prims.broadcast_in_dim(t_transformer_h_3_norm_1_weight, (1, 512, 4096), (2,)) # t428: "cuda:0 bf16[1, 512, 4096]" - t429 = prims.convert_element_type(t427, dtypes.float32) # t429: "cuda:0 f32[1, 512, 4096]" - t430 = prims.convert_element_type(t428, dtypes.float32) # t430: "cuda:0 f32[1, 512, 4096]" - t431 = ltorch.mul(t429, t430) # t431: "cuda:0 f32[1, 512, 4096]" - # t431 = prims.mul(t429, t430) # t431: "cuda:0 f32[1, 512, 4096]" - t432 = prims.convert_element_type(t431, dtypes.bfloat16) # t432: "cuda:0 bf16[1, 512, 4096]" - t433 = torch.nn.functional.linear(t432, t_transformer_h_3_attn_attn_weight, None) # t433: "cuda:0 bf16[1, 512, 12288]" - # t433 = ltorch.linear(t432, t_transformer_h_3_attn_attn_weight, None) # t433: "cuda:0 bf16[1, 512, 12288]" - # t433 = prims.linear(t432, t_transformer_h_3_attn_attn_weight, None) # t433: "cuda:0 bf16[1, 512, 12288]" - t439 = prims.reshape(t433, (1, 512, 32, 3, 128)) # t439: "cuda:0 bf16[1, 512, 32, 3, 128]" - t445 = prims.transpose(t439, (0, 2, 3, 1, 4)) # t445: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t446, t447, t448) = ltorch.split(t445, (1, 1, 1), 2) - # t446 = prims.slice_prim(t445, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t446: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t447 = prims.slice_prim(t445, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t447: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t448 = prims.slice_prim(t445, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t448: "cuda:0 bf16[1, 32, 1, 512, 128]" - t454 = prims.reshape(t446, (1, 32, 512, 128)) # t454: "cuda:0 bf16[1, 32, 512, 128]" - t460 = prims.reshape(t447, (1, 32, 512, 128)) # t460: "cuda:0 bf16[1, 32, 512, 128]" - t466 = prims.reshape(t448, (1, 32, 512, 128)) # t466: "cuda:0 bf16[1, 32, 512, 128]" - t467 = prims.slice_prim(t454, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t467: "cuda:0 bf16[1, 32, 512, 128]" - t468 = prims.slice_prim(t467, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t468: "cuda:0 bf16[1, 32, 512, 64]" - t469 = prims.slice_prim(t467, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t469: "cuda:0 bf16[1, 32, 512, 64]" - t470 = prims.convert_element_type(t469, dtypes.float32) # t470: "cuda:0 f32[1, 32, 512, 64]" - t471 = prims.neg(t470) # t471: "cuda:0 f32[1, 32, 512, 64]" - t472 = prims.convert_element_type(t471, dtypes.bfloat16) # t472: "cuda:0 bf16[1, 32, 512, 64]" - t474 = prims.cat((t472, t468), -1) # t474: "cuda:0 bf16[1, 32, 512, 128]" - t475 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t475: "cuda:0 f32[1, 32, 512, 128]" - t476 = prims.convert_element_type(t467, dtypes.float32) # t476: "cuda:0 f32[1, 32, 512, 128]" - t477 = ltorch.mul(t476, t475) # t477: "cuda:0 f32[1, 32, 512, 128]" - # t477 = prims.mul(t476, t475) # t477: "cuda:0 f32[1, 32, 512, 128]" - t478 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t478: "cuda:0 f32[1, 32, 512, 128]" - t479 = prims.convert_element_type(t474, dtypes.float32) # t479: "cuda:0 f32[1, 32, 512, 128]" - t480 = ltorch.mul(t479, t478) # t480: "cuda:0 f32[1, 32, 512, 128]" - # t480 = prims.mul(t479, t478) # t480: "cuda:0 f32[1, 32, 512, 128]" - t481 = ltorch.add(t477, t480, alpha=None) # t481: "cuda:0 f32[1, 32, 512, 128]" - # t481 = prims.add(t477, t480) # t481: "cuda:0 f32[1, 32, 512, 128]" - t482 = prims.convert_element_type(t481, dtypes.bfloat16) # t482: "cuda:0 bf16[1, 32, 512, 128]" - t483 = prims.slice_prim(t460, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t483: "cuda:0 bf16[1, 32, 512, 128]" - t484 = prims.slice_prim(t483, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t484: "cuda:0 bf16[1, 32, 512, 64]" - t485 = prims.slice_prim(t483, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t485: "cuda:0 bf16[1, 32, 512, 64]" - t486 = prims.convert_element_type(t485, dtypes.float32) # t486: "cuda:0 f32[1, 32, 512, 64]" - t487 = prims.neg(t486) # t487: "cuda:0 f32[1, 32, 512, 64]" - t488 = prims.convert_element_type(t487, dtypes.bfloat16) # t488: "cuda:0 bf16[1, 32, 512, 64]" - t490 = prims.cat((t488, t484), -1) # t490: "cuda:0 bf16[1, 32, 512, 128]" - t491 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t491: "cuda:0 f32[1, 32, 512, 128]" - t492 = prims.convert_element_type(t483, dtypes.float32) # t492: "cuda:0 f32[1, 32, 512, 128]" - t493 = ltorch.mul(t492, t491) # t493: "cuda:0 f32[1, 32, 512, 128]" - # t493 = prims.mul(t492, t491) # t493: "cuda:0 f32[1, 32, 512, 128]" - t494 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t494: "cuda:0 f32[1, 32, 512, 128]" - t495 = prims.convert_element_type(t490, dtypes.float32) # t495: "cuda:0 f32[1, 32, 512, 128]" - t496 = ltorch.mul(t495, t494) # t496: "cuda:0 f32[1, 32, 512, 128]" - # t496 = prims.mul(t495, t494) # t496: "cuda:0 f32[1, 32, 512, 128]" - t497 = ltorch.add(t493, t496, alpha=None) # t497: "cuda:0 f32[1, 32, 512, 128]" - # t497 = prims.add(t493, t496) # t497: "cuda:0 f32[1, 32, 512, 128]" - t498 = prims.convert_element_type(t497, dtypes.bfloat16) # t498: "cuda:0 bf16[1, 32, 512, 128]" - t499 = prims.slice_prim(t454, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t499: "cuda:0 bf16[1, 32, 512, 0]" - t501 = prims.cat((t482, t499), -1) # t501: "cuda:0 bf16[1, 32, 512, 128]" - t502 = prims.slice_prim(t460, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t502: "cuda:0 bf16[1, 32, 512, 0]" - t504 = prims.cat((t498, t502), -1) # t504: "cuda:0 bf16[1, 32, 512, 128]" - (t505, t506, t507, t508) = cudnn_sdpa_fwd(t501, t504, t466, None, 0.0, True, scale=0.08838834764831843) - t511 = prims.transpose(t505, (0, 2, 1, 3)) # t511: "cuda:0 bf16[1, 512, 32, 128]" - t515 = prims.reshape(t511, (1, 512, 4096)) # t515: "cuda:0 bf16[1, 512, 4096]" - t516 = torch.nn.functional.linear(t515, t_transformer_h_3_attn_proj_weight, None) # t516: "cuda:0 bf16[1, 512, 4096]" - # t516 = ltorch.linear(t515, t_transformer_h_3_attn_proj_weight, None) # t516: "cuda:0 bf16[1, 512, 4096]" - # t516 = prims.linear(t515, t_transformer_h_3_attn_proj_weight, None) # t516: "cuda:0 bf16[1, 512, 4096]" - t517 = prims.convert_element_type(t516, dtypes.float32) # t517: "cuda:0 f32[1, 512, 4096]" - t518 = prims.convert_element_type(t414, dtypes.float32) # t518: "cuda:0 f32[1, 512, 4096]" - t519 = ltorch.add(t517, t518, alpha=None) # t519: "cuda:0 f32[1, 512, 4096]" - # t519 = prims.add(t517, t518) # t519: "cuda:0 f32[1, 512, 4096]" - t520 = prims.convert_element_type(t519, dtypes.bfloat16) # t520: "cuda:0 bf16[1, 512, 4096]" - t521 = prims.convert_element_type(t520, dtypes.float32) # t521: "cuda:0 f32[1, 512, 4096]" - t522 = ltorch.mul(t521, t521) # t522: "cuda:0 f32[1, 512, 4096]" - # t522 = prims.mul(t521, t521) # t522: "cuda:0 f32[1, 512, 4096]" - t524 = prims.sum(t522, (2,)) # t524: "cuda:0 f32[1, 512]" - t525 = prims.broadcast_in_dim(t524, [1, 512, 1], [0, 1]) # t525: "cuda:0 f32[1, 512, 1]" - t527 = ltorch.true_divide(t525, 4096.0) # t527: "cuda:0 f32[1, 512, 1]" - # t527 = prims.div(t525, 4096.0) # t527: "cuda:0 f32[1, 512, 1]" - t529 = ltorch.add(t527, 1e-05, alpha=None) # t529: "cuda:0 f32[1, 512, 1]" - # t529 = prims.add(t527, 1e-05) # t529: "cuda:0 f32[1, 512, 1]" - t530 = prims.rsqrt(t529) # t530: "cuda:0 f32[1, 512, 1]" - t531 = prims.broadcast_in_dim(t530, (1, 512, 4096), (0, 1, 2)) # t531: "cuda:0 f32[1, 512, 4096]" - t532 = ltorch.mul(t521, t531) # t532: "cuda:0 f32[1, 512, 4096]" - # t532 = prims.mul(t521, t531) # t532: "cuda:0 f32[1, 512, 4096]" - t533 = prims.convert_element_type(t532, dtypes.bfloat16) # t533: "cuda:0 bf16[1, 512, 4096]" - t534 = prims.broadcast_in_dim(t_transformer_h_3_norm_2_weight, (1, 512, 4096), (2,)) # t534: "cuda:0 bf16[1, 512, 4096]" - t535 = prims.convert_element_type(t533, dtypes.float32) # t535: "cuda:0 f32[1, 512, 4096]" - t536 = prims.convert_element_type(t534, dtypes.float32) # t536: "cuda:0 f32[1, 512, 4096]" - t537 = ltorch.mul(t535, t536) # t537: "cuda:0 f32[1, 512, 4096]" - # t537 = prims.mul(t535, t536) # t537: "cuda:0 f32[1, 512, 4096]" - t538 = prims.convert_element_type(t537, dtypes.bfloat16) # t538: "cuda:0 bf16[1, 512, 4096]" - t539 = torch.nn.functional.linear(t538, t_transformer_h_3_mlp_fc_1_weight, None) # t539: "cuda:0 bf16[1, 512, 11008]" - # t539 = ltorch.linear(t538, t_transformer_h_3_mlp_fc_1_weight, None) # t539: "cuda:0 bf16[1, 512, 11008]" - # t539 = prims.linear(t538, t_transformer_h_3_mlp_fc_1_weight, None) # t539: "cuda:0 bf16[1, 512, 11008]" - t540 = torch.nn.functional.linear(t538, t_transformer_h_3_mlp_fc_2_weight, None) # t540: "cuda:0 bf16[1, 512, 11008]" - # t540 = ltorch.linear(t538, t_transformer_h_3_mlp_fc_2_weight, None) # t540: "cuda:0 bf16[1, 512, 11008]" - # t540 = prims.linear(t538, t_transformer_h_3_mlp_fc_2_weight, None) # t540: "cuda:0 bf16[1, 512, 11008]" - t541 = prims.convert_element_type(t539, dtypes.float32) # t541: "cuda:0 f32[1, 512, 11008]" - t542 = prims.neg(t541) # t542: "cuda:0 f32[1, 512, 11008]" - t543 = prims.exp(t542) # t543: "cuda:0 f32[1, 512, 11008]" - t544 = ltorch.add(1.0, t543, alpha=None) # t544: "cuda:0 f32[1, 512, 11008]" - # t544 = prims.add(1.0, t543) # t544: "cuda:0 f32[1, 512, 11008]" - t545 = prims.reciprocal(t544) # t545: "cuda:0 f32[1, 512, 11008]" - t546 = prims.convert_element_type(t545, dtypes.bfloat16) # t546: "cuda:0 bf16[1, 512, 11008]" - t547 = prims.convert_element_type(t539, dtypes.float32) # t547: "cuda:0 f32[1, 512, 11008]" - t548 = prims.convert_element_type(t546, dtypes.float32) # t548: "cuda:0 f32[1, 512, 11008]" - t549 = ltorch.mul(t547, t548) # t549: "cuda:0 f32[1, 512, 11008]" - # t549 = prims.mul(t547, t548) # t549: "cuda:0 f32[1, 512, 11008]" - t550 = prims.convert_element_type(t549, dtypes.bfloat16) # t550: "cuda:0 bf16[1, 512, 11008]" - t551 = prims.convert_element_type(t550, dtypes.float32) # t551: "cuda:0 f32[1, 512, 11008]" - t552 = prims.convert_element_type(t540, dtypes.float32) # t552: "cuda:0 f32[1, 512, 11008]" - t553 = ltorch.mul(t551, t552) # t553: "cuda:0 f32[1, 512, 11008]" - # t553 = prims.mul(t551, t552) # t553: "cuda:0 f32[1, 512, 11008]" - t554 = prims.convert_element_type(t553, dtypes.bfloat16) # t554: "cuda:0 bf16[1, 512, 11008]" - t555 = torch.nn.functional.linear(t554, t_transformer_h_3_mlp_proj_weight, None) # t555: "cuda:0 bf16[1, 512, 4096]" - # t555 = ltorch.linear(t554, t_transformer_h_3_mlp_proj_weight, None) # t555: "cuda:0 bf16[1, 512, 4096]" - # t555 = prims.linear(t554, t_transformer_h_3_mlp_proj_weight, None) # t555: "cuda:0 bf16[1, 512, 4096]" - t556 = prims.convert_element_type(t555, dtypes.float32) # t556: "cuda:0 f32[1, 512, 4096]" - t557 = prims.convert_element_type(t520, dtypes.float32) # t557: "cuda:0 f32[1, 512, 4096]" - t558 = ltorch.add(t556, t557, alpha=None) # t558: "cuda:0 f32[1, 512, 4096]" - # t558 = prims.add(t556, t557) # t558: "cuda:0 f32[1, 512, 4096]" - t559 = prims.convert_element_type(t558, dtypes.bfloat16) # t559: "cuda:0 bf16[1, 512, 4096]" - t560 = prims.convert_element_type(t559, dtypes.float32) # t560: "cuda:0 f32[1, 512, 4096]" - t561 = ltorch.mul(t560, t560) # t561: "cuda:0 f32[1, 512, 4096]" - # t561 = prims.mul(t560, t560) # t561: "cuda:0 f32[1, 512, 4096]" - t563 = prims.sum(t561, (2,)) # t563: "cuda:0 f32[1, 512]" - t564 = prims.broadcast_in_dim(t563, [1, 512, 1], [0, 1]) # t564: "cuda:0 f32[1, 512, 1]" - t566 = ltorch.true_divide(t564, 4096.0) # t566: "cuda:0 f32[1, 512, 1]" - # t566 = prims.div(t564, 4096.0) # t566: "cuda:0 f32[1, 512, 1]" - t568 = ltorch.add(t566, 1e-05, alpha=None) # t568: "cuda:0 f32[1, 512, 1]" - # t568 = prims.add(t566, 1e-05) # t568: "cuda:0 f32[1, 512, 1]" - t569 = prims.rsqrt(t568) # t569: "cuda:0 f32[1, 512, 1]" - t570 = prims.broadcast_in_dim(t569, (1, 512, 4096), (0, 1, 2)) # t570: "cuda:0 f32[1, 512, 4096]" - t571 = ltorch.mul(t560, t570) # t571: "cuda:0 f32[1, 512, 4096]" - # t571 = prims.mul(t560, t570) # t571: "cuda:0 f32[1, 512, 4096]" - t572 = prims.convert_element_type(t571, dtypes.bfloat16) # t572: "cuda:0 bf16[1, 512, 4096]" - t573 = prims.broadcast_in_dim(t_transformer_h_4_norm_1_weight, (1, 512, 4096), (2,)) # t573: "cuda:0 bf16[1, 512, 4096]" - t574 = prims.convert_element_type(t572, dtypes.float32) # t574: "cuda:0 f32[1, 512, 4096]" - t575 = prims.convert_element_type(t573, dtypes.float32) # t575: "cuda:0 f32[1, 512, 4096]" - t576 = ltorch.mul(t574, t575) # t576: "cuda:0 f32[1, 512, 4096]" - # t576 = prims.mul(t574, t575) # t576: "cuda:0 f32[1, 512, 4096]" - t577 = prims.convert_element_type(t576, dtypes.bfloat16) # t577: "cuda:0 bf16[1, 512, 4096]" - t578 = torch.nn.functional.linear(t577, t_transformer_h_4_attn_attn_weight, None) # t578: "cuda:0 bf16[1, 512, 12288]" - # t578 = ltorch.linear(t577, t_transformer_h_4_attn_attn_weight, None) # t578: "cuda:0 bf16[1, 512, 12288]" - # t578 = prims.linear(t577, t_transformer_h_4_attn_attn_weight, None) # t578: "cuda:0 bf16[1, 512, 12288]" - t584 = prims.reshape(t578, (1, 512, 32, 3, 128)) # t584: "cuda:0 bf16[1, 512, 32, 3, 128]" - t590 = prims.transpose(t584, (0, 2, 3, 1, 4)) # t590: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t591, t592, t593) = ltorch.split(t590, (1, 1, 1), 2) - # t591 = prims.slice_prim(t590, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t591: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t592 = prims.slice_prim(t590, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t592: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t593 = prims.slice_prim(t590, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t593: "cuda:0 bf16[1, 32, 1, 512, 128]" - t599 = prims.reshape(t591, (1, 32, 512, 128)) # t599: "cuda:0 bf16[1, 32, 512, 128]" - t605 = prims.reshape(t592, (1, 32, 512, 128)) # t605: "cuda:0 bf16[1, 32, 512, 128]" - t611 = prims.reshape(t593, (1, 32, 512, 128)) # t611: "cuda:0 bf16[1, 32, 512, 128]" - t612 = prims.slice_prim(t599, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t612: "cuda:0 bf16[1, 32, 512, 128]" - t613 = prims.slice_prim(t612, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t613: "cuda:0 bf16[1, 32, 512, 64]" - t614 = prims.slice_prim(t612, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t614: "cuda:0 bf16[1, 32, 512, 64]" - t615 = prims.convert_element_type(t614, dtypes.float32) # t615: "cuda:0 f32[1, 32, 512, 64]" - t616 = prims.neg(t615) # t616: "cuda:0 f32[1, 32, 512, 64]" - t617 = prims.convert_element_type(t616, dtypes.bfloat16) # t617: "cuda:0 bf16[1, 32, 512, 64]" - t619 = prims.cat((t617, t613), -1) # t619: "cuda:0 bf16[1, 32, 512, 128]" - t620 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t620: "cuda:0 f32[1, 32, 512, 128]" - t621 = prims.convert_element_type(t612, dtypes.float32) # t621: "cuda:0 f32[1, 32, 512, 128]" - t622 = ltorch.mul(t621, t620) # t622: "cuda:0 f32[1, 32, 512, 128]" - # t622 = prims.mul(t621, t620) # t622: "cuda:0 f32[1, 32, 512, 128]" - t623 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t623: "cuda:0 f32[1, 32, 512, 128]" - t624 = prims.convert_element_type(t619, dtypes.float32) # t624: "cuda:0 f32[1, 32, 512, 128]" - t625 = ltorch.mul(t624, t623) # t625: "cuda:0 f32[1, 32, 512, 128]" - # t625 = prims.mul(t624, t623) # t625: "cuda:0 f32[1, 32, 512, 128]" - t626 = ltorch.add(t622, t625, alpha=None) # t626: "cuda:0 f32[1, 32, 512, 128]" - # t626 = prims.add(t622, t625) # t626: "cuda:0 f32[1, 32, 512, 128]" - t627 = prims.convert_element_type(t626, dtypes.bfloat16) # t627: "cuda:0 bf16[1, 32, 512, 128]" - t628 = prims.slice_prim(t605, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t628: "cuda:0 bf16[1, 32, 512, 128]" - t629 = prims.slice_prim(t628, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t629: "cuda:0 bf16[1, 32, 512, 64]" - t630 = prims.slice_prim(t628, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t630: "cuda:0 bf16[1, 32, 512, 64]" - t631 = prims.convert_element_type(t630, dtypes.float32) # t631: "cuda:0 f32[1, 32, 512, 64]" - t632 = prims.neg(t631) # t632: "cuda:0 f32[1, 32, 512, 64]" - t633 = prims.convert_element_type(t632, dtypes.bfloat16) # t633: "cuda:0 bf16[1, 32, 512, 64]" - t635 = prims.cat((t633, t629), -1) # t635: "cuda:0 bf16[1, 32, 512, 128]" - t636 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t636: "cuda:0 f32[1, 32, 512, 128]" - t637 = prims.convert_element_type(t628, dtypes.float32) # t637: "cuda:0 f32[1, 32, 512, 128]" - t638 = ltorch.mul(t637, t636) # t638: "cuda:0 f32[1, 32, 512, 128]" - # t638 = prims.mul(t637, t636) # t638: "cuda:0 f32[1, 32, 512, 128]" - t639 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t639: "cuda:0 f32[1, 32, 512, 128]" - t640 = prims.convert_element_type(t635, dtypes.float32) # t640: "cuda:0 f32[1, 32, 512, 128]" - t641 = ltorch.mul(t640, t639) # t641: "cuda:0 f32[1, 32, 512, 128]" - # t641 = prims.mul(t640, t639) # t641: "cuda:0 f32[1, 32, 512, 128]" - t642 = ltorch.add(t638, t641, alpha=None) # t642: "cuda:0 f32[1, 32, 512, 128]" - # t642 = prims.add(t638, t641) # t642: "cuda:0 f32[1, 32, 512, 128]" - t643 = prims.convert_element_type(t642, dtypes.bfloat16) # t643: "cuda:0 bf16[1, 32, 512, 128]" - t644 = prims.slice_prim(t599, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t644: "cuda:0 bf16[1, 32, 512, 0]" - t646 = prims.cat((t627, t644), -1) # t646: "cuda:0 bf16[1, 32, 512, 128]" - t647 = prims.slice_prim(t605, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t647: "cuda:0 bf16[1, 32, 512, 0]" - t649 = prims.cat((t643, t647), -1) # t649: "cuda:0 bf16[1, 32, 512, 128]" - (t650, t651, t652, t653) = cudnn_sdpa_fwd(t646, t649, t611, None, 0.0, True, scale=0.08838834764831843) - t656 = prims.transpose(t650, (0, 2, 1, 3)) # t656: "cuda:0 bf16[1, 512, 32, 128]" - t660 = prims.reshape(t656, (1, 512, 4096)) # t660: "cuda:0 bf16[1, 512, 4096]" - t661 = torch.nn.functional.linear(t660, t_transformer_h_4_attn_proj_weight, None) # t661: "cuda:0 bf16[1, 512, 4096]" - # t661 = ltorch.linear(t660, t_transformer_h_4_attn_proj_weight, None) # t661: "cuda:0 bf16[1, 512, 4096]" - # t661 = prims.linear(t660, t_transformer_h_4_attn_proj_weight, None) # t661: "cuda:0 bf16[1, 512, 4096]" - t662 = prims.convert_element_type(t661, dtypes.float32) # t662: "cuda:0 f32[1, 512, 4096]" - t663 = prims.convert_element_type(t559, dtypes.float32) # t663: "cuda:0 f32[1, 512, 4096]" - t664 = ltorch.add(t662, t663, alpha=None) # t664: "cuda:0 f32[1, 512, 4096]" - # t664 = prims.add(t662, t663) # t664: "cuda:0 f32[1, 512, 4096]" - t665 = prims.convert_element_type(t664, dtypes.bfloat16) # t665: "cuda:0 bf16[1, 512, 4096]" - t666 = prims.convert_element_type(t665, dtypes.float32) # t666: "cuda:0 f32[1, 512, 4096]" - t667 = ltorch.mul(t666, t666) # t667: "cuda:0 f32[1, 512, 4096]" - # t667 = prims.mul(t666, t666) # t667: "cuda:0 f32[1, 512, 4096]" - t669 = prims.sum(t667, (2,)) # t669: "cuda:0 f32[1, 512]" - t670 = prims.broadcast_in_dim(t669, [1, 512, 1], [0, 1]) # t670: "cuda:0 f32[1, 512, 1]" - t672 = ltorch.true_divide(t670, 4096.0) # t672: "cuda:0 f32[1, 512, 1]" - # t672 = prims.div(t670, 4096.0) # t672: "cuda:0 f32[1, 512, 1]" - t674 = ltorch.add(t672, 1e-05, alpha=None) # t674: "cuda:0 f32[1, 512, 1]" - # t674 = prims.add(t672, 1e-05) # t674: "cuda:0 f32[1, 512, 1]" - t675 = prims.rsqrt(t674) # t675: "cuda:0 f32[1, 512, 1]" - t676 = prims.broadcast_in_dim(t675, (1, 512, 4096), (0, 1, 2)) # t676: "cuda:0 f32[1, 512, 4096]" - t677 = ltorch.mul(t666, t676) # t677: "cuda:0 f32[1, 512, 4096]" - # t677 = prims.mul(t666, t676) # t677: "cuda:0 f32[1, 512, 4096]" - t678 = prims.convert_element_type(t677, dtypes.bfloat16) # t678: "cuda:0 bf16[1, 512, 4096]" - t679 = prims.broadcast_in_dim(t_transformer_h_4_norm_2_weight, (1, 512, 4096), (2,)) # t679: "cuda:0 bf16[1, 512, 4096]" - t680 = prims.convert_element_type(t678, dtypes.float32) # t680: "cuda:0 f32[1, 512, 4096]" - t681 = prims.convert_element_type(t679, dtypes.float32) # t681: "cuda:0 f32[1, 512, 4096]" - t682 = ltorch.mul(t680, t681) # t682: "cuda:0 f32[1, 512, 4096]" - # t682 = prims.mul(t680, t681) # t682: "cuda:0 f32[1, 512, 4096]" - t683 = prims.convert_element_type(t682, dtypes.bfloat16) # t683: "cuda:0 bf16[1, 512, 4096]" - t684 = torch.nn.functional.linear(t683, t_transformer_h_4_mlp_fc_1_weight, None) # t684: "cuda:0 bf16[1, 512, 11008]" - # t684 = ltorch.linear(t683, t_transformer_h_4_mlp_fc_1_weight, None) # t684: "cuda:0 bf16[1, 512, 11008]" - # t684 = prims.linear(t683, t_transformer_h_4_mlp_fc_1_weight, None) # t684: "cuda:0 bf16[1, 512, 11008]" - t685 = torch.nn.functional.linear(t683, t_transformer_h_4_mlp_fc_2_weight, None) # t685: "cuda:0 bf16[1, 512, 11008]" - # t685 = ltorch.linear(t683, t_transformer_h_4_mlp_fc_2_weight, None) # t685: "cuda:0 bf16[1, 512, 11008]" - # t685 = prims.linear(t683, t_transformer_h_4_mlp_fc_2_weight, None) # t685: "cuda:0 bf16[1, 512, 11008]" - t686 = prims.convert_element_type(t684, dtypes.float32) # t686: "cuda:0 f32[1, 512, 11008]" - t687 = prims.neg(t686) # t687: "cuda:0 f32[1, 512, 11008]" - t688 = prims.exp(t687) # t688: "cuda:0 f32[1, 512, 11008]" - t689 = ltorch.add(1.0, t688, alpha=None) # t689: "cuda:0 f32[1, 512, 11008]" - # t689 = prims.add(1.0, t688) # t689: "cuda:0 f32[1, 512, 11008]" - t690 = prims.reciprocal(t689) # t690: "cuda:0 f32[1, 512, 11008]" - t691 = prims.convert_element_type(t690, dtypes.bfloat16) # t691: "cuda:0 bf16[1, 512, 11008]" - t692 = prims.convert_element_type(t684, dtypes.float32) # t692: "cuda:0 f32[1, 512, 11008]" - t693 = prims.convert_element_type(t691, dtypes.float32) # t693: "cuda:0 f32[1, 512, 11008]" - t694 = ltorch.mul(t692, t693) # t694: "cuda:0 f32[1, 512, 11008]" - # t694 = prims.mul(t692, t693) # t694: "cuda:0 f32[1, 512, 11008]" - t695 = prims.convert_element_type(t694, dtypes.bfloat16) # t695: "cuda:0 bf16[1, 512, 11008]" - t696 = prims.convert_element_type(t695, dtypes.float32) # t696: "cuda:0 f32[1, 512, 11008]" - t697 = prims.convert_element_type(t685, dtypes.float32) # t697: "cuda:0 f32[1, 512, 11008]" - t698 = ltorch.mul(t696, t697) # t698: "cuda:0 f32[1, 512, 11008]" - # t698 = prims.mul(t696, t697) # t698: "cuda:0 f32[1, 512, 11008]" - t699 = prims.convert_element_type(t698, dtypes.bfloat16) # t699: "cuda:0 bf16[1, 512, 11008]" - t700 = torch.nn.functional.linear(t699, t_transformer_h_4_mlp_proj_weight, None) # t700: "cuda:0 bf16[1, 512, 4096]" - # t700 = ltorch.linear(t699, t_transformer_h_4_mlp_proj_weight, None) # t700: "cuda:0 bf16[1, 512, 4096]" - # t700 = prims.linear(t699, t_transformer_h_4_mlp_proj_weight, None) # t700: "cuda:0 bf16[1, 512, 4096]" - t701 = prims.convert_element_type(t700, dtypes.float32) # t701: "cuda:0 f32[1, 512, 4096]" - t702 = prims.convert_element_type(t665, dtypes.float32) # t702: "cuda:0 f32[1, 512, 4096]" - t703 = ltorch.add(t701, t702, alpha=None) # t703: "cuda:0 f32[1, 512, 4096]" - # t703 = prims.add(t701, t702) # t703: "cuda:0 f32[1, 512, 4096]" - t704 = prims.convert_element_type(t703, dtypes.bfloat16) # t704: "cuda:0 bf16[1, 512, 4096]" - t705 = prims.convert_element_type(t704, dtypes.float32) # t705: "cuda:0 f32[1, 512, 4096]" - t706 = ltorch.mul(t705, t705) # t706: "cuda:0 f32[1, 512, 4096]" - # t706 = prims.mul(t705, t705) # t706: "cuda:0 f32[1, 512, 4096]" - t708 = prims.sum(t706, (2,)) # t708: "cuda:0 f32[1, 512]" - t709 = prims.broadcast_in_dim(t708, [1, 512, 1], [0, 1]) # t709: "cuda:0 f32[1, 512, 1]" - t711 = ltorch.true_divide(t709, 4096.0) # t711: "cuda:0 f32[1, 512, 1]" - # t711 = prims.div(t709, 4096.0) # t711: "cuda:0 f32[1, 512, 1]" - t713 = ltorch.add(t711, 1e-05, alpha=None) # t713: "cuda:0 f32[1, 512, 1]" - # t713 = prims.add(t711, 1e-05) # t713: "cuda:0 f32[1, 512, 1]" - t714 = prims.rsqrt(t713) # t714: "cuda:0 f32[1, 512, 1]" - t715 = prims.broadcast_in_dim(t714, (1, 512, 4096), (0, 1, 2)) # t715: "cuda:0 f32[1, 512, 4096]" - t716 = ltorch.mul(t705, t715) # t716: "cuda:0 f32[1, 512, 4096]" - # t716 = prims.mul(t705, t715) # t716: "cuda:0 f32[1, 512, 4096]" - t717 = prims.convert_element_type(t716, dtypes.bfloat16) # t717: "cuda:0 bf16[1, 512, 4096]" - t718 = prims.broadcast_in_dim(t_transformer_h_5_norm_1_weight, (1, 512, 4096), (2,)) # t718: "cuda:0 bf16[1, 512, 4096]" - t719 = prims.convert_element_type(t717, dtypes.float32) # t719: "cuda:0 f32[1, 512, 4096]" - t720 = prims.convert_element_type(t718, dtypes.float32) # t720: "cuda:0 f32[1, 512, 4096]" - t721 = ltorch.mul(t719, t720) # t721: "cuda:0 f32[1, 512, 4096]" - # t721 = prims.mul(t719, t720) # t721: "cuda:0 f32[1, 512, 4096]" - t722 = prims.convert_element_type(t721, dtypes.bfloat16) # t722: "cuda:0 bf16[1, 512, 4096]" - t723 = torch.nn.functional.linear(t722, t_transformer_h_5_attn_attn_weight, None) # t723: "cuda:0 bf16[1, 512, 12288]" - # t723 = ltorch.linear(t722, t_transformer_h_5_attn_attn_weight, None) # t723: "cuda:0 bf16[1, 512, 12288]" - # t723 = prims.linear(t722, t_transformer_h_5_attn_attn_weight, None) # t723: "cuda:0 bf16[1, 512, 12288]" - t729 = prims.reshape(t723, (1, 512, 32, 3, 128)) # t729: "cuda:0 bf16[1, 512, 32, 3, 128]" - t735 = prims.transpose(t729, (0, 2, 3, 1, 4)) # t735: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t736, t737, t738) = ltorch.split(t735, (1, 1, 1), 2) - # t736 = prims.slice_prim(t735, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t736: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t737 = prims.slice_prim(t735, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t737: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t738 = prims.slice_prim(t735, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t738: "cuda:0 bf16[1, 32, 1, 512, 128]" - t744 = prims.reshape(t736, (1, 32, 512, 128)) # t744: "cuda:0 bf16[1, 32, 512, 128]" - t750 = prims.reshape(t737, (1, 32, 512, 128)) # t750: "cuda:0 bf16[1, 32, 512, 128]" - t756 = prims.reshape(t738, (1, 32, 512, 128)) # t756: "cuda:0 bf16[1, 32, 512, 128]" - t757 = prims.slice_prim(t744, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t757: "cuda:0 bf16[1, 32, 512, 128]" - t758 = prims.slice_prim(t757, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t758: "cuda:0 bf16[1, 32, 512, 64]" - t759 = prims.slice_prim(t757, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t759: "cuda:0 bf16[1, 32, 512, 64]" - t760 = prims.convert_element_type(t759, dtypes.float32) # t760: "cuda:0 f32[1, 32, 512, 64]" - t761 = prims.neg(t760) # t761: "cuda:0 f32[1, 32, 512, 64]" - t762 = prims.convert_element_type(t761, dtypes.bfloat16) # t762: "cuda:0 bf16[1, 32, 512, 64]" - t764 = prims.cat((t762, t758), -1) # t764: "cuda:0 bf16[1, 32, 512, 128]" - t765 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t765: "cuda:0 f32[1, 32, 512, 128]" - t766 = prims.convert_element_type(t757, dtypes.float32) # t766: "cuda:0 f32[1, 32, 512, 128]" - t767 = ltorch.mul(t766, t765) # t767: "cuda:0 f32[1, 32, 512, 128]" - # t767 = prims.mul(t766, t765) # t767: "cuda:0 f32[1, 32, 512, 128]" - t768 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t768: "cuda:0 f32[1, 32, 512, 128]" - t769 = prims.convert_element_type(t764, dtypes.float32) # t769: "cuda:0 f32[1, 32, 512, 128]" - t770 = ltorch.mul(t769, t768) # t770: "cuda:0 f32[1, 32, 512, 128]" - # t770 = prims.mul(t769, t768) # t770: "cuda:0 f32[1, 32, 512, 128]" - t771 = ltorch.add(t767, t770, alpha=None) # t771: "cuda:0 f32[1, 32, 512, 128]" - # t771 = prims.add(t767, t770) # t771: "cuda:0 f32[1, 32, 512, 128]" - t772 = prims.convert_element_type(t771, dtypes.bfloat16) # t772: "cuda:0 bf16[1, 32, 512, 128]" - t773 = prims.slice_prim(t750, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t773: "cuda:0 bf16[1, 32, 512, 128]" - t774 = prims.slice_prim(t773, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t774: "cuda:0 bf16[1, 32, 512, 64]" - t775 = prims.slice_prim(t773, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t775: "cuda:0 bf16[1, 32, 512, 64]" - t776 = prims.convert_element_type(t775, dtypes.float32) # t776: "cuda:0 f32[1, 32, 512, 64]" - t777 = prims.neg(t776) # t777: "cuda:0 f32[1, 32, 512, 64]" - t778 = prims.convert_element_type(t777, dtypes.bfloat16) # t778: "cuda:0 bf16[1, 32, 512, 64]" - t780 = prims.cat((t778, t774), -1) # t780: "cuda:0 bf16[1, 32, 512, 128]" - t781 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t781: "cuda:0 f32[1, 32, 512, 128]" - t782 = prims.convert_element_type(t773, dtypes.float32) # t782: "cuda:0 f32[1, 32, 512, 128]" - t783 = ltorch.mul(t782, t781) # t783: "cuda:0 f32[1, 32, 512, 128]" - # t783 = prims.mul(t782, t781) # t783: "cuda:0 f32[1, 32, 512, 128]" - t784 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t784: "cuda:0 f32[1, 32, 512, 128]" - t785 = prims.convert_element_type(t780, dtypes.float32) # t785: "cuda:0 f32[1, 32, 512, 128]" - t786 = ltorch.mul(t785, t784) # t786: "cuda:0 f32[1, 32, 512, 128]" - # t786 = prims.mul(t785, t784) # t786: "cuda:0 f32[1, 32, 512, 128]" - t787 = ltorch.add(t783, t786, alpha=None) # t787: "cuda:0 f32[1, 32, 512, 128]" - # t787 = prims.add(t783, t786) # t787: "cuda:0 f32[1, 32, 512, 128]" - t788 = prims.convert_element_type(t787, dtypes.bfloat16) # t788: "cuda:0 bf16[1, 32, 512, 128]" - t789 = prims.slice_prim(t744, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t789: "cuda:0 bf16[1, 32, 512, 0]" - t791 = prims.cat((t772, t789), -1) # t791: "cuda:0 bf16[1, 32, 512, 128]" - t792 = prims.slice_prim(t750, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t792: "cuda:0 bf16[1, 32, 512, 0]" - t794 = prims.cat((t788, t792), -1) # t794: "cuda:0 bf16[1, 32, 512, 128]" - (t795, t796, t797, t798) = cudnn_sdpa_fwd(t791, t794, t756, None, 0.0, True, scale=0.08838834764831843) - t801 = prims.transpose(t795, (0, 2, 1, 3)) # t801: "cuda:0 bf16[1, 512, 32, 128]" - t805 = prims.reshape(t801, (1, 512, 4096)) # t805: "cuda:0 bf16[1, 512, 4096]" - t806 = torch.nn.functional.linear(t805, t_transformer_h_5_attn_proj_weight, None) # t806: "cuda:0 bf16[1, 512, 4096]" - # t806 = ltorch.linear(t805, t_transformer_h_5_attn_proj_weight, None) # t806: "cuda:0 bf16[1, 512, 4096]" - # t806 = prims.linear(t805, t_transformer_h_5_attn_proj_weight, None) # t806: "cuda:0 bf16[1, 512, 4096]" - t807 = prims.convert_element_type(t806, dtypes.float32) # t807: "cuda:0 f32[1, 512, 4096]" - t808 = prims.convert_element_type(t704, dtypes.float32) # t808: "cuda:0 f32[1, 512, 4096]" - t809 = ltorch.add(t807, t808, alpha=None) # t809: "cuda:0 f32[1, 512, 4096]" - # t809 = prims.add(t807, t808) # t809: "cuda:0 f32[1, 512, 4096]" - t810 = prims.convert_element_type(t809, dtypes.bfloat16) # t810: "cuda:0 bf16[1, 512, 4096]" - t811 = prims.convert_element_type(t810, dtypes.float32) # t811: "cuda:0 f32[1, 512, 4096]" - t812 = ltorch.mul(t811, t811) # t812: "cuda:0 f32[1, 512, 4096]" - # t812 = prims.mul(t811, t811) # t812: "cuda:0 f32[1, 512, 4096]" - t814 = prims.sum(t812, (2,)) # t814: "cuda:0 f32[1, 512]" - t815 = prims.broadcast_in_dim(t814, [1, 512, 1], [0, 1]) # t815: "cuda:0 f32[1, 512, 1]" - t817 = ltorch.true_divide(t815, 4096.0) # t817: "cuda:0 f32[1, 512, 1]" - # t817 = prims.div(t815, 4096.0) # t817: "cuda:0 f32[1, 512, 1]" - t819 = ltorch.add(t817, 1e-05, alpha=None) # t819: "cuda:0 f32[1, 512, 1]" - # t819 = prims.add(t817, 1e-05) # t819: "cuda:0 f32[1, 512, 1]" - t820 = prims.rsqrt(t819) # t820: "cuda:0 f32[1, 512, 1]" - t821 = prims.broadcast_in_dim(t820, (1, 512, 4096), (0, 1, 2)) # t821: "cuda:0 f32[1, 512, 4096]" - t822 = ltorch.mul(t811, t821) # t822: "cuda:0 f32[1, 512, 4096]" - # t822 = prims.mul(t811, t821) # t822: "cuda:0 f32[1, 512, 4096]" - t823 = prims.convert_element_type(t822, dtypes.bfloat16) # t823: "cuda:0 bf16[1, 512, 4096]" - t824 = prims.broadcast_in_dim(t_transformer_h_5_norm_2_weight, (1, 512, 4096), (2,)) # t824: "cuda:0 bf16[1, 512, 4096]" - t825 = prims.convert_element_type(t823, dtypes.float32) # t825: "cuda:0 f32[1, 512, 4096]" - t826 = prims.convert_element_type(t824, dtypes.float32) # t826: "cuda:0 f32[1, 512, 4096]" - t827 = ltorch.mul(t825, t826) # t827: "cuda:0 f32[1, 512, 4096]" - # t827 = prims.mul(t825, t826) # t827: "cuda:0 f32[1, 512, 4096]" - t828 = prims.convert_element_type(t827, dtypes.bfloat16) # t828: "cuda:0 bf16[1, 512, 4096]" - t829 = torch.nn.functional.linear(t828, t_transformer_h_5_mlp_fc_1_weight, None) # t829: "cuda:0 bf16[1, 512, 11008]" - # t829 = ltorch.linear(t828, t_transformer_h_5_mlp_fc_1_weight, None) # t829: "cuda:0 bf16[1, 512, 11008]" - # t829 = prims.linear(t828, t_transformer_h_5_mlp_fc_1_weight, None) # t829: "cuda:0 bf16[1, 512, 11008]" - t830 = torch.nn.functional.linear(t828, t_transformer_h_5_mlp_fc_2_weight, None) # t830: "cuda:0 bf16[1, 512, 11008]" - # t830 = ltorch.linear(t828, t_transformer_h_5_mlp_fc_2_weight, None) # t830: "cuda:0 bf16[1, 512, 11008]" - # t830 = prims.linear(t828, t_transformer_h_5_mlp_fc_2_weight, None) # t830: "cuda:0 bf16[1, 512, 11008]" - t831 = prims.convert_element_type(t829, dtypes.float32) # t831: "cuda:0 f32[1, 512, 11008]" - t832 = prims.neg(t831) # t832: "cuda:0 f32[1, 512, 11008]" - t833 = prims.exp(t832) # t833: "cuda:0 f32[1, 512, 11008]" - t834 = ltorch.add(1.0, t833, alpha=None) # t834: "cuda:0 f32[1, 512, 11008]" - # t834 = prims.add(1.0, t833) # t834: "cuda:0 f32[1, 512, 11008]" - t835 = prims.reciprocal(t834) # t835: "cuda:0 f32[1, 512, 11008]" - t836 = prims.convert_element_type(t835, dtypes.bfloat16) # t836: "cuda:0 bf16[1, 512, 11008]" - t837 = prims.convert_element_type(t829, dtypes.float32) # t837: "cuda:0 f32[1, 512, 11008]" - t838 = prims.convert_element_type(t836, dtypes.float32) # t838: "cuda:0 f32[1, 512, 11008]" - t839 = ltorch.mul(t837, t838) # t839: "cuda:0 f32[1, 512, 11008]" - # t839 = prims.mul(t837, t838) # t839: "cuda:0 f32[1, 512, 11008]" - t840 = prims.convert_element_type(t839, dtypes.bfloat16) # t840: "cuda:0 bf16[1, 512, 11008]" - t841 = prims.convert_element_type(t840, dtypes.float32) # t841: "cuda:0 f32[1, 512, 11008]" - t842 = prims.convert_element_type(t830, dtypes.float32) # t842: "cuda:0 f32[1, 512, 11008]" - t843 = ltorch.mul(t841, t842) # t843: "cuda:0 f32[1, 512, 11008]" - # t843 = prims.mul(t841, t842) # t843: "cuda:0 f32[1, 512, 11008]" - t844 = prims.convert_element_type(t843, dtypes.bfloat16) # t844: "cuda:0 bf16[1, 512, 11008]" - t845 = torch.nn.functional.linear(t844, t_transformer_h_5_mlp_proj_weight, None) # t845: "cuda:0 bf16[1, 512, 4096]" - # t845 = ltorch.linear(t844, t_transformer_h_5_mlp_proj_weight, None) # t845: "cuda:0 bf16[1, 512, 4096]" - # t845 = prims.linear(t844, t_transformer_h_5_mlp_proj_weight, None) # t845: "cuda:0 bf16[1, 512, 4096]" - t846 = prims.convert_element_type(t845, dtypes.float32) # t846: "cuda:0 f32[1, 512, 4096]" - t847 = prims.convert_element_type(t810, dtypes.float32) # t847: "cuda:0 f32[1, 512, 4096]" - t848 = ltorch.add(t846, t847, alpha=None) # t848: "cuda:0 f32[1, 512, 4096]" - # t848 = prims.add(t846, t847) # t848: "cuda:0 f32[1, 512, 4096]" - t849 = prims.convert_element_type(t848, dtypes.bfloat16) # t849: "cuda:0 bf16[1, 512, 4096]" - t850 = prims.convert_element_type(t849, dtypes.float32) # t850: "cuda:0 f32[1, 512, 4096]" - t851 = ltorch.mul(t850, t850) # t851: "cuda:0 f32[1, 512, 4096]" - # t851 = prims.mul(t850, t850) # t851: "cuda:0 f32[1, 512, 4096]" - t853 = prims.sum(t851, (2,)) # t853: "cuda:0 f32[1, 512]" - t854 = prims.broadcast_in_dim(t853, [1, 512, 1], [0, 1]) # t854: "cuda:0 f32[1, 512, 1]" - t856 = ltorch.true_divide(t854, 4096.0) # t856: "cuda:0 f32[1, 512, 1]" - # t856 = prims.div(t854, 4096.0) # t856: "cuda:0 f32[1, 512, 1]" - t858 = ltorch.add(t856, 1e-05, alpha=None) # t858: "cuda:0 f32[1, 512, 1]" - # t858 = prims.add(t856, 1e-05) # t858: "cuda:0 f32[1, 512, 1]" - t859 = prims.rsqrt(t858) # t859: "cuda:0 f32[1, 512, 1]" - t860 = prims.broadcast_in_dim(t859, (1, 512, 4096), (0, 1, 2)) # t860: "cuda:0 f32[1, 512, 4096]" - t861 = ltorch.mul(t850, t860) # t861: "cuda:0 f32[1, 512, 4096]" - # t861 = prims.mul(t850, t860) # t861: "cuda:0 f32[1, 512, 4096]" - t862 = prims.convert_element_type(t861, dtypes.bfloat16) # t862: "cuda:0 bf16[1, 512, 4096]" - t863 = prims.broadcast_in_dim(t_transformer_h_6_norm_1_weight, (1, 512, 4096), (2,)) # t863: "cuda:0 bf16[1, 512, 4096]" - t864 = prims.convert_element_type(t862, dtypes.float32) # t864: "cuda:0 f32[1, 512, 4096]" - t865 = prims.convert_element_type(t863, dtypes.float32) # t865: "cuda:0 f32[1, 512, 4096]" - t866 = ltorch.mul(t864, t865) # t866: "cuda:0 f32[1, 512, 4096]" - # t866 = prims.mul(t864, t865) # t866: "cuda:0 f32[1, 512, 4096]" - t867 = prims.convert_element_type(t866, dtypes.bfloat16) # t867: "cuda:0 bf16[1, 512, 4096]" - t868 = torch.nn.functional.linear(t867, t_transformer_h_6_attn_attn_weight, None) # t868: "cuda:0 bf16[1, 512, 12288]" - # t868 = ltorch.linear(t867, t_transformer_h_6_attn_attn_weight, None) # t868: "cuda:0 bf16[1, 512, 12288]" - # t868 = prims.linear(t867, t_transformer_h_6_attn_attn_weight, None) # t868: "cuda:0 bf16[1, 512, 12288]" - t874 = prims.reshape(t868, (1, 512, 32, 3, 128)) # t874: "cuda:0 bf16[1, 512, 32, 3, 128]" - t880 = prims.transpose(t874, (0, 2, 3, 1, 4)) # t880: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t881, t882, t883) = ltorch.split(t880, (1, 1, 1), 2) - # t881 = prims.slice_prim(t880, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t881: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t882 = prims.slice_prim(t880, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t882: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t883 = prims.slice_prim(t880, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t883: "cuda:0 bf16[1, 32, 1, 512, 128]" - t889 = prims.reshape(t881, (1, 32, 512, 128)) # t889: "cuda:0 bf16[1, 32, 512, 128]" - t895 = prims.reshape(t882, (1, 32, 512, 128)) # t895: "cuda:0 bf16[1, 32, 512, 128]" - t901 = prims.reshape(t883, (1, 32, 512, 128)) # t901: "cuda:0 bf16[1, 32, 512, 128]" - t902 = prims.slice_prim(t889, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t902: "cuda:0 bf16[1, 32, 512, 128]" - t903 = prims.slice_prim(t902, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t903: "cuda:0 bf16[1, 32, 512, 64]" - t904 = prims.slice_prim(t902, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t904: "cuda:0 bf16[1, 32, 512, 64]" - t905 = prims.convert_element_type(t904, dtypes.float32) # t905: "cuda:0 f32[1, 32, 512, 64]" - t906 = prims.neg(t905) # t906: "cuda:0 f32[1, 32, 512, 64]" - t907 = prims.convert_element_type(t906, dtypes.bfloat16) # t907: "cuda:0 bf16[1, 32, 512, 64]" - t909 = prims.cat((t907, t903), -1) # t909: "cuda:0 bf16[1, 32, 512, 128]" - t910 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t910: "cuda:0 f32[1, 32, 512, 128]" - t911 = prims.convert_element_type(t902, dtypes.float32) # t911: "cuda:0 f32[1, 32, 512, 128]" - t912 = ltorch.mul(t911, t910) # t912: "cuda:0 f32[1, 32, 512, 128]" - # t912 = prims.mul(t911, t910) # t912: "cuda:0 f32[1, 32, 512, 128]" - t913 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t913: "cuda:0 f32[1, 32, 512, 128]" - t914 = prims.convert_element_type(t909, dtypes.float32) # t914: "cuda:0 f32[1, 32, 512, 128]" - t915 = ltorch.mul(t914, t913) # t915: "cuda:0 f32[1, 32, 512, 128]" - # t915 = prims.mul(t914, t913) # t915: "cuda:0 f32[1, 32, 512, 128]" - t916 = ltorch.add(t912, t915, alpha=None) # t916: "cuda:0 f32[1, 32, 512, 128]" - # t916 = prims.add(t912, t915) # t916: "cuda:0 f32[1, 32, 512, 128]" - t917 = prims.convert_element_type(t916, dtypes.bfloat16) # t917: "cuda:0 bf16[1, 32, 512, 128]" - t918 = prims.slice_prim(t895, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t918: "cuda:0 bf16[1, 32, 512, 128]" - t919 = prims.slice_prim(t918, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t919: "cuda:0 bf16[1, 32, 512, 64]" - t920 = prims.slice_prim(t918, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t920: "cuda:0 bf16[1, 32, 512, 64]" - t921 = prims.convert_element_type(t920, dtypes.float32) # t921: "cuda:0 f32[1, 32, 512, 64]" - t922 = prims.neg(t921) # t922: "cuda:0 f32[1, 32, 512, 64]" - t923 = prims.convert_element_type(t922, dtypes.bfloat16) # t923: "cuda:0 bf16[1, 32, 512, 64]" - t925 = prims.cat((t923, t919), -1) # t925: "cuda:0 bf16[1, 32, 512, 128]" - t926 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t926: "cuda:0 f32[1, 32, 512, 128]" - t927 = prims.convert_element_type(t918, dtypes.float32) # t927: "cuda:0 f32[1, 32, 512, 128]" - t928 = ltorch.mul(t927, t926) # t928: "cuda:0 f32[1, 32, 512, 128]" - # t928 = prims.mul(t927, t926) # t928: "cuda:0 f32[1, 32, 512, 128]" - t929 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t929: "cuda:0 f32[1, 32, 512, 128]" - t930 = prims.convert_element_type(t925, dtypes.float32) # t930: "cuda:0 f32[1, 32, 512, 128]" - t931 = ltorch.mul(t930, t929) # t931: "cuda:0 f32[1, 32, 512, 128]" - # t931 = prims.mul(t930, t929) # t931: "cuda:0 f32[1, 32, 512, 128]" - t932 = ltorch.add(t928, t931, alpha=None) # t932: "cuda:0 f32[1, 32, 512, 128]" - # t932 = prims.add(t928, t931) # t932: "cuda:0 f32[1, 32, 512, 128]" - t933 = prims.convert_element_type(t932, dtypes.bfloat16) # t933: "cuda:0 bf16[1, 32, 512, 128]" - t934 = prims.slice_prim(t889, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t934: "cuda:0 bf16[1, 32, 512, 0]" - t936 = prims.cat((t917, t934), -1) # t936: "cuda:0 bf16[1, 32, 512, 128]" - t937 = prims.slice_prim(t895, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t937: "cuda:0 bf16[1, 32, 512, 0]" - t939 = prims.cat((t933, t937), -1) # t939: "cuda:0 bf16[1, 32, 512, 128]" - (t940, t941, t942, t943) = cudnn_sdpa_fwd(t936, t939, t901, None, 0.0, True, scale=0.08838834764831843) - t946 = prims.transpose(t940, (0, 2, 1, 3)) # t946: "cuda:0 bf16[1, 512, 32, 128]" - t950 = prims.reshape(t946, (1, 512, 4096)) # t950: "cuda:0 bf16[1, 512, 4096]" - t951 = torch.nn.functional.linear(t950, t_transformer_h_6_attn_proj_weight, None) # t951: "cuda:0 bf16[1, 512, 4096]" - # t951 = ltorch.linear(t950, t_transformer_h_6_attn_proj_weight, None) # t951: "cuda:0 bf16[1, 512, 4096]" - # t951 = prims.linear(t950, t_transformer_h_6_attn_proj_weight, None) # t951: "cuda:0 bf16[1, 512, 4096]" - t952 = prims.convert_element_type(t951, dtypes.float32) # t952: "cuda:0 f32[1, 512, 4096]" - t953 = prims.convert_element_type(t849, dtypes.float32) # t953: "cuda:0 f32[1, 512, 4096]" - t954 = ltorch.add(t952, t953, alpha=None) # t954: "cuda:0 f32[1, 512, 4096]" - # t954 = prims.add(t952, t953) # t954: "cuda:0 f32[1, 512, 4096]" - t955 = prims.convert_element_type(t954, dtypes.bfloat16) # t955: "cuda:0 bf16[1, 512, 4096]" - t956 = prims.convert_element_type(t955, dtypes.float32) # t956: "cuda:0 f32[1, 512, 4096]" - t957 = ltorch.mul(t956, t956) # t957: "cuda:0 f32[1, 512, 4096]" - # t957 = prims.mul(t956, t956) # t957: "cuda:0 f32[1, 512, 4096]" - t959 = prims.sum(t957, (2,)) # t959: "cuda:0 f32[1, 512]" - t960 = prims.broadcast_in_dim(t959, [1, 512, 1], [0, 1]) # t960: "cuda:0 f32[1, 512, 1]" - t962 = ltorch.true_divide(t960, 4096.0) # t962: "cuda:0 f32[1, 512, 1]" - # t962 = prims.div(t960, 4096.0) # t962: "cuda:0 f32[1, 512, 1]" - t964 = ltorch.add(t962, 1e-05, alpha=None) # t964: "cuda:0 f32[1, 512, 1]" - # t964 = prims.add(t962, 1e-05) # t964: "cuda:0 f32[1, 512, 1]" - t965 = prims.rsqrt(t964) # t965: "cuda:0 f32[1, 512, 1]" - t966 = prims.broadcast_in_dim(t965, (1, 512, 4096), (0, 1, 2)) # t966: "cuda:0 f32[1, 512, 4096]" - t967 = ltorch.mul(t956, t966) # t967: "cuda:0 f32[1, 512, 4096]" - # t967 = prims.mul(t956, t966) # t967: "cuda:0 f32[1, 512, 4096]" - t968 = prims.convert_element_type(t967, dtypes.bfloat16) # t968: "cuda:0 bf16[1, 512, 4096]" - t969 = prims.broadcast_in_dim(t_transformer_h_6_norm_2_weight, (1, 512, 4096), (2,)) # t969: "cuda:0 bf16[1, 512, 4096]" - t970 = prims.convert_element_type(t968, dtypes.float32) # t970: "cuda:0 f32[1, 512, 4096]" - t971 = prims.convert_element_type(t969, dtypes.float32) # t971: "cuda:0 f32[1, 512, 4096]" - t972 = ltorch.mul(t970, t971) # t972: "cuda:0 f32[1, 512, 4096]" - # t972 = prims.mul(t970, t971) # t972: "cuda:0 f32[1, 512, 4096]" - t973 = prims.convert_element_type(t972, dtypes.bfloat16) # t973: "cuda:0 bf16[1, 512, 4096]" - t974 = torch.nn.functional.linear(t973, t_transformer_h_6_mlp_fc_1_weight, None) # t974: "cuda:0 bf16[1, 512, 11008]" - # t974 = ltorch.linear(t973, t_transformer_h_6_mlp_fc_1_weight, None) # t974: "cuda:0 bf16[1, 512, 11008]" - # t974 = prims.linear(t973, t_transformer_h_6_mlp_fc_1_weight, None) # t974: "cuda:0 bf16[1, 512, 11008]" - t975 = torch.nn.functional.linear(t973, t_transformer_h_6_mlp_fc_2_weight, None) # t975: "cuda:0 bf16[1, 512, 11008]" - # t975 = ltorch.linear(t973, t_transformer_h_6_mlp_fc_2_weight, None) # t975: "cuda:0 bf16[1, 512, 11008]" - # t975 = prims.linear(t973, t_transformer_h_6_mlp_fc_2_weight, None) # t975: "cuda:0 bf16[1, 512, 11008]" - t976 = prims.convert_element_type(t974, dtypes.float32) # t976: "cuda:0 f32[1, 512, 11008]" - t977 = prims.neg(t976) # t977: "cuda:0 f32[1, 512, 11008]" - t978 = prims.exp(t977) # t978: "cuda:0 f32[1, 512, 11008]" - t979 = ltorch.add(1.0, t978, alpha=None) # t979: "cuda:0 f32[1, 512, 11008]" - # t979 = prims.add(1.0, t978) # t979: "cuda:0 f32[1, 512, 11008]" - t980 = prims.reciprocal(t979) # t980: "cuda:0 f32[1, 512, 11008]" - t981 = prims.convert_element_type(t980, dtypes.bfloat16) # t981: "cuda:0 bf16[1, 512, 11008]" - t982 = prims.convert_element_type(t974, dtypes.float32) # t982: "cuda:0 f32[1, 512, 11008]" - t983 = prims.convert_element_type(t981, dtypes.float32) # t983: "cuda:0 f32[1, 512, 11008]" - t984 = ltorch.mul(t982, t983) # t984: "cuda:0 f32[1, 512, 11008]" - # t984 = prims.mul(t982, t983) # t984: "cuda:0 f32[1, 512, 11008]" - t985 = prims.convert_element_type(t984, dtypes.bfloat16) # t985: "cuda:0 bf16[1, 512, 11008]" - t986 = prims.convert_element_type(t985, dtypes.float32) # t986: "cuda:0 f32[1, 512, 11008]" - t987 = prims.convert_element_type(t975, dtypes.float32) # t987: "cuda:0 f32[1, 512, 11008]" - t988 = ltorch.mul(t986, t987) # t988: "cuda:0 f32[1, 512, 11008]" - # t988 = prims.mul(t986, t987) # t988: "cuda:0 f32[1, 512, 11008]" - t989 = prims.convert_element_type(t988, dtypes.bfloat16) # t989: "cuda:0 bf16[1, 512, 11008]" - t990 = torch.nn.functional.linear(t989, t_transformer_h_6_mlp_proj_weight, None) # t990: "cuda:0 bf16[1, 512, 4096]" - # t990 = ltorch.linear(t989, t_transformer_h_6_mlp_proj_weight, None) # t990: "cuda:0 bf16[1, 512, 4096]" - # t990 = prims.linear(t989, t_transformer_h_6_mlp_proj_weight, None) # t990: "cuda:0 bf16[1, 512, 4096]" - t991 = prims.convert_element_type(t990, dtypes.float32) # t991: "cuda:0 f32[1, 512, 4096]" - t992 = prims.convert_element_type(t955, dtypes.float32) # t992: "cuda:0 f32[1, 512, 4096]" - t993 = ltorch.add(t991, t992, alpha=None) # t993: "cuda:0 f32[1, 512, 4096]" - # t993 = prims.add(t991, t992) # t993: "cuda:0 f32[1, 512, 4096]" - t994 = prims.convert_element_type(t993, dtypes.bfloat16) # t994: "cuda:0 bf16[1, 512, 4096]" - t995 = prims.convert_element_type(t994, dtypes.float32) # t995: "cuda:0 f32[1, 512, 4096]" - t996 = ltorch.mul(t995, t995) # t996: "cuda:0 f32[1, 512, 4096]" - # t996 = prims.mul(t995, t995) # t996: "cuda:0 f32[1, 512, 4096]" - t998 = prims.sum(t996, (2,)) # t998: "cuda:0 f32[1, 512]" - t999 = prims.broadcast_in_dim(t998, [1, 512, 1], [0, 1]) # t999: "cuda:0 f32[1, 512, 1]" - t1001 = ltorch.true_divide(t999, 4096.0) # t1001: "cuda:0 f32[1, 512, 1]" - # t1001 = prims.div(t999, 4096.0) # t1001: "cuda:0 f32[1, 512, 1]" - t1003 = ltorch.add(t1001, 1e-05, alpha=None) # t1003: "cuda:0 f32[1, 512, 1]" - # t1003 = prims.add(t1001, 1e-05) # t1003: "cuda:0 f32[1, 512, 1]" - t1004 = prims.rsqrt(t1003) # t1004: "cuda:0 f32[1, 512, 1]" - t1005 = prims.broadcast_in_dim(t1004, (1, 512, 4096), (0, 1, 2)) # t1005: "cuda:0 f32[1, 512, 4096]" - t1006 = ltorch.mul(t995, t1005) # t1006: "cuda:0 f32[1, 512, 4096]" - # t1006 = prims.mul(t995, t1005) # t1006: "cuda:0 f32[1, 512, 4096]" - t1007 = prims.convert_element_type(t1006, dtypes.bfloat16) # t1007: "cuda:0 bf16[1, 512, 4096]" - t1008 = prims.broadcast_in_dim(t_transformer_h_7_norm_1_weight, (1, 512, 4096), (2,)) # t1008: "cuda:0 bf16[1, 512, 4096]" - t1009 = prims.convert_element_type(t1007, dtypes.float32) # t1009: "cuda:0 f32[1, 512, 4096]" - t1010 = prims.convert_element_type(t1008, dtypes.float32) # t1010: "cuda:0 f32[1, 512, 4096]" - t1011 = ltorch.mul(t1009, t1010) # t1011: "cuda:0 f32[1, 512, 4096]" - # t1011 = prims.mul(t1009, t1010) # t1011: "cuda:0 f32[1, 512, 4096]" - t1012 = prims.convert_element_type(t1011, dtypes.bfloat16) # t1012: "cuda:0 bf16[1, 512, 4096]" - t1013 = torch.nn.functional.linear(t1012, t_transformer_h_7_attn_attn_weight, None) # t1013: "cuda:0 bf16[1, 512, 12288]" - # t1013 = ltorch.linear(t1012, t_transformer_h_7_attn_attn_weight, None) # t1013: "cuda:0 bf16[1, 512, 12288]" - # t1013 = prims.linear(t1012, t_transformer_h_7_attn_attn_weight, None) # t1013: "cuda:0 bf16[1, 512, 12288]" - t1019 = prims.reshape(t1013, (1, 512, 32, 3, 128)) # t1019: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1025 = prims.transpose(t1019, (0, 2, 3, 1, 4)) # t1025: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1026, t1027, t1028) = ltorch.split(t1025, (1, 1, 1), 2) - # t1026 = prims.slice_prim(t1025, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1026: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1027 = prims.slice_prim(t1025, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1027: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1028 = prims.slice_prim(t1025, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1028: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1034 = prims.reshape(t1026, (1, 32, 512, 128)) # t1034: "cuda:0 bf16[1, 32, 512, 128]" - t1040 = prims.reshape(t1027, (1, 32, 512, 128)) # t1040: "cuda:0 bf16[1, 32, 512, 128]" - t1046 = prims.reshape(t1028, (1, 32, 512, 128)) # t1046: "cuda:0 bf16[1, 32, 512, 128]" - t1047 = prims.slice_prim(t1034, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1047: "cuda:0 bf16[1, 32, 512, 128]" - t1048 = prims.slice_prim(t1047, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1048: "cuda:0 bf16[1, 32, 512, 64]" - t1049 = prims.slice_prim(t1047, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1049: "cuda:0 bf16[1, 32, 512, 64]" - t1050 = prims.convert_element_type(t1049, dtypes.float32) # t1050: "cuda:0 f32[1, 32, 512, 64]" - t1051 = prims.neg(t1050) # t1051: "cuda:0 f32[1, 32, 512, 64]" - t1052 = prims.convert_element_type(t1051, dtypes.bfloat16) # t1052: "cuda:0 bf16[1, 32, 512, 64]" - t1054 = prims.cat((t1052, t1048), -1) # t1054: "cuda:0 bf16[1, 32, 512, 128]" - t1055 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1055: "cuda:0 f32[1, 32, 512, 128]" - t1056 = prims.convert_element_type(t1047, dtypes.float32) # t1056: "cuda:0 f32[1, 32, 512, 128]" - t1057 = ltorch.mul(t1056, t1055) # t1057: "cuda:0 f32[1, 32, 512, 128]" - # t1057 = prims.mul(t1056, t1055) # t1057: "cuda:0 f32[1, 32, 512, 128]" - t1058 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1058: "cuda:0 f32[1, 32, 512, 128]" - t1059 = prims.convert_element_type(t1054, dtypes.float32) # t1059: "cuda:0 f32[1, 32, 512, 128]" - t1060 = ltorch.mul(t1059, t1058) # t1060: "cuda:0 f32[1, 32, 512, 128]" - # t1060 = prims.mul(t1059, t1058) # t1060: "cuda:0 f32[1, 32, 512, 128]" - t1061 = ltorch.add(t1057, t1060, alpha=None) # t1061: "cuda:0 f32[1, 32, 512, 128]" - # t1061 = prims.add(t1057, t1060) # t1061: "cuda:0 f32[1, 32, 512, 128]" - t1062 = prims.convert_element_type(t1061, dtypes.bfloat16) # t1062: "cuda:0 bf16[1, 32, 512, 128]" - t1063 = prims.slice_prim(t1040, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1063: "cuda:0 bf16[1, 32, 512, 128]" - t1064 = prims.slice_prim(t1063, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1064: "cuda:0 bf16[1, 32, 512, 64]" - t1065 = prims.slice_prim(t1063, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1065: "cuda:0 bf16[1, 32, 512, 64]" - t1066 = prims.convert_element_type(t1065, dtypes.float32) # t1066: "cuda:0 f32[1, 32, 512, 64]" - t1067 = prims.neg(t1066) # t1067: "cuda:0 f32[1, 32, 512, 64]" - t1068 = prims.convert_element_type(t1067, dtypes.bfloat16) # t1068: "cuda:0 bf16[1, 32, 512, 64]" - t1070 = prims.cat((t1068, t1064), -1) # t1070: "cuda:0 bf16[1, 32, 512, 128]" - t1071 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1071: "cuda:0 f32[1, 32, 512, 128]" - t1072 = prims.convert_element_type(t1063, dtypes.float32) # t1072: "cuda:0 f32[1, 32, 512, 128]" - t1073 = ltorch.mul(t1072, t1071) # t1073: "cuda:0 f32[1, 32, 512, 128]" - # t1073 = prims.mul(t1072, t1071) # t1073: "cuda:0 f32[1, 32, 512, 128]" - t1074 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1074: "cuda:0 f32[1, 32, 512, 128]" - t1075 = prims.convert_element_type(t1070, dtypes.float32) # t1075: "cuda:0 f32[1, 32, 512, 128]" - t1076 = ltorch.mul(t1075, t1074) # t1076: "cuda:0 f32[1, 32, 512, 128]" - # t1076 = prims.mul(t1075, t1074) # t1076: "cuda:0 f32[1, 32, 512, 128]" - t1077 = ltorch.add(t1073, t1076, alpha=None) # t1077: "cuda:0 f32[1, 32, 512, 128]" - # t1077 = prims.add(t1073, t1076) # t1077: "cuda:0 f32[1, 32, 512, 128]" - t1078 = prims.convert_element_type(t1077, dtypes.bfloat16) # t1078: "cuda:0 bf16[1, 32, 512, 128]" - t1079 = prims.slice_prim(t1034, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1079: "cuda:0 bf16[1, 32, 512, 0]" - t1081 = prims.cat((t1062, t1079), -1) # t1081: "cuda:0 bf16[1, 32, 512, 128]" - t1082 = prims.slice_prim(t1040, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1082: "cuda:0 bf16[1, 32, 512, 0]" - t1084 = prims.cat((t1078, t1082), -1) # t1084: "cuda:0 bf16[1, 32, 512, 128]" - (t1085, t1086, t1087, t1088) = cudnn_sdpa_fwd(t1081, t1084, t1046, None, 0.0, True, scale=0.08838834764831843) - t1091 = prims.transpose(t1085, (0, 2, 1, 3)) # t1091: "cuda:0 bf16[1, 512, 32, 128]" - t1095 = prims.reshape(t1091, (1, 512, 4096)) # t1095: "cuda:0 bf16[1, 512, 4096]" - t1096 = torch.nn.functional.linear(t1095, t_transformer_h_7_attn_proj_weight, None) # t1096: "cuda:0 bf16[1, 512, 4096]" - # t1096 = ltorch.linear(t1095, t_transformer_h_7_attn_proj_weight, None) # t1096: "cuda:0 bf16[1, 512, 4096]" - # t1096 = prims.linear(t1095, t_transformer_h_7_attn_proj_weight, None) # t1096: "cuda:0 bf16[1, 512, 4096]" - t1097 = prims.convert_element_type(t1096, dtypes.float32) # t1097: "cuda:0 f32[1, 512, 4096]" - t1098 = prims.convert_element_type(t994, dtypes.float32) # t1098: "cuda:0 f32[1, 512, 4096]" - t1099 = ltorch.add(t1097, t1098, alpha=None) # t1099: "cuda:0 f32[1, 512, 4096]" - # t1099 = prims.add(t1097, t1098) # t1099: "cuda:0 f32[1, 512, 4096]" - t1100 = prims.convert_element_type(t1099, dtypes.bfloat16) # t1100: "cuda:0 bf16[1, 512, 4096]" - t1101 = prims.convert_element_type(t1100, dtypes.float32) # t1101: "cuda:0 f32[1, 512, 4096]" - t1102 = ltorch.mul(t1101, t1101) # t1102: "cuda:0 f32[1, 512, 4096]" - # t1102 = prims.mul(t1101, t1101) # t1102: "cuda:0 f32[1, 512, 4096]" - t1104 = prims.sum(t1102, (2,)) # t1104: "cuda:0 f32[1, 512]" - t1105 = prims.broadcast_in_dim(t1104, [1, 512, 1], [0, 1]) # t1105: "cuda:0 f32[1, 512, 1]" - t1107 = ltorch.true_divide(t1105, 4096.0) # t1107: "cuda:0 f32[1, 512, 1]" - # t1107 = prims.div(t1105, 4096.0) # t1107: "cuda:0 f32[1, 512, 1]" - t1109 = ltorch.add(t1107, 1e-05, alpha=None) # t1109: "cuda:0 f32[1, 512, 1]" - # t1109 = prims.add(t1107, 1e-05) # t1109: "cuda:0 f32[1, 512, 1]" - t1110 = prims.rsqrt(t1109) # t1110: "cuda:0 f32[1, 512, 1]" - t1111 = prims.broadcast_in_dim(t1110, (1, 512, 4096), (0, 1, 2)) # t1111: "cuda:0 f32[1, 512, 4096]" - t1112 = ltorch.mul(t1101, t1111) # t1112: "cuda:0 f32[1, 512, 4096]" - # t1112 = prims.mul(t1101, t1111) # t1112: "cuda:0 f32[1, 512, 4096]" - t1113 = prims.convert_element_type(t1112, dtypes.bfloat16) # t1113: "cuda:0 bf16[1, 512, 4096]" - t1114 = prims.broadcast_in_dim(t_transformer_h_7_norm_2_weight, (1, 512, 4096), (2,)) # t1114: "cuda:0 bf16[1, 512, 4096]" - t1115 = prims.convert_element_type(t1113, dtypes.float32) # t1115: "cuda:0 f32[1, 512, 4096]" - t1116 = prims.convert_element_type(t1114, dtypes.float32) # t1116: "cuda:0 f32[1, 512, 4096]" - t1117 = ltorch.mul(t1115, t1116) # t1117: "cuda:0 f32[1, 512, 4096]" - # t1117 = prims.mul(t1115, t1116) # t1117: "cuda:0 f32[1, 512, 4096]" - t1118 = prims.convert_element_type(t1117, dtypes.bfloat16) # t1118: "cuda:0 bf16[1, 512, 4096]" - t1119 = torch.nn.functional.linear(t1118, t_transformer_h_7_mlp_fc_1_weight, None) # t1119: "cuda:0 bf16[1, 512, 11008]" - # t1119 = ltorch.linear(t1118, t_transformer_h_7_mlp_fc_1_weight, None) # t1119: "cuda:0 bf16[1, 512, 11008]" - # t1119 = prims.linear(t1118, t_transformer_h_7_mlp_fc_1_weight, None) # t1119: "cuda:0 bf16[1, 512, 11008]" - t1120 = torch.nn.functional.linear(t1118, t_transformer_h_7_mlp_fc_2_weight, None) # t1120: "cuda:0 bf16[1, 512, 11008]" - # t1120 = ltorch.linear(t1118, t_transformer_h_7_mlp_fc_2_weight, None) # t1120: "cuda:0 bf16[1, 512, 11008]" - # t1120 = prims.linear(t1118, t_transformer_h_7_mlp_fc_2_weight, None) # t1120: "cuda:0 bf16[1, 512, 11008]" - t1121 = prims.convert_element_type(t1119, dtypes.float32) # t1121: "cuda:0 f32[1, 512, 11008]" - t1122 = prims.neg(t1121) # t1122: "cuda:0 f32[1, 512, 11008]" - t1123 = prims.exp(t1122) # t1123: "cuda:0 f32[1, 512, 11008]" - t1124 = ltorch.add(1.0, t1123, alpha=None) # t1124: "cuda:0 f32[1, 512, 11008]" - # t1124 = prims.add(1.0, t1123) # t1124: "cuda:0 f32[1, 512, 11008]" - t1125 = prims.reciprocal(t1124) # t1125: "cuda:0 f32[1, 512, 11008]" - t1126 = prims.convert_element_type(t1125, dtypes.bfloat16) # t1126: "cuda:0 bf16[1, 512, 11008]" - t1127 = prims.convert_element_type(t1119, dtypes.float32) # t1127: "cuda:0 f32[1, 512, 11008]" - t1128 = prims.convert_element_type(t1126, dtypes.float32) # t1128: "cuda:0 f32[1, 512, 11008]" - t1129 = ltorch.mul(t1127, t1128) # t1129: "cuda:0 f32[1, 512, 11008]" - # t1129 = prims.mul(t1127, t1128) # t1129: "cuda:0 f32[1, 512, 11008]" - t1130 = prims.convert_element_type(t1129, dtypes.bfloat16) # t1130: "cuda:0 bf16[1, 512, 11008]" - t1131 = prims.convert_element_type(t1130, dtypes.float32) # t1131: "cuda:0 f32[1, 512, 11008]" - t1132 = prims.convert_element_type(t1120, dtypes.float32) # t1132: "cuda:0 f32[1, 512, 11008]" - t1133 = ltorch.mul(t1131, t1132) # t1133: "cuda:0 f32[1, 512, 11008]" - # t1133 = prims.mul(t1131, t1132) # t1133: "cuda:0 f32[1, 512, 11008]" - t1134 = prims.convert_element_type(t1133, dtypes.bfloat16) # t1134: "cuda:0 bf16[1, 512, 11008]" - t1135 = torch.nn.functional.linear(t1134, t_transformer_h_7_mlp_proj_weight, None) # t1135: "cuda:0 bf16[1, 512, 4096]" - # t1135 = ltorch.linear(t1134, t_transformer_h_7_mlp_proj_weight, None) # t1135: "cuda:0 bf16[1, 512, 4096]" - # t1135 = prims.linear(t1134, t_transformer_h_7_mlp_proj_weight, None) # t1135: "cuda:0 bf16[1, 512, 4096]" - t1136 = prims.convert_element_type(t1135, dtypes.float32) # t1136: "cuda:0 f32[1, 512, 4096]" - t1137 = prims.convert_element_type(t1100, dtypes.float32) # t1137: "cuda:0 f32[1, 512, 4096]" - t1138 = ltorch.add(t1136, t1137, alpha=None) # t1138: "cuda:0 f32[1, 512, 4096]" - # t1138 = prims.add(t1136, t1137) # t1138: "cuda:0 f32[1, 512, 4096]" - t1139 = prims.convert_element_type(t1138, dtypes.bfloat16) # t1139: "cuda:0 bf16[1, 512, 4096]" - t1140 = prims.convert_element_type(t1139, dtypes.float32) # t1140: "cuda:0 f32[1, 512, 4096]" - t1141 = ltorch.mul(t1140, t1140) # t1141: "cuda:0 f32[1, 512, 4096]" - # t1141 = prims.mul(t1140, t1140) # t1141: "cuda:0 f32[1, 512, 4096]" - t1143 = prims.sum(t1141, (2,)) # t1143: "cuda:0 f32[1, 512]" - t1144 = prims.broadcast_in_dim(t1143, [1, 512, 1], [0, 1]) # t1144: "cuda:0 f32[1, 512, 1]" - t1146 = ltorch.true_divide(t1144, 4096.0) # t1146: "cuda:0 f32[1, 512, 1]" - # t1146 = prims.div(t1144, 4096.0) # t1146: "cuda:0 f32[1, 512, 1]" - t1148 = ltorch.add(t1146, 1e-05, alpha=None) # t1148: "cuda:0 f32[1, 512, 1]" - # t1148 = prims.add(t1146, 1e-05) # t1148: "cuda:0 f32[1, 512, 1]" - t1149 = prims.rsqrt(t1148) # t1149: "cuda:0 f32[1, 512, 1]" - t1150 = prims.broadcast_in_dim(t1149, (1, 512, 4096), (0, 1, 2)) # t1150: "cuda:0 f32[1, 512, 4096]" - t1151 = ltorch.mul(t1140, t1150) # t1151: "cuda:0 f32[1, 512, 4096]" - # t1151 = prims.mul(t1140, t1150) # t1151: "cuda:0 f32[1, 512, 4096]" - t1152 = prims.convert_element_type(t1151, dtypes.bfloat16) # t1152: "cuda:0 bf16[1, 512, 4096]" - t1153 = prims.broadcast_in_dim(t_transformer_h_8_norm_1_weight, (1, 512, 4096), (2,)) # t1153: "cuda:0 bf16[1, 512, 4096]" - t1154 = prims.convert_element_type(t1152, dtypes.float32) # t1154: "cuda:0 f32[1, 512, 4096]" - t1155 = prims.convert_element_type(t1153, dtypes.float32) # t1155: "cuda:0 f32[1, 512, 4096]" - t1156 = ltorch.mul(t1154, t1155) # t1156: "cuda:0 f32[1, 512, 4096]" - # t1156 = prims.mul(t1154, t1155) # t1156: "cuda:0 f32[1, 512, 4096]" - t1157 = prims.convert_element_type(t1156, dtypes.bfloat16) # t1157: "cuda:0 bf16[1, 512, 4096]" - t1158 = torch.nn.functional.linear(t1157, t_transformer_h_8_attn_attn_weight, None) # t1158: "cuda:0 bf16[1, 512, 12288]" - # t1158 = ltorch.linear(t1157, t_transformer_h_8_attn_attn_weight, None) # t1158: "cuda:0 bf16[1, 512, 12288]" - # t1158 = prims.linear(t1157, t_transformer_h_8_attn_attn_weight, None) # t1158: "cuda:0 bf16[1, 512, 12288]" - t1164 = prims.reshape(t1158, (1, 512, 32, 3, 128)) # t1164: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1170 = prims.transpose(t1164, (0, 2, 3, 1, 4)) # t1170: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1171, t1172, t1173) = ltorch.split(t1170, (1, 1, 1), 2) - # t1171 = prims.slice_prim(t1170, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1171: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1172 = prims.slice_prim(t1170, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1172: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1173 = prims.slice_prim(t1170, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1173: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1179 = prims.reshape(t1171, (1, 32, 512, 128)) # t1179: "cuda:0 bf16[1, 32, 512, 128]" - t1185 = prims.reshape(t1172, (1, 32, 512, 128)) # t1185: "cuda:0 bf16[1, 32, 512, 128]" - t1191 = prims.reshape(t1173, (1, 32, 512, 128)) # t1191: "cuda:0 bf16[1, 32, 512, 128]" - t1192 = prims.slice_prim(t1179, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1192: "cuda:0 bf16[1, 32, 512, 128]" - t1193 = prims.slice_prim(t1192, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1193: "cuda:0 bf16[1, 32, 512, 64]" - t1194 = prims.slice_prim(t1192, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1194: "cuda:0 bf16[1, 32, 512, 64]" - t1195 = prims.convert_element_type(t1194, dtypes.float32) # t1195: "cuda:0 f32[1, 32, 512, 64]" - t1196 = prims.neg(t1195) # t1196: "cuda:0 f32[1, 32, 512, 64]" - t1197 = prims.convert_element_type(t1196, dtypes.bfloat16) # t1197: "cuda:0 bf16[1, 32, 512, 64]" - t1199 = prims.cat((t1197, t1193), -1) # t1199: "cuda:0 bf16[1, 32, 512, 128]" - t1200 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1200: "cuda:0 f32[1, 32, 512, 128]" - t1201 = prims.convert_element_type(t1192, dtypes.float32) # t1201: "cuda:0 f32[1, 32, 512, 128]" - t1202 = ltorch.mul(t1201, t1200) # t1202: "cuda:0 f32[1, 32, 512, 128]" - # t1202 = prims.mul(t1201, t1200) # t1202: "cuda:0 f32[1, 32, 512, 128]" - t1203 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1203: "cuda:0 f32[1, 32, 512, 128]" - t1204 = prims.convert_element_type(t1199, dtypes.float32) # t1204: "cuda:0 f32[1, 32, 512, 128]" - t1205 = ltorch.mul(t1204, t1203) # t1205: "cuda:0 f32[1, 32, 512, 128]" - # t1205 = prims.mul(t1204, t1203) # t1205: "cuda:0 f32[1, 32, 512, 128]" - t1206 = ltorch.add(t1202, t1205, alpha=None) # t1206: "cuda:0 f32[1, 32, 512, 128]" - # t1206 = prims.add(t1202, t1205) # t1206: "cuda:0 f32[1, 32, 512, 128]" - t1207 = prims.convert_element_type(t1206, dtypes.bfloat16) # t1207: "cuda:0 bf16[1, 32, 512, 128]" - t1208 = prims.slice_prim(t1185, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1208: "cuda:0 bf16[1, 32, 512, 128]" - t1209 = prims.slice_prim(t1208, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1209: "cuda:0 bf16[1, 32, 512, 64]" - t1210 = prims.slice_prim(t1208, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1210: "cuda:0 bf16[1, 32, 512, 64]" - t1211 = prims.convert_element_type(t1210, dtypes.float32) # t1211: "cuda:0 f32[1, 32, 512, 64]" - t1212 = prims.neg(t1211) # t1212: "cuda:0 f32[1, 32, 512, 64]" - t1213 = prims.convert_element_type(t1212, dtypes.bfloat16) # t1213: "cuda:0 bf16[1, 32, 512, 64]" - t1215 = prims.cat((t1213, t1209), -1) # t1215: "cuda:0 bf16[1, 32, 512, 128]" - t1216 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1216: "cuda:0 f32[1, 32, 512, 128]" - t1217 = prims.convert_element_type(t1208, dtypes.float32) # t1217: "cuda:0 f32[1, 32, 512, 128]" - t1218 = ltorch.mul(t1217, t1216) # t1218: "cuda:0 f32[1, 32, 512, 128]" - # t1218 = prims.mul(t1217, t1216) # t1218: "cuda:0 f32[1, 32, 512, 128]" - t1219 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1219: "cuda:0 f32[1, 32, 512, 128]" - t1220 = prims.convert_element_type(t1215, dtypes.float32) # t1220: "cuda:0 f32[1, 32, 512, 128]" - t1221 = ltorch.mul(t1220, t1219) # t1221: "cuda:0 f32[1, 32, 512, 128]" - # t1221 = prims.mul(t1220, t1219) # t1221: "cuda:0 f32[1, 32, 512, 128]" - t1222 = ltorch.add(t1218, t1221, alpha=None) # t1222: "cuda:0 f32[1, 32, 512, 128]" - # t1222 = prims.add(t1218, t1221) # t1222: "cuda:0 f32[1, 32, 512, 128]" - t1223 = prims.convert_element_type(t1222, dtypes.bfloat16) # t1223: "cuda:0 bf16[1, 32, 512, 128]" - t1224 = prims.slice_prim(t1179, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1224: "cuda:0 bf16[1, 32, 512, 0]" - t1226 = prims.cat((t1207, t1224), -1) # t1226: "cuda:0 bf16[1, 32, 512, 128]" - t1227 = prims.slice_prim(t1185, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1227: "cuda:0 bf16[1, 32, 512, 0]" - t1229 = prims.cat((t1223, t1227), -1) # t1229: "cuda:0 bf16[1, 32, 512, 128]" - (t1230, t1231, t1232, t1233) = cudnn_sdpa_fwd(t1226, t1229, t1191, None, 0.0, True, scale=0.08838834764831843) - t1236 = prims.transpose(t1230, (0, 2, 1, 3)) # t1236: "cuda:0 bf16[1, 512, 32, 128]" - t1240 = prims.reshape(t1236, (1, 512, 4096)) # t1240: "cuda:0 bf16[1, 512, 4096]" - t1241 = torch.nn.functional.linear(t1240, t_transformer_h_8_attn_proj_weight, None) # t1241: "cuda:0 bf16[1, 512, 4096]" - # t1241 = ltorch.linear(t1240, t_transformer_h_8_attn_proj_weight, None) # t1241: "cuda:0 bf16[1, 512, 4096]" - # t1241 = prims.linear(t1240, t_transformer_h_8_attn_proj_weight, None) # t1241: "cuda:0 bf16[1, 512, 4096]" - t1242 = prims.convert_element_type(t1241, dtypes.float32) # t1242: "cuda:0 f32[1, 512, 4096]" - t1243 = prims.convert_element_type(t1139, dtypes.float32) # t1243: "cuda:0 f32[1, 512, 4096]" - t1244 = ltorch.add(t1242, t1243, alpha=None) # t1244: "cuda:0 f32[1, 512, 4096]" - # t1244 = prims.add(t1242, t1243) # t1244: "cuda:0 f32[1, 512, 4096]" - t1245 = prims.convert_element_type(t1244, dtypes.bfloat16) # t1245: "cuda:0 bf16[1, 512, 4096]" - t1246 = prims.convert_element_type(t1245, dtypes.float32) # t1246: "cuda:0 f32[1, 512, 4096]" - t1247 = ltorch.mul(t1246, t1246) # t1247: "cuda:0 f32[1, 512, 4096]" - # t1247 = prims.mul(t1246, t1246) # t1247: "cuda:0 f32[1, 512, 4096]" - t1249 = prims.sum(t1247, (2,)) # t1249: "cuda:0 f32[1, 512]" - t1250 = prims.broadcast_in_dim(t1249, [1, 512, 1], [0, 1]) # t1250: "cuda:0 f32[1, 512, 1]" - t1252 = ltorch.true_divide(t1250, 4096.0) # t1252: "cuda:0 f32[1, 512, 1]" - # t1252 = prims.div(t1250, 4096.0) # t1252: "cuda:0 f32[1, 512, 1]" - t1254 = ltorch.add(t1252, 1e-05, alpha=None) # t1254: "cuda:0 f32[1, 512, 1]" - # t1254 = prims.add(t1252, 1e-05) # t1254: "cuda:0 f32[1, 512, 1]" - t1255 = prims.rsqrt(t1254) # t1255: "cuda:0 f32[1, 512, 1]" - t1256 = prims.broadcast_in_dim(t1255, (1, 512, 4096), (0, 1, 2)) # t1256: "cuda:0 f32[1, 512, 4096]" - t1257 = ltorch.mul(t1246, t1256) # t1257: "cuda:0 f32[1, 512, 4096]" - # t1257 = prims.mul(t1246, t1256) # t1257: "cuda:0 f32[1, 512, 4096]" - t1258 = prims.convert_element_type(t1257, dtypes.bfloat16) # t1258: "cuda:0 bf16[1, 512, 4096]" - t1259 = prims.broadcast_in_dim(t_transformer_h_8_norm_2_weight, (1, 512, 4096), (2,)) # t1259: "cuda:0 bf16[1, 512, 4096]" - t1260 = prims.convert_element_type(t1258, dtypes.float32) # t1260: "cuda:0 f32[1, 512, 4096]" - t1261 = prims.convert_element_type(t1259, dtypes.float32) # t1261: "cuda:0 f32[1, 512, 4096]" - t1262 = ltorch.mul(t1260, t1261) # t1262: "cuda:0 f32[1, 512, 4096]" - # t1262 = prims.mul(t1260, t1261) # t1262: "cuda:0 f32[1, 512, 4096]" - t1263 = prims.convert_element_type(t1262, dtypes.bfloat16) # t1263: "cuda:0 bf16[1, 512, 4096]" - t1264 = torch.nn.functional.linear(t1263, t_transformer_h_8_mlp_fc_1_weight, None) # t1264: "cuda:0 bf16[1, 512, 11008]" - # t1264 = ltorch.linear(t1263, t_transformer_h_8_mlp_fc_1_weight, None) # t1264: "cuda:0 bf16[1, 512, 11008]" - # t1264 = prims.linear(t1263, t_transformer_h_8_mlp_fc_1_weight, None) # t1264: "cuda:0 bf16[1, 512, 11008]" - t1265 = torch.nn.functional.linear(t1263, t_transformer_h_8_mlp_fc_2_weight, None) # t1265: "cuda:0 bf16[1, 512, 11008]" - # t1265 = ltorch.linear(t1263, t_transformer_h_8_mlp_fc_2_weight, None) # t1265: "cuda:0 bf16[1, 512, 11008]" - # t1265 = prims.linear(t1263, t_transformer_h_8_mlp_fc_2_weight, None) # t1265: "cuda:0 bf16[1, 512, 11008]" - t1266 = prims.convert_element_type(t1264, dtypes.float32) # t1266: "cuda:0 f32[1, 512, 11008]" - t1267 = prims.neg(t1266) # t1267: "cuda:0 f32[1, 512, 11008]" - t1268 = prims.exp(t1267) # t1268: "cuda:0 f32[1, 512, 11008]" - t1269 = ltorch.add(1.0, t1268, alpha=None) # t1269: "cuda:0 f32[1, 512, 11008]" - # t1269 = prims.add(1.0, t1268) # t1269: "cuda:0 f32[1, 512, 11008]" - t1270 = prims.reciprocal(t1269) # t1270: "cuda:0 f32[1, 512, 11008]" - t1271 = prims.convert_element_type(t1270, dtypes.bfloat16) # t1271: "cuda:0 bf16[1, 512, 11008]" - t1272 = prims.convert_element_type(t1264, dtypes.float32) # t1272: "cuda:0 f32[1, 512, 11008]" - t1273 = prims.convert_element_type(t1271, dtypes.float32) # t1273: "cuda:0 f32[1, 512, 11008]" - t1274 = ltorch.mul(t1272, t1273) # t1274: "cuda:0 f32[1, 512, 11008]" - # t1274 = prims.mul(t1272, t1273) # t1274: "cuda:0 f32[1, 512, 11008]" - t1275 = prims.convert_element_type(t1274, dtypes.bfloat16) # t1275: "cuda:0 bf16[1, 512, 11008]" - t1276 = prims.convert_element_type(t1275, dtypes.float32) # t1276: "cuda:0 f32[1, 512, 11008]" - t1277 = prims.convert_element_type(t1265, dtypes.float32) # t1277: "cuda:0 f32[1, 512, 11008]" - t1278 = ltorch.mul(t1276, t1277) # t1278: "cuda:0 f32[1, 512, 11008]" - # t1278 = prims.mul(t1276, t1277) # t1278: "cuda:0 f32[1, 512, 11008]" - t1279 = prims.convert_element_type(t1278, dtypes.bfloat16) # t1279: "cuda:0 bf16[1, 512, 11008]" - t1280 = torch.nn.functional.linear(t1279, t_transformer_h_8_mlp_proj_weight, None) # t1280: "cuda:0 bf16[1, 512, 4096]" - # t1280 = ltorch.linear(t1279, t_transformer_h_8_mlp_proj_weight, None) # t1280: "cuda:0 bf16[1, 512, 4096]" - # t1280 = prims.linear(t1279, t_transformer_h_8_mlp_proj_weight, None) # t1280: "cuda:0 bf16[1, 512, 4096]" - t1281 = prims.convert_element_type(t1280, dtypes.float32) # t1281: "cuda:0 f32[1, 512, 4096]" - t1282 = prims.convert_element_type(t1245, dtypes.float32) # t1282: "cuda:0 f32[1, 512, 4096]" - t1283 = ltorch.add(t1281, t1282, alpha=None) # t1283: "cuda:0 f32[1, 512, 4096]" - # t1283 = prims.add(t1281, t1282) # t1283: "cuda:0 f32[1, 512, 4096]" - t1284 = prims.convert_element_type(t1283, dtypes.bfloat16) # t1284: "cuda:0 bf16[1, 512, 4096]" - t1285 = prims.convert_element_type(t1284, dtypes.float32) # t1285: "cuda:0 f32[1, 512, 4096]" - t1286 = ltorch.mul(t1285, t1285) # t1286: "cuda:0 f32[1, 512, 4096]" - # t1286 = prims.mul(t1285, t1285) # t1286: "cuda:0 f32[1, 512, 4096]" - t1288 = prims.sum(t1286, (2,)) # t1288: "cuda:0 f32[1, 512]" - t1289 = prims.broadcast_in_dim(t1288, [1, 512, 1], [0, 1]) # t1289: "cuda:0 f32[1, 512, 1]" - t1291 = ltorch.true_divide(t1289, 4096.0) # t1291: "cuda:0 f32[1, 512, 1]" - # t1291 = prims.div(t1289, 4096.0) # t1291: "cuda:0 f32[1, 512, 1]" - t1293 = ltorch.add(t1291, 1e-05, alpha=None) # t1293: "cuda:0 f32[1, 512, 1]" - # t1293 = prims.add(t1291, 1e-05) # t1293: "cuda:0 f32[1, 512, 1]" - t1294 = prims.rsqrt(t1293) # t1294: "cuda:0 f32[1, 512, 1]" - t1295 = prims.broadcast_in_dim(t1294, (1, 512, 4096), (0, 1, 2)) # t1295: "cuda:0 f32[1, 512, 4096]" - t1296 = ltorch.mul(t1285, t1295) # t1296: "cuda:0 f32[1, 512, 4096]" - # t1296 = prims.mul(t1285, t1295) # t1296: "cuda:0 f32[1, 512, 4096]" - t1297 = prims.convert_element_type(t1296, dtypes.bfloat16) # t1297: "cuda:0 bf16[1, 512, 4096]" - t1298 = prims.broadcast_in_dim(t_transformer_h_9_norm_1_weight, (1, 512, 4096), (2,)) # t1298: "cuda:0 bf16[1, 512, 4096]" - t1299 = prims.convert_element_type(t1297, dtypes.float32) # t1299: "cuda:0 f32[1, 512, 4096]" - t1300 = prims.convert_element_type(t1298, dtypes.float32) # t1300: "cuda:0 f32[1, 512, 4096]" - t1301 = ltorch.mul(t1299, t1300) # t1301: "cuda:0 f32[1, 512, 4096]" - # t1301 = prims.mul(t1299, t1300) # t1301: "cuda:0 f32[1, 512, 4096]" - t1302 = prims.convert_element_type(t1301, dtypes.bfloat16) # t1302: "cuda:0 bf16[1, 512, 4096]" - t1303 = torch.nn.functional.linear(t1302, t_transformer_h_9_attn_attn_weight, None) # t1303: "cuda:0 bf16[1, 512, 12288]" - # t1303 = ltorch.linear(t1302, t_transformer_h_9_attn_attn_weight, None) # t1303: "cuda:0 bf16[1, 512, 12288]" - # t1303 = prims.linear(t1302, t_transformer_h_9_attn_attn_weight, None) # t1303: "cuda:0 bf16[1, 512, 12288]" - t1309 = prims.reshape(t1303, (1, 512, 32, 3, 128)) # t1309: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1315 = prims.transpose(t1309, (0, 2, 3, 1, 4)) # t1315: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1316, t1317, t1318) = ltorch.split(t1315, (1, 1, 1), 2) - # t1316 = prims.slice_prim(t1315, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1316: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1317 = prims.slice_prim(t1315, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1317: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1318 = prims.slice_prim(t1315, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1318: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1324 = prims.reshape(t1316, (1, 32, 512, 128)) # t1324: "cuda:0 bf16[1, 32, 512, 128]" - t1330 = prims.reshape(t1317, (1, 32, 512, 128)) # t1330: "cuda:0 bf16[1, 32, 512, 128]" - t1336 = prims.reshape(t1318, (1, 32, 512, 128)) # t1336: "cuda:0 bf16[1, 32, 512, 128]" - t1337 = prims.slice_prim(t1324, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1337: "cuda:0 bf16[1, 32, 512, 128]" - t1338 = prims.slice_prim(t1337, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1338: "cuda:0 bf16[1, 32, 512, 64]" - t1339 = prims.slice_prim(t1337, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1339: "cuda:0 bf16[1, 32, 512, 64]" - t1340 = prims.convert_element_type(t1339, dtypes.float32) # t1340: "cuda:0 f32[1, 32, 512, 64]" - t1341 = prims.neg(t1340) # t1341: "cuda:0 f32[1, 32, 512, 64]" - t1342 = prims.convert_element_type(t1341, dtypes.bfloat16) # t1342: "cuda:0 bf16[1, 32, 512, 64]" - t1344 = prims.cat((t1342, t1338), -1) # t1344: "cuda:0 bf16[1, 32, 512, 128]" - t1345 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1345: "cuda:0 f32[1, 32, 512, 128]" - t1346 = prims.convert_element_type(t1337, dtypes.float32) # t1346: "cuda:0 f32[1, 32, 512, 128]" - t1347 = ltorch.mul(t1346, t1345) # t1347: "cuda:0 f32[1, 32, 512, 128]" - # t1347 = prims.mul(t1346, t1345) # t1347: "cuda:0 f32[1, 32, 512, 128]" - t1348 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1348: "cuda:0 f32[1, 32, 512, 128]" - t1349 = prims.convert_element_type(t1344, dtypes.float32) # t1349: "cuda:0 f32[1, 32, 512, 128]" - t1350 = ltorch.mul(t1349, t1348) # t1350: "cuda:0 f32[1, 32, 512, 128]" - # t1350 = prims.mul(t1349, t1348) # t1350: "cuda:0 f32[1, 32, 512, 128]" - t1351 = ltorch.add(t1347, t1350, alpha=None) # t1351: "cuda:0 f32[1, 32, 512, 128]" - # t1351 = prims.add(t1347, t1350) # t1351: "cuda:0 f32[1, 32, 512, 128]" - t1352 = prims.convert_element_type(t1351, dtypes.bfloat16) # t1352: "cuda:0 bf16[1, 32, 512, 128]" - t1353 = prims.slice_prim(t1330, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1353: "cuda:0 bf16[1, 32, 512, 128]" - t1354 = prims.slice_prim(t1353, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1354: "cuda:0 bf16[1, 32, 512, 64]" - t1355 = prims.slice_prim(t1353, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1355: "cuda:0 bf16[1, 32, 512, 64]" - t1356 = prims.convert_element_type(t1355, dtypes.float32) # t1356: "cuda:0 f32[1, 32, 512, 64]" - t1357 = prims.neg(t1356) # t1357: "cuda:0 f32[1, 32, 512, 64]" - t1358 = prims.convert_element_type(t1357, dtypes.bfloat16) # t1358: "cuda:0 bf16[1, 32, 512, 64]" - t1360 = prims.cat((t1358, t1354), -1) # t1360: "cuda:0 bf16[1, 32, 512, 128]" - t1361 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1361: "cuda:0 f32[1, 32, 512, 128]" - t1362 = prims.convert_element_type(t1353, dtypes.float32) # t1362: "cuda:0 f32[1, 32, 512, 128]" - t1363 = ltorch.mul(t1362, t1361) # t1363: "cuda:0 f32[1, 32, 512, 128]" - # t1363 = prims.mul(t1362, t1361) # t1363: "cuda:0 f32[1, 32, 512, 128]" - t1364 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1364: "cuda:0 f32[1, 32, 512, 128]" - t1365 = prims.convert_element_type(t1360, dtypes.float32) # t1365: "cuda:0 f32[1, 32, 512, 128]" - t1366 = ltorch.mul(t1365, t1364) # t1366: "cuda:0 f32[1, 32, 512, 128]" - # t1366 = prims.mul(t1365, t1364) # t1366: "cuda:0 f32[1, 32, 512, 128]" - t1367 = ltorch.add(t1363, t1366, alpha=None) # t1367: "cuda:0 f32[1, 32, 512, 128]" - # t1367 = prims.add(t1363, t1366) # t1367: "cuda:0 f32[1, 32, 512, 128]" - t1368 = prims.convert_element_type(t1367, dtypes.bfloat16) # t1368: "cuda:0 bf16[1, 32, 512, 128]" - t1369 = prims.slice_prim(t1324, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1369: "cuda:0 bf16[1, 32, 512, 0]" - t1371 = prims.cat((t1352, t1369), -1) # t1371: "cuda:0 bf16[1, 32, 512, 128]" - t1372 = prims.slice_prim(t1330, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1372: "cuda:0 bf16[1, 32, 512, 0]" - t1374 = prims.cat((t1368, t1372), -1) # t1374: "cuda:0 bf16[1, 32, 512, 128]" - (t1375, t1376, t1377, t1378) = cudnn_sdpa_fwd(t1371, t1374, t1336, None, 0.0, True, scale=0.08838834764831843) - t1381 = prims.transpose(t1375, (0, 2, 1, 3)) # t1381: "cuda:0 bf16[1, 512, 32, 128]" - t1385 = prims.reshape(t1381, (1, 512, 4096)) # t1385: "cuda:0 bf16[1, 512, 4096]" - t1386 = torch.nn.functional.linear(t1385, t_transformer_h_9_attn_proj_weight, None) # t1386: "cuda:0 bf16[1, 512, 4096]" - # t1386 = ltorch.linear(t1385, t_transformer_h_9_attn_proj_weight, None) # t1386: "cuda:0 bf16[1, 512, 4096]" - # t1386 = prims.linear(t1385, t_transformer_h_9_attn_proj_weight, None) # t1386: "cuda:0 bf16[1, 512, 4096]" - t1387 = prims.convert_element_type(t1386, dtypes.float32) # t1387: "cuda:0 f32[1, 512, 4096]" - t1388 = prims.convert_element_type(t1284, dtypes.float32) # t1388: "cuda:0 f32[1, 512, 4096]" - t1389 = ltorch.add(t1387, t1388, alpha=None) # t1389: "cuda:0 f32[1, 512, 4096]" - # t1389 = prims.add(t1387, t1388) # t1389: "cuda:0 f32[1, 512, 4096]" - t1390 = prims.convert_element_type(t1389, dtypes.bfloat16) # t1390: "cuda:0 bf16[1, 512, 4096]" - t1391 = prims.convert_element_type(t1390, dtypes.float32) # t1391: "cuda:0 f32[1, 512, 4096]" - t1392 = ltorch.mul(t1391, t1391) # t1392: "cuda:0 f32[1, 512, 4096]" - # t1392 = prims.mul(t1391, t1391) # t1392: "cuda:0 f32[1, 512, 4096]" - t1394 = prims.sum(t1392, (2,)) # t1394: "cuda:0 f32[1, 512]" - t1395 = prims.broadcast_in_dim(t1394, [1, 512, 1], [0, 1]) # t1395: "cuda:0 f32[1, 512, 1]" - t1397 = ltorch.true_divide(t1395, 4096.0) # t1397: "cuda:0 f32[1, 512, 1]" - # t1397 = prims.div(t1395, 4096.0) # t1397: "cuda:0 f32[1, 512, 1]" - t1399 = ltorch.add(t1397, 1e-05, alpha=None) # t1399: "cuda:0 f32[1, 512, 1]" - # t1399 = prims.add(t1397, 1e-05) # t1399: "cuda:0 f32[1, 512, 1]" - t1400 = prims.rsqrt(t1399) # t1400: "cuda:0 f32[1, 512, 1]" - t1401 = prims.broadcast_in_dim(t1400, (1, 512, 4096), (0, 1, 2)) # t1401: "cuda:0 f32[1, 512, 4096]" - t1402 = ltorch.mul(t1391, t1401) # t1402: "cuda:0 f32[1, 512, 4096]" - # t1402 = prims.mul(t1391, t1401) # t1402: "cuda:0 f32[1, 512, 4096]" - t1403 = prims.convert_element_type(t1402, dtypes.bfloat16) # t1403: "cuda:0 bf16[1, 512, 4096]" - t1404 = prims.broadcast_in_dim(t_transformer_h_9_norm_2_weight, (1, 512, 4096), (2,)) # t1404: "cuda:0 bf16[1, 512, 4096]" - t1405 = prims.convert_element_type(t1403, dtypes.float32) # t1405: "cuda:0 f32[1, 512, 4096]" - t1406 = prims.convert_element_type(t1404, dtypes.float32) # t1406: "cuda:0 f32[1, 512, 4096]" - t1407 = ltorch.mul(t1405, t1406) # t1407: "cuda:0 f32[1, 512, 4096]" - # t1407 = prims.mul(t1405, t1406) # t1407: "cuda:0 f32[1, 512, 4096]" - t1408 = prims.convert_element_type(t1407, dtypes.bfloat16) # t1408: "cuda:0 bf16[1, 512, 4096]" - t1409 = torch.nn.functional.linear(t1408, t_transformer_h_9_mlp_fc_1_weight, None) # t1409: "cuda:0 bf16[1, 512, 11008]" - # t1409 = ltorch.linear(t1408, t_transformer_h_9_mlp_fc_1_weight, None) # t1409: "cuda:0 bf16[1, 512, 11008]" - # t1409 = prims.linear(t1408, t_transformer_h_9_mlp_fc_1_weight, None) # t1409: "cuda:0 bf16[1, 512, 11008]" - t1410 = torch.nn.functional.linear(t1408, t_transformer_h_9_mlp_fc_2_weight, None) # t1410: "cuda:0 bf16[1, 512, 11008]" - # t1410 = ltorch.linear(t1408, t_transformer_h_9_mlp_fc_2_weight, None) # t1410: "cuda:0 bf16[1, 512, 11008]" - # t1410 = prims.linear(t1408, t_transformer_h_9_mlp_fc_2_weight, None) # t1410: "cuda:0 bf16[1, 512, 11008]" - t1411 = prims.convert_element_type(t1409, dtypes.float32) # t1411: "cuda:0 f32[1, 512, 11008]" - t1412 = prims.neg(t1411) # t1412: "cuda:0 f32[1, 512, 11008]" - t1413 = prims.exp(t1412) # t1413: "cuda:0 f32[1, 512, 11008]" - t1414 = ltorch.add(1.0, t1413, alpha=None) # t1414: "cuda:0 f32[1, 512, 11008]" - # t1414 = prims.add(1.0, t1413) # t1414: "cuda:0 f32[1, 512, 11008]" - t1415 = prims.reciprocal(t1414) # t1415: "cuda:0 f32[1, 512, 11008]" - t1416 = prims.convert_element_type(t1415, dtypes.bfloat16) # t1416: "cuda:0 bf16[1, 512, 11008]" - t1417 = prims.convert_element_type(t1409, dtypes.float32) # t1417: "cuda:0 f32[1, 512, 11008]" - t1418 = prims.convert_element_type(t1416, dtypes.float32) # t1418: "cuda:0 f32[1, 512, 11008]" - t1419 = ltorch.mul(t1417, t1418) # t1419: "cuda:0 f32[1, 512, 11008]" - # t1419 = prims.mul(t1417, t1418) # t1419: "cuda:0 f32[1, 512, 11008]" - t1420 = prims.convert_element_type(t1419, dtypes.bfloat16) # t1420: "cuda:0 bf16[1, 512, 11008]" - t1421 = prims.convert_element_type(t1420, dtypes.float32) # t1421: "cuda:0 f32[1, 512, 11008]" - t1422 = prims.convert_element_type(t1410, dtypes.float32) # t1422: "cuda:0 f32[1, 512, 11008]" - t1423 = ltorch.mul(t1421, t1422) # t1423: "cuda:0 f32[1, 512, 11008]" - # t1423 = prims.mul(t1421, t1422) # t1423: "cuda:0 f32[1, 512, 11008]" - t1424 = prims.convert_element_type(t1423, dtypes.bfloat16) # t1424: "cuda:0 bf16[1, 512, 11008]" - t1425 = torch.nn.functional.linear(t1424, t_transformer_h_9_mlp_proj_weight, None) # t1425: "cuda:0 bf16[1, 512, 4096]" - # t1425 = ltorch.linear(t1424, t_transformer_h_9_mlp_proj_weight, None) # t1425: "cuda:0 bf16[1, 512, 4096]" - # t1425 = prims.linear(t1424, t_transformer_h_9_mlp_proj_weight, None) # t1425: "cuda:0 bf16[1, 512, 4096]" - t1426 = prims.convert_element_type(t1425, dtypes.float32) # t1426: "cuda:0 f32[1, 512, 4096]" - t1427 = prims.convert_element_type(t1390, dtypes.float32) # t1427: "cuda:0 f32[1, 512, 4096]" - t1428 = ltorch.add(t1426, t1427, alpha=None) # t1428: "cuda:0 f32[1, 512, 4096]" - # t1428 = prims.add(t1426, t1427) # t1428: "cuda:0 f32[1, 512, 4096]" - t1429 = prims.convert_element_type(t1428, dtypes.bfloat16) # t1429: "cuda:0 bf16[1, 512, 4096]" - t1430 = prims.convert_element_type(t1429, dtypes.float32) # t1430: "cuda:0 f32[1, 512, 4096]" - t1431 = ltorch.mul(t1430, t1430) # t1431: "cuda:0 f32[1, 512, 4096]" - # t1431 = prims.mul(t1430, t1430) # t1431: "cuda:0 f32[1, 512, 4096]" - t1433 = prims.sum(t1431, (2,)) # t1433: "cuda:0 f32[1, 512]" - t1434 = prims.broadcast_in_dim(t1433, [1, 512, 1], [0, 1]) # t1434: "cuda:0 f32[1, 512, 1]" - t1436 = ltorch.true_divide(t1434, 4096.0) # t1436: "cuda:0 f32[1, 512, 1]" - # t1436 = prims.div(t1434, 4096.0) # t1436: "cuda:0 f32[1, 512, 1]" - t1438 = ltorch.add(t1436, 1e-05, alpha=None) # t1438: "cuda:0 f32[1, 512, 1]" - # t1438 = prims.add(t1436, 1e-05) # t1438: "cuda:0 f32[1, 512, 1]" - t1439 = prims.rsqrt(t1438) # t1439: "cuda:0 f32[1, 512, 1]" - t1440 = prims.broadcast_in_dim(t1439, (1, 512, 4096), (0, 1, 2)) # t1440: "cuda:0 f32[1, 512, 4096]" - t1441 = ltorch.mul(t1430, t1440) # t1441: "cuda:0 f32[1, 512, 4096]" - # t1441 = prims.mul(t1430, t1440) # t1441: "cuda:0 f32[1, 512, 4096]" - t1442 = prims.convert_element_type(t1441, dtypes.bfloat16) # t1442: "cuda:0 bf16[1, 512, 4096]" - t1443 = prims.broadcast_in_dim(t_transformer_h_10_norm_1_weight, (1, 512, 4096), (2,)) # t1443: "cuda:0 bf16[1, 512, 4096]" - t1444 = prims.convert_element_type(t1442, dtypes.float32) # t1444: "cuda:0 f32[1, 512, 4096]" - t1445 = prims.convert_element_type(t1443, dtypes.float32) # t1445: "cuda:0 f32[1, 512, 4096]" - t1446 = ltorch.mul(t1444, t1445) # t1446: "cuda:0 f32[1, 512, 4096]" - # t1446 = prims.mul(t1444, t1445) # t1446: "cuda:0 f32[1, 512, 4096]" - t1447 = prims.convert_element_type(t1446, dtypes.bfloat16) # t1447: "cuda:0 bf16[1, 512, 4096]" - t1448 = torch.nn.functional.linear(t1447, t_transformer_h_10_attn_attn_weight, None) # t1448: "cuda:0 bf16[1, 512, 12288]" - # t1448 = ltorch.linear(t1447, t_transformer_h_10_attn_attn_weight, None) # t1448: "cuda:0 bf16[1, 512, 12288]" - # t1448 = prims.linear(t1447, t_transformer_h_10_attn_attn_weight, None) # t1448: "cuda:0 bf16[1, 512, 12288]" - t1454 = prims.reshape(t1448, (1, 512, 32, 3, 128)) # t1454: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1460 = prims.transpose(t1454, (0, 2, 3, 1, 4)) # t1460: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1461, t1462, t1463) = ltorch.split(t1460, (1, 1, 1), 2) - # t1461 = prims.slice_prim(t1460, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1461: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1462 = prims.slice_prim(t1460, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1462: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1463 = prims.slice_prim(t1460, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1463: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1469 = prims.reshape(t1461, (1, 32, 512, 128)) # t1469: "cuda:0 bf16[1, 32, 512, 128]" - t1475 = prims.reshape(t1462, (1, 32, 512, 128)) # t1475: "cuda:0 bf16[1, 32, 512, 128]" - t1481 = prims.reshape(t1463, (1, 32, 512, 128)) # t1481: "cuda:0 bf16[1, 32, 512, 128]" - t1482 = prims.slice_prim(t1469, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1482: "cuda:0 bf16[1, 32, 512, 128]" - t1483 = prims.slice_prim(t1482, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1483: "cuda:0 bf16[1, 32, 512, 64]" - t1484 = prims.slice_prim(t1482, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1484: "cuda:0 bf16[1, 32, 512, 64]" - t1485 = prims.convert_element_type(t1484, dtypes.float32) # t1485: "cuda:0 f32[1, 32, 512, 64]" - t1486 = prims.neg(t1485) # t1486: "cuda:0 f32[1, 32, 512, 64]" - t1487 = prims.convert_element_type(t1486, dtypes.bfloat16) # t1487: "cuda:0 bf16[1, 32, 512, 64]" - t1489 = prims.cat((t1487, t1483), -1) # t1489: "cuda:0 bf16[1, 32, 512, 128]" - t1490 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1490: "cuda:0 f32[1, 32, 512, 128]" - t1491 = prims.convert_element_type(t1482, dtypes.float32) # t1491: "cuda:0 f32[1, 32, 512, 128]" - t1492 = ltorch.mul(t1491, t1490) # t1492: "cuda:0 f32[1, 32, 512, 128]" - # t1492 = prims.mul(t1491, t1490) # t1492: "cuda:0 f32[1, 32, 512, 128]" - t1493 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1493: "cuda:0 f32[1, 32, 512, 128]" - t1494 = prims.convert_element_type(t1489, dtypes.float32) # t1494: "cuda:0 f32[1, 32, 512, 128]" - t1495 = ltorch.mul(t1494, t1493) # t1495: "cuda:0 f32[1, 32, 512, 128]" - # t1495 = prims.mul(t1494, t1493) # t1495: "cuda:0 f32[1, 32, 512, 128]" - t1496 = ltorch.add(t1492, t1495, alpha=None) # t1496: "cuda:0 f32[1, 32, 512, 128]" - # t1496 = prims.add(t1492, t1495) # t1496: "cuda:0 f32[1, 32, 512, 128]" - t1497 = prims.convert_element_type(t1496, dtypes.bfloat16) # t1497: "cuda:0 bf16[1, 32, 512, 128]" - t1498 = prims.slice_prim(t1475, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1498: "cuda:0 bf16[1, 32, 512, 128]" - t1499 = prims.slice_prim(t1498, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1499: "cuda:0 bf16[1, 32, 512, 64]" - t1500 = prims.slice_prim(t1498, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1500: "cuda:0 bf16[1, 32, 512, 64]" - t1501 = prims.convert_element_type(t1500, dtypes.float32) # t1501: "cuda:0 f32[1, 32, 512, 64]" - t1502 = prims.neg(t1501) # t1502: "cuda:0 f32[1, 32, 512, 64]" - t1503 = prims.convert_element_type(t1502, dtypes.bfloat16) # t1503: "cuda:0 bf16[1, 32, 512, 64]" - t1505 = prims.cat((t1503, t1499), -1) # t1505: "cuda:0 bf16[1, 32, 512, 128]" - t1506 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1506: "cuda:0 f32[1, 32, 512, 128]" - t1507 = prims.convert_element_type(t1498, dtypes.float32) # t1507: "cuda:0 f32[1, 32, 512, 128]" - t1508 = ltorch.mul(t1507, t1506) # t1508: "cuda:0 f32[1, 32, 512, 128]" - # t1508 = prims.mul(t1507, t1506) # t1508: "cuda:0 f32[1, 32, 512, 128]" - t1509 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1509: "cuda:0 f32[1, 32, 512, 128]" - t1510 = prims.convert_element_type(t1505, dtypes.float32) # t1510: "cuda:0 f32[1, 32, 512, 128]" - t1511 = ltorch.mul(t1510, t1509) # t1511: "cuda:0 f32[1, 32, 512, 128]" - # t1511 = prims.mul(t1510, t1509) # t1511: "cuda:0 f32[1, 32, 512, 128]" - t1512 = ltorch.add(t1508, t1511, alpha=None) # t1512: "cuda:0 f32[1, 32, 512, 128]" - # t1512 = prims.add(t1508, t1511) # t1512: "cuda:0 f32[1, 32, 512, 128]" - t1513 = prims.convert_element_type(t1512, dtypes.bfloat16) # t1513: "cuda:0 bf16[1, 32, 512, 128]" - t1514 = prims.slice_prim(t1469, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1514: "cuda:0 bf16[1, 32, 512, 0]" - t1516 = prims.cat((t1497, t1514), -1) # t1516: "cuda:0 bf16[1, 32, 512, 128]" - t1517 = prims.slice_prim(t1475, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1517: "cuda:0 bf16[1, 32, 512, 0]" - t1519 = prims.cat((t1513, t1517), -1) # t1519: "cuda:0 bf16[1, 32, 512, 128]" - (t1520, t1521, t1522, t1523) = cudnn_sdpa_fwd(t1516, t1519, t1481, None, 0.0, True, scale=0.08838834764831843) - t1526 = prims.transpose(t1520, (0, 2, 1, 3)) # t1526: "cuda:0 bf16[1, 512, 32, 128]" - t1530 = prims.reshape(t1526, (1, 512, 4096)) # t1530: "cuda:0 bf16[1, 512, 4096]" - t1531 = torch.nn.functional.linear(t1530, t_transformer_h_10_attn_proj_weight, None) # t1531: "cuda:0 bf16[1, 512, 4096]" - # t1531 = ltorch.linear(t1530, t_transformer_h_10_attn_proj_weight, None) # t1531: "cuda:0 bf16[1, 512, 4096]" - # t1531 = prims.linear(t1530, t_transformer_h_10_attn_proj_weight, None) # t1531: "cuda:0 bf16[1, 512, 4096]" - t1532 = prims.convert_element_type(t1531, dtypes.float32) # t1532: "cuda:0 f32[1, 512, 4096]" - t1533 = prims.convert_element_type(t1429, dtypes.float32) # t1533: "cuda:0 f32[1, 512, 4096]" - t1534 = ltorch.add(t1532, t1533, alpha=None) # t1534: "cuda:0 f32[1, 512, 4096]" - # t1534 = prims.add(t1532, t1533) # t1534: "cuda:0 f32[1, 512, 4096]" - t1535 = prims.convert_element_type(t1534, dtypes.bfloat16) # t1535: "cuda:0 bf16[1, 512, 4096]" - t1536 = prims.convert_element_type(t1535, dtypes.float32) # t1536: "cuda:0 f32[1, 512, 4096]" - t1537 = ltorch.mul(t1536, t1536) # t1537: "cuda:0 f32[1, 512, 4096]" - # t1537 = prims.mul(t1536, t1536) # t1537: "cuda:0 f32[1, 512, 4096]" - t1539 = prims.sum(t1537, (2,)) # t1539: "cuda:0 f32[1, 512]" - t1540 = prims.broadcast_in_dim(t1539, [1, 512, 1], [0, 1]) # t1540: "cuda:0 f32[1, 512, 1]" - t1542 = ltorch.true_divide(t1540, 4096.0) # t1542: "cuda:0 f32[1, 512, 1]" - # t1542 = prims.div(t1540, 4096.0) # t1542: "cuda:0 f32[1, 512, 1]" - t1544 = ltorch.add(t1542, 1e-05, alpha=None) # t1544: "cuda:0 f32[1, 512, 1]" - # t1544 = prims.add(t1542, 1e-05) # t1544: "cuda:0 f32[1, 512, 1]" - t1545 = prims.rsqrt(t1544) # t1545: "cuda:0 f32[1, 512, 1]" - t1546 = prims.broadcast_in_dim(t1545, (1, 512, 4096), (0, 1, 2)) # t1546: "cuda:0 f32[1, 512, 4096]" - t1547 = ltorch.mul(t1536, t1546) # t1547: "cuda:0 f32[1, 512, 4096]" - # t1547 = prims.mul(t1536, t1546) # t1547: "cuda:0 f32[1, 512, 4096]" - t1548 = prims.convert_element_type(t1547, dtypes.bfloat16) # t1548: "cuda:0 bf16[1, 512, 4096]" - t1549 = prims.broadcast_in_dim(t_transformer_h_10_norm_2_weight, (1, 512, 4096), (2,)) # t1549: "cuda:0 bf16[1, 512, 4096]" - t1550 = prims.convert_element_type(t1548, dtypes.float32) # t1550: "cuda:0 f32[1, 512, 4096]" - t1551 = prims.convert_element_type(t1549, dtypes.float32) # t1551: "cuda:0 f32[1, 512, 4096]" - t1552 = ltorch.mul(t1550, t1551) # t1552: "cuda:0 f32[1, 512, 4096]" - # t1552 = prims.mul(t1550, t1551) # t1552: "cuda:0 f32[1, 512, 4096]" - t1553 = prims.convert_element_type(t1552, dtypes.bfloat16) # t1553: "cuda:0 bf16[1, 512, 4096]" - t1554 = torch.nn.functional.linear(t1553, t_transformer_h_10_mlp_fc_1_weight, None) # t1554: "cuda:0 bf16[1, 512, 11008]" - # t1554 = ltorch.linear(t1553, t_transformer_h_10_mlp_fc_1_weight, None) # t1554: "cuda:0 bf16[1, 512, 11008]" - # t1554 = prims.linear(t1553, t_transformer_h_10_mlp_fc_1_weight, None) # t1554: "cuda:0 bf16[1, 512, 11008]" - t1555 = torch.nn.functional.linear(t1553, t_transformer_h_10_mlp_fc_2_weight, None) # t1555: "cuda:0 bf16[1, 512, 11008]" - # t1555 = ltorch.linear(t1553, t_transformer_h_10_mlp_fc_2_weight, None) # t1555: "cuda:0 bf16[1, 512, 11008]" - # t1555 = prims.linear(t1553, t_transformer_h_10_mlp_fc_2_weight, None) # t1555: "cuda:0 bf16[1, 512, 11008]" - t1556 = prims.convert_element_type(t1554, dtypes.float32) # t1556: "cuda:0 f32[1, 512, 11008]" - t1557 = prims.neg(t1556) # t1557: "cuda:0 f32[1, 512, 11008]" - t1558 = prims.exp(t1557) # t1558: "cuda:0 f32[1, 512, 11008]" - t1559 = ltorch.add(1.0, t1558, alpha=None) # t1559: "cuda:0 f32[1, 512, 11008]" - # t1559 = prims.add(1.0, t1558) # t1559: "cuda:0 f32[1, 512, 11008]" - t1560 = prims.reciprocal(t1559) # t1560: "cuda:0 f32[1, 512, 11008]" - t1561 = prims.convert_element_type(t1560, dtypes.bfloat16) # t1561: "cuda:0 bf16[1, 512, 11008]" - t1562 = prims.convert_element_type(t1554, dtypes.float32) # t1562: "cuda:0 f32[1, 512, 11008]" - t1563 = prims.convert_element_type(t1561, dtypes.float32) # t1563: "cuda:0 f32[1, 512, 11008]" - t1564 = ltorch.mul(t1562, t1563) # t1564: "cuda:0 f32[1, 512, 11008]" - # t1564 = prims.mul(t1562, t1563) # t1564: "cuda:0 f32[1, 512, 11008]" - t1565 = prims.convert_element_type(t1564, dtypes.bfloat16) # t1565: "cuda:0 bf16[1, 512, 11008]" - t1566 = prims.convert_element_type(t1565, dtypes.float32) # t1566: "cuda:0 f32[1, 512, 11008]" - t1567 = prims.convert_element_type(t1555, dtypes.float32) # t1567: "cuda:0 f32[1, 512, 11008]" - t1568 = ltorch.mul(t1566, t1567) # t1568: "cuda:0 f32[1, 512, 11008]" - # t1568 = prims.mul(t1566, t1567) # t1568: "cuda:0 f32[1, 512, 11008]" - t1569 = prims.convert_element_type(t1568, dtypes.bfloat16) # t1569: "cuda:0 bf16[1, 512, 11008]" - t1570 = torch.nn.functional.linear(t1569, t_transformer_h_10_mlp_proj_weight, None) # t1570: "cuda:0 bf16[1, 512, 4096]" - # t1570 = ltorch.linear(t1569, t_transformer_h_10_mlp_proj_weight, None) # t1570: "cuda:0 bf16[1, 512, 4096]" - # t1570 = prims.linear(t1569, t_transformer_h_10_mlp_proj_weight, None) # t1570: "cuda:0 bf16[1, 512, 4096]" - t1571 = prims.convert_element_type(t1570, dtypes.float32) # t1571: "cuda:0 f32[1, 512, 4096]" - t1572 = prims.convert_element_type(t1535, dtypes.float32) # t1572: "cuda:0 f32[1, 512, 4096]" - t1573 = ltorch.add(t1571, t1572, alpha=None) # t1573: "cuda:0 f32[1, 512, 4096]" - # t1573 = prims.add(t1571, t1572) # t1573: "cuda:0 f32[1, 512, 4096]" - t1574 = prims.convert_element_type(t1573, dtypes.bfloat16) # t1574: "cuda:0 bf16[1, 512, 4096]" - t1575 = prims.convert_element_type(t1574, dtypes.float32) # t1575: "cuda:0 f32[1, 512, 4096]" - t1576 = ltorch.mul(t1575, t1575) # t1576: "cuda:0 f32[1, 512, 4096]" - # t1576 = prims.mul(t1575, t1575) # t1576: "cuda:0 f32[1, 512, 4096]" - t1578 = prims.sum(t1576, (2,)) # t1578: "cuda:0 f32[1, 512]" - t1579 = prims.broadcast_in_dim(t1578, [1, 512, 1], [0, 1]) # t1579: "cuda:0 f32[1, 512, 1]" - t1581 = ltorch.true_divide(t1579, 4096.0) # t1581: "cuda:0 f32[1, 512, 1]" - # t1581 = prims.div(t1579, 4096.0) # t1581: "cuda:0 f32[1, 512, 1]" - t1583 = ltorch.add(t1581, 1e-05, alpha=None) # t1583: "cuda:0 f32[1, 512, 1]" - # t1583 = prims.add(t1581, 1e-05) # t1583: "cuda:0 f32[1, 512, 1]" - t1584 = prims.rsqrt(t1583) # t1584: "cuda:0 f32[1, 512, 1]" - t1585 = prims.broadcast_in_dim(t1584, (1, 512, 4096), (0, 1, 2)) # t1585: "cuda:0 f32[1, 512, 4096]" - t1586 = ltorch.mul(t1575, t1585) # t1586: "cuda:0 f32[1, 512, 4096]" - # t1586 = prims.mul(t1575, t1585) # t1586: "cuda:0 f32[1, 512, 4096]" - t1587 = prims.convert_element_type(t1586, dtypes.bfloat16) # t1587: "cuda:0 bf16[1, 512, 4096]" - t1588 = prims.broadcast_in_dim(t_transformer_h_11_norm_1_weight, (1, 512, 4096), (2,)) # t1588: "cuda:0 bf16[1, 512, 4096]" - t1589 = prims.convert_element_type(t1587, dtypes.float32) # t1589: "cuda:0 f32[1, 512, 4096]" - t1590 = prims.convert_element_type(t1588, dtypes.float32) # t1590: "cuda:0 f32[1, 512, 4096]" - t1591 = ltorch.mul(t1589, t1590) # t1591: "cuda:0 f32[1, 512, 4096]" - # t1591 = prims.mul(t1589, t1590) # t1591: "cuda:0 f32[1, 512, 4096]" - t1592 = prims.convert_element_type(t1591, dtypes.bfloat16) # t1592: "cuda:0 bf16[1, 512, 4096]" - t1593 = torch.nn.functional.linear(t1592, t_transformer_h_11_attn_attn_weight, None) # t1593: "cuda:0 bf16[1, 512, 12288]" - # t1593 = ltorch.linear(t1592, t_transformer_h_11_attn_attn_weight, None) # t1593: "cuda:0 bf16[1, 512, 12288]" - # t1593 = prims.linear(t1592, t_transformer_h_11_attn_attn_weight, None) # t1593: "cuda:0 bf16[1, 512, 12288]" - t1599 = prims.reshape(t1593, (1, 512, 32, 3, 128)) # t1599: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1605 = prims.transpose(t1599, (0, 2, 3, 1, 4)) # t1605: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1606, t1607, t1608) = ltorch.split(t1605, (1, 1, 1), 2) - # t1606 = prims.slice_prim(t1605, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1606: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1607 = prims.slice_prim(t1605, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1607: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1608 = prims.slice_prim(t1605, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1608: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1614 = prims.reshape(t1606, (1, 32, 512, 128)) # t1614: "cuda:0 bf16[1, 32, 512, 128]" - t1620 = prims.reshape(t1607, (1, 32, 512, 128)) # t1620: "cuda:0 bf16[1, 32, 512, 128]" - t1626 = prims.reshape(t1608, (1, 32, 512, 128)) # t1626: "cuda:0 bf16[1, 32, 512, 128]" - t1627 = prims.slice_prim(t1614, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1627: "cuda:0 bf16[1, 32, 512, 128]" - t1628 = prims.slice_prim(t1627, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1628: "cuda:0 bf16[1, 32, 512, 64]" - t1629 = prims.slice_prim(t1627, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1629: "cuda:0 bf16[1, 32, 512, 64]" - t1630 = prims.convert_element_type(t1629, dtypes.float32) # t1630: "cuda:0 f32[1, 32, 512, 64]" - t1631 = prims.neg(t1630) # t1631: "cuda:0 f32[1, 32, 512, 64]" - t1632 = prims.convert_element_type(t1631, dtypes.bfloat16) # t1632: "cuda:0 bf16[1, 32, 512, 64]" - t1634 = prims.cat((t1632, t1628), -1) # t1634: "cuda:0 bf16[1, 32, 512, 128]" - t1635 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1635: "cuda:0 f32[1, 32, 512, 128]" - t1636 = prims.convert_element_type(t1627, dtypes.float32) # t1636: "cuda:0 f32[1, 32, 512, 128]" - t1637 = ltorch.mul(t1636, t1635) # t1637: "cuda:0 f32[1, 32, 512, 128]" - # t1637 = prims.mul(t1636, t1635) # t1637: "cuda:0 f32[1, 32, 512, 128]" - t1638 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1638: "cuda:0 f32[1, 32, 512, 128]" - t1639 = prims.convert_element_type(t1634, dtypes.float32) # t1639: "cuda:0 f32[1, 32, 512, 128]" - t1640 = ltorch.mul(t1639, t1638) # t1640: "cuda:0 f32[1, 32, 512, 128]" - # t1640 = prims.mul(t1639, t1638) # t1640: "cuda:0 f32[1, 32, 512, 128]" - t1641 = ltorch.add(t1637, t1640, alpha=None) # t1641: "cuda:0 f32[1, 32, 512, 128]" - # t1641 = prims.add(t1637, t1640) # t1641: "cuda:0 f32[1, 32, 512, 128]" - t1642 = prims.convert_element_type(t1641, dtypes.bfloat16) # t1642: "cuda:0 bf16[1, 32, 512, 128]" - t1643 = prims.slice_prim(t1620, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1643: "cuda:0 bf16[1, 32, 512, 128]" - t1644 = prims.slice_prim(t1643, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1644: "cuda:0 bf16[1, 32, 512, 64]" - t1645 = prims.slice_prim(t1643, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1645: "cuda:0 bf16[1, 32, 512, 64]" - t1646 = prims.convert_element_type(t1645, dtypes.float32) # t1646: "cuda:0 f32[1, 32, 512, 64]" - t1647 = prims.neg(t1646) # t1647: "cuda:0 f32[1, 32, 512, 64]" - t1648 = prims.convert_element_type(t1647, dtypes.bfloat16) # t1648: "cuda:0 bf16[1, 32, 512, 64]" - t1650 = prims.cat((t1648, t1644), -1) # t1650: "cuda:0 bf16[1, 32, 512, 128]" - t1651 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1651: "cuda:0 f32[1, 32, 512, 128]" - t1652 = prims.convert_element_type(t1643, dtypes.float32) # t1652: "cuda:0 f32[1, 32, 512, 128]" - t1653 = ltorch.mul(t1652, t1651) # t1653: "cuda:0 f32[1, 32, 512, 128]" - # t1653 = prims.mul(t1652, t1651) # t1653: "cuda:0 f32[1, 32, 512, 128]" - t1654 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1654: "cuda:0 f32[1, 32, 512, 128]" - t1655 = prims.convert_element_type(t1650, dtypes.float32) # t1655: "cuda:0 f32[1, 32, 512, 128]" - t1656 = ltorch.mul(t1655, t1654) # t1656: "cuda:0 f32[1, 32, 512, 128]" - # t1656 = prims.mul(t1655, t1654) # t1656: "cuda:0 f32[1, 32, 512, 128]" - t1657 = ltorch.add(t1653, t1656, alpha=None) # t1657: "cuda:0 f32[1, 32, 512, 128]" - # t1657 = prims.add(t1653, t1656) # t1657: "cuda:0 f32[1, 32, 512, 128]" - t1658 = prims.convert_element_type(t1657, dtypes.bfloat16) # t1658: "cuda:0 bf16[1, 32, 512, 128]" - t1659 = prims.slice_prim(t1614, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1659: "cuda:0 bf16[1, 32, 512, 0]" - t1661 = prims.cat((t1642, t1659), -1) # t1661: "cuda:0 bf16[1, 32, 512, 128]" - t1662 = prims.slice_prim(t1620, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1662: "cuda:0 bf16[1, 32, 512, 0]" - t1664 = prims.cat((t1658, t1662), -1) # t1664: "cuda:0 bf16[1, 32, 512, 128]" - (t1665, t1666, t1667, t1668) = cudnn_sdpa_fwd(t1661, t1664, t1626, None, 0.0, True, scale=0.08838834764831843) - t1671 = prims.transpose(t1665, (0, 2, 1, 3)) # t1671: "cuda:0 bf16[1, 512, 32, 128]" - t1675 = prims.reshape(t1671, (1, 512, 4096)) # t1675: "cuda:0 bf16[1, 512, 4096]" - t1676 = torch.nn.functional.linear(t1675, t_transformer_h_11_attn_proj_weight, None) # t1676: "cuda:0 bf16[1, 512, 4096]" - # t1676 = ltorch.linear(t1675, t_transformer_h_11_attn_proj_weight, None) # t1676: "cuda:0 bf16[1, 512, 4096]" - # t1676 = prims.linear(t1675, t_transformer_h_11_attn_proj_weight, None) # t1676: "cuda:0 bf16[1, 512, 4096]" - t1677 = prims.convert_element_type(t1676, dtypes.float32) # t1677: "cuda:0 f32[1, 512, 4096]" - t1678 = prims.convert_element_type(t1574, dtypes.float32) # t1678: "cuda:0 f32[1, 512, 4096]" - t1679 = ltorch.add(t1677, t1678, alpha=None) # t1679: "cuda:0 f32[1, 512, 4096]" - # t1679 = prims.add(t1677, t1678) # t1679: "cuda:0 f32[1, 512, 4096]" - t1680 = prims.convert_element_type(t1679, dtypes.bfloat16) # t1680: "cuda:0 bf16[1, 512, 4096]" - t1681 = prims.convert_element_type(t1680, dtypes.float32) # t1681: "cuda:0 f32[1, 512, 4096]" - t1682 = ltorch.mul(t1681, t1681) # t1682: "cuda:0 f32[1, 512, 4096]" - # t1682 = prims.mul(t1681, t1681) # t1682: "cuda:0 f32[1, 512, 4096]" - t1684 = prims.sum(t1682, (2,)) # t1684: "cuda:0 f32[1, 512]" - t1685 = prims.broadcast_in_dim(t1684, [1, 512, 1], [0, 1]) # t1685: "cuda:0 f32[1, 512, 1]" - t1687 = ltorch.true_divide(t1685, 4096.0) # t1687: "cuda:0 f32[1, 512, 1]" - # t1687 = prims.div(t1685, 4096.0) # t1687: "cuda:0 f32[1, 512, 1]" - t1689 = ltorch.add(t1687, 1e-05, alpha=None) # t1689: "cuda:0 f32[1, 512, 1]" - # t1689 = prims.add(t1687, 1e-05) # t1689: "cuda:0 f32[1, 512, 1]" - t1690 = prims.rsqrt(t1689) # t1690: "cuda:0 f32[1, 512, 1]" - t1691 = prims.broadcast_in_dim(t1690, (1, 512, 4096), (0, 1, 2)) # t1691: "cuda:0 f32[1, 512, 4096]" - t1692 = ltorch.mul(t1681, t1691) # t1692: "cuda:0 f32[1, 512, 4096]" - # t1692 = prims.mul(t1681, t1691) # t1692: "cuda:0 f32[1, 512, 4096]" - t1693 = prims.convert_element_type(t1692, dtypes.bfloat16) # t1693: "cuda:0 bf16[1, 512, 4096]" - t1694 = prims.broadcast_in_dim(t_transformer_h_11_norm_2_weight, (1, 512, 4096), (2,)) # t1694: "cuda:0 bf16[1, 512, 4096]" - t1695 = prims.convert_element_type(t1693, dtypes.float32) # t1695: "cuda:0 f32[1, 512, 4096]" - t1696 = prims.convert_element_type(t1694, dtypes.float32) # t1696: "cuda:0 f32[1, 512, 4096]" - t1697 = ltorch.mul(t1695, t1696) # t1697: "cuda:0 f32[1, 512, 4096]" - # t1697 = prims.mul(t1695, t1696) # t1697: "cuda:0 f32[1, 512, 4096]" - t1698 = prims.convert_element_type(t1697, dtypes.bfloat16) # t1698: "cuda:0 bf16[1, 512, 4096]" - t1699 = torch.nn.functional.linear(t1698, t_transformer_h_11_mlp_fc_1_weight, None) # t1699: "cuda:0 bf16[1, 512, 11008]" - # t1699 = ltorch.linear(t1698, t_transformer_h_11_mlp_fc_1_weight, None) # t1699: "cuda:0 bf16[1, 512, 11008]" - # t1699 = prims.linear(t1698, t_transformer_h_11_mlp_fc_1_weight, None) # t1699: "cuda:0 bf16[1, 512, 11008]" - t1700 = torch.nn.functional.linear(t1698, t_transformer_h_11_mlp_fc_2_weight, None) # t1700: "cuda:0 bf16[1, 512, 11008]" - # t1700 = ltorch.linear(t1698, t_transformer_h_11_mlp_fc_2_weight, None) # t1700: "cuda:0 bf16[1, 512, 11008]" - # t1700 = prims.linear(t1698, t_transformer_h_11_mlp_fc_2_weight, None) # t1700: "cuda:0 bf16[1, 512, 11008]" - t1701 = prims.convert_element_type(t1699, dtypes.float32) # t1701: "cuda:0 f32[1, 512, 11008]" - t1702 = prims.neg(t1701) # t1702: "cuda:0 f32[1, 512, 11008]" - t1703 = prims.exp(t1702) # t1703: "cuda:0 f32[1, 512, 11008]" - t1704 = ltorch.add(1.0, t1703, alpha=None) # t1704: "cuda:0 f32[1, 512, 11008]" - # t1704 = prims.add(1.0, t1703) # t1704: "cuda:0 f32[1, 512, 11008]" - t1705 = prims.reciprocal(t1704) # t1705: "cuda:0 f32[1, 512, 11008]" - t1706 = prims.convert_element_type(t1705, dtypes.bfloat16) # t1706: "cuda:0 bf16[1, 512, 11008]" - t1707 = prims.convert_element_type(t1699, dtypes.float32) # t1707: "cuda:0 f32[1, 512, 11008]" - t1708 = prims.convert_element_type(t1706, dtypes.float32) # t1708: "cuda:0 f32[1, 512, 11008]" - t1709 = ltorch.mul(t1707, t1708) # t1709: "cuda:0 f32[1, 512, 11008]" - # t1709 = prims.mul(t1707, t1708) # t1709: "cuda:0 f32[1, 512, 11008]" - t1710 = prims.convert_element_type(t1709, dtypes.bfloat16) # t1710: "cuda:0 bf16[1, 512, 11008]" - t1711 = prims.convert_element_type(t1710, dtypes.float32) # t1711: "cuda:0 f32[1, 512, 11008]" - t1712 = prims.convert_element_type(t1700, dtypes.float32) # t1712: "cuda:0 f32[1, 512, 11008]" - t1713 = ltorch.mul(t1711, t1712) # t1713: "cuda:0 f32[1, 512, 11008]" - # t1713 = prims.mul(t1711, t1712) # t1713: "cuda:0 f32[1, 512, 11008]" - t1714 = prims.convert_element_type(t1713, dtypes.bfloat16) # t1714: "cuda:0 bf16[1, 512, 11008]" - t1715 = torch.nn.functional.linear(t1714, t_transformer_h_11_mlp_proj_weight, None) # t1715: "cuda:0 bf16[1, 512, 4096]" - # t1715 = ltorch.linear(t1714, t_transformer_h_11_mlp_proj_weight, None) # t1715: "cuda:0 bf16[1, 512, 4096]" - # t1715 = prims.linear(t1714, t_transformer_h_11_mlp_proj_weight, None) # t1715: "cuda:0 bf16[1, 512, 4096]" - t1716 = prims.convert_element_type(t1715, dtypes.float32) # t1716: "cuda:0 f32[1, 512, 4096]" - t1717 = prims.convert_element_type(t1680, dtypes.float32) # t1717: "cuda:0 f32[1, 512, 4096]" - t1718 = ltorch.add(t1716, t1717, alpha=None) # t1718: "cuda:0 f32[1, 512, 4096]" - # t1718 = prims.add(t1716, t1717) # t1718: "cuda:0 f32[1, 512, 4096]" - t1719 = prims.convert_element_type(t1718, dtypes.bfloat16) # t1719: "cuda:0 bf16[1, 512, 4096]" - t1720 = prims.convert_element_type(t1719, dtypes.float32) # t1720: "cuda:0 f32[1, 512, 4096]" - t1721 = ltorch.mul(t1720, t1720) # t1721: "cuda:0 f32[1, 512, 4096]" - # t1721 = prims.mul(t1720, t1720) # t1721: "cuda:0 f32[1, 512, 4096]" - t1723 = prims.sum(t1721, (2,)) # t1723: "cuda:0 f32[1, 512]" - t1724 = prims.broadcast_in_dim(t1723, [1, 512, 1], [0, 1]) # t1724: "cuda:0 f32[1, 512, 1]" - t1726 = ltorch.true_divide(t1724, 4096.0) # t1726: "cuda:0 f32[1, 512, 1]" - # t1726 = prims.div(t1724, 4096.0) # t1726: "cuda:0 f32[1, 512, 1]" - t1728 = ltorch.add(t1726, 1e-05, alpha=None) # t1728: "cuda:0 f32[1, 512, 1]" - # t1728 = prims.add(t1726, 1e-05) # t1728: "cuda:0 f32[1, 512, 1]" - t1729 = prims.rsqrt(t1728) # t1729: "cuda:0 f32[1, 512, 1]" - t1730 = prims.broadcast_in_dim(t1729, (1, 512, 4096), (0, 1, 2)) # t1730: "cuda:0 f32[1, 512, 4096]" - t1731 = ltorch.mul(t1720, t1730) # t1731: "cuda:0 f32[1, 512, 4096]" - # t1731 = prims.mul(t1720, t1730) # t1731: "cuda:0 f32[1, 512, 4096]" - t1732 = prims.convert_element_type(t1731, dtypes.bfloat16) # t1732: "cuda:0 bf16[1, 512, 4096]" - t1733 = prims.broadcast_in_dim(t_transformer_h_12_norm_1_weight, (1, 512, 4096), (2,)) # t1733: "cuda:0 bf16[1, 512, 4096]" - t1734 = prims.convert_element_type(t1732, dtypes.float32) # t1734: "cuda:0 f32[1, 512, 4096]" - t1735 = prims.convert_element_type(t1733, dtypes.float32) # t1735: "cuda:0 f32[1, 512, 4096]" - t1736 = ltorch.mul(t1734, t1735) # t1736: "cuda:0 f32[1, 512, 4096]" - # t1736 = prims.mul(t1734, t1735) # t1736: "cuda:0 f32[1, 512, 4096]" - t1737 = prims.convert_element_type(t1736, dtypes.bfloat16) # t1737: "cuda:0 bf16[1, 512, 4096]" - t1738 = torch.nn.functional.linear(t1737, t_transformer_h_12_attn_attn_weight, None) # t1738: "cuda:0 bf16[1, 512, 12288]" - # t1738 = ltorch.linear(t1737, t_transformer_h_12_attn_attn_weight, None) # t1738: "cuda:0 bf16[1, 512, 12288]" - # t1738 = prims.linear(t1737, t_transformer_h_12_attn_attn_weight, None) # t1738: "cuda:0 bf16[1, 512, 12288]" - t1744 = prims.reshape(t1738, (1, 512, 32, 3, 128)) # t1744: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1750 = prims.transpose(t1744, (0, 2, 3, 1, 4)) # t1750: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1751, t1752, t1753) = ltorch.split(t1750, (1, 1, 1), 2) - # t1751 = prims.slice_prim(t1750, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1751: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1752 = prims.slice_prim(t1750, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1752: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1753 = prims.slice_prim(t1750, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1753: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1759 = prims.reshape(t1751, (1, 32, 512, 128)) # t1759: "cuda:0 bf16[1, 32, 512, 128]" - t1765 = prims.reshape(t1752, (1, 32, 512, 128)) # t1765: "cuda:0 bf16[1, 32, 512, 128]" - t1771 = prims.reshape(t1753, (1, 32, 512, 128)) # t1771: "cuda:0 bf16[1, 32, 512, 128]" - t1772 = prims.slice_prim(t1759, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1772: "cuda:0 bf16[1, 32, 512, 128]" - t1773 = prims.slice_prim(t1772, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1773: "cuda:0 bf16[1, 32, 512, 64]" - t1774 = prims.slice_prim(t1772, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1774: "cuda:0 bf16[1, 32, 512, 64]" - t1775 = prims.convert_element_type(t1774, dtypes.float32) # t1775: "cuda:0 f32[1, 32, 512, 64]" - t1776 = prims.neg(t1775) # t1776: "cuda:0 f32[1, 32, 512, 64]" - t1777 = prims.convert_element_type(t1776, dtypes.bfloat16) # t1777: "cuda:0 bf16[1, 32, 512, 64]" - t1779 = prims.cat((t1777, t1773), -1) # t1779: "cuda:0 bf16[1, 32, 512, 128]" - t1780 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1780: "cuda:0 f32[1, 32, 512, 128]" - t1781 = prims.convert_element_type(t1772, dtypes.float32) # t1781: "cuda:0 f32[1, 32, 512, 128]" - t1782 = ltorch.mul(t1781, t1780) # t1782: "cuda:0 f32[1, 32, 512, 128]" - # t1782 = prims.mul(t1781, t1780) # t1782: "cuda:0 f32[1, 32, 512, 128]" - t1783 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1783: "cuda:0 f32[1, 32, 512, 128]" - t1784 = prims.convert_element_type(t1779, dtypes.float32) # t1784: "cuda:0 f32[1, 32, 512, 128]" - t1785 = ltorch.mul(t1784, t1783) # t1785: "cuda:0 f32[1, 32, 512, 128]" - # t1785 = prims.mul(t1784, t1783) # t1785: "cuda:0 f32[1, 32, 512, 128]" - t1786 = ltorch.add(t1782, t1785, alpha=None) # t1786: "cuda:0 f32[1, 32, 512, 128]" - # t1786 = prims.add(t1782, t1785) # t1786: "cuda:0 f32[1, 32, 512, 128]" - t1787 = prims.convert_element_type(t1786, dtypes.bfloat16) # t1787: "cuda:0 bf16[1, 32, 512, 128]" - t1788 = prims.slice_prim(t1765, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1788: "cuda:0 bf16[1, 32, 512, 128]" - t1789 = prims.slice_prim(t1788, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1789: "cuda:0 bf16[1, 32, 512, 64]" - t1790 = prims.slice_prim(t1788, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1790: "cuda:0 bf16[1, 32, 512, 64]" - t1791 = prims.convert_element_type(t1790, dtypes.float32) # t1791: "cuda:0 f32[1, 32, 512, 64]" - t1792 = prims.neg(t1791) # t1792: "cuda:0 f32[1, 32, 512, 64]" - t1793 = prims.convert_element_type(t1792, dtypes.bfloat16) # t1793: "cuda:0 bf16[1, 32, 512, 64]" - t1795 = prims.cat((t1793, t1789), -1) # t1795: "cuda:0 bf16[1, 32, 512, 128]" - t1796 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1796: "cuda:0 f32[1, 32, 512, 128]" - t1797 = prims.convert_element_type(t1788, dtypes.float32) # t1797: "cuda:0 f32[1, 32, 512, 128]" - t1798 = ltorch.mul(t1797, t1796) # t1798: "cuda:0 f32[1, 32, 512, 128]" - # t1798 = prims.mul(t1797, t1796) # t1798: "cuda:0 f32[1, 32, 512, 128]" - t1799 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1799: "cuda:0 f32[1, 32, 512, 128]" - t1800 = prims.convert_element_type(t1795, dtypes.float32) # t1800: "cuda:0 f32[1, 32, 512, 128]" - t1801 = ltorch.mul(t1800, t1799) # t1801: "cuda:0 f32[1, 32, 512, 128]" - # t1801 = prims.mul(t1800, t1799) # t1801: "cuda:0 f32[1, 32, 512, 128]" - t1802 = ltorch.add(t1798, t1801, alpha=None) # t1802: "cuda:0 f32[1, 32, 512, 128]" - # t1802 = prims.add(t1798, t1801) # t1802: "cuda:0 f32[1, 32, 512, 128]" - t1803 = prims.convert_element_type(t1802, dtypes.bfloat16) # t1803: "cuda:0 bf16[1, 32, 512, 128]" - t1804 = prims.slice_prim(t1759, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1804: "cuda:0 bf16[1, 32, 512, 0]" - t1806 = prims.cat((t1787, t1804), -1) # t1806: "cuda:0 bf16[1, 32, 512, 128]" - t1807 = prims.slice_prim(t1765, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1807: "cuda:0 bf16[1, 32, 512, 0]" - t1809 = prims.cat((t1803, t1807), -1) # t1809: "cuda:0 bf16[1, 32, 512, 128]" - (t1810, t1811, t1812, t1813) = cudnn_sdpa_fwd(t1806, t1809, t1771, None, 0.0, True, scale=0.08838834764831843) - t1816 = prims.transpose(t1810, (0, 2, 1, 3)) # t1816: "cuda:0 bf16[1, 512, 32, 128]" - t1820 = prims.reshape(t1816, (1, 512, 4096)) # t1820: "cuda:0 bf16[1, 512, 4096]" - t1821 = torch.nn.functional.linear(t1820, t_transformer_h_12_attn_proj_weight, None) # t1821: "cuda:0 bf16[1, 512, 4096]" - # t1821 = ltorch.linear(t1820, t_transformer_h_12_attn_proj_weight, None) # t1821: "cuda:0 bf16[1, 512, 4096]" - # t1821 = prims.linear(t1820, t_transformer_h_12_attn_proj_weight, None) # t1821: "cuda:0 bf16[1, 512, 4096]" - t1822 = prims.convert_element_type(t1821, dtypes.float32) # t1822: "cuda:0 f32[1, 512, 4096]" - t1823 = prims.convert_element_type(t1719, dtypes.float32) # t1823: "cuda:0 f32[1, 512, 4096]" - t1824 = ltorch.add(t1822, t1823, alpha=None) # t1824: "cuda:0 f32[1, 512, 4096]" - # t1824 = prims.add(t1822, t1823) # t1824: "cuda:0 f32[1, 512, 4096]" - t1825 = prims.convert_element_type(t1824, dtypes.bfloat16) # t1825: "cuda:0 bf16[1, 512, 4096]" - t1826 = prims.convert_element_type(t1825, dtypes.float32) # t1826: "cuda:0 f32[1, 512, 4096]" - t1827 = ltorch.mul(t1826, t1826) # t1827: "cuda:0 f32[1, 512, 4096]" - # t1827 = prims.mul(t1826, t1826) # t1827: "cuda:0 f32[1, 512, 4096]" - t1829 = prims.sum(t1827, (2,)) # t1829: "cuda:0 f32[1, 512]" - t1830 = prims.broadcast_in_dim(t1829, [1, 512, 1], [0, 1]) # t1830: "cuda:0 f32[1, 512, 1]" - t1832 = ltorch.true_divide(t1830, 4096.0) # t1832: "cuda:0 f32[1, 512, 1]" - # t1832 = prims.div(t1830, 4096.0) # t1832: "cuda:0 f32[1, 512, 1]" - t1834 = ltorch.add(t1832, 1e-05, alpha=None) # t1834: "cuda:0 f32[1, 512, 1]" - # t1834 = prims.add(t1832, 1e-05) # t1834: "cuda:0 f32[1, 512, 1]" - t1835 = prims.rsqrt(t1834) # t1835: "cuda:0 f32[1, 512, 1]" - t1836 = prims.broadcast_in_dim(t1835, (1, 512, 4096), (0, 1, 2)) # t1836: "cuda:0 f32[1, 512, 4096]" - t1837 = ltorch.mul(t1826, t1836) # t1837: "cuda:0 f32[1, 512, 4096]" - # t1837 = prims.mul(t1826, t1836) # t1837: "cuda:0 f32[1, 512, 4096]" - t1838 = prims.convert_element_type(t1837, dtypes.bfloat16) # t1838: "cuda:0 bf16[1, 512, 4096]" - t1839 = prims.broadcast_in_dim(t_transformer_h_12_norm_2_weight, (1, 512, 4096), (2,)) # t1839: "cuda:0 bf16[1, 512, 4096]" - t1840 = prims.convert_element_type(t1838, dtypes.float32) # t1840: "cuda:0 f32[1, 512, 4096]" - t1841 = prims.convert_element_type(t1839, dtypes.float32) # t1841: "cuda:0 f32[1, 512, 4096]" - t1842 = ltorch.mul(t1840, t1841) # t1842: "cuda:0 f32[1, 512, 4096]" - # t1842 = prims.mul(t1840, t1841) # t1842: "cuda:0 f32[1, 512, 4096]" - t1843 = prims.convert_element_type(t1842, dtypes.bfloat16) # t1843: "cuda:0 bf16[1, 512, 4096]" - t1844 = torch.nn.functional.linear(t1843, t_transformer_h_12_mlp_fc_1_weight, None) # t1844: "cuda:0 bf16[1, 512, 11008]" - # t1844 = ltorch.linear(t1843, t_transformer_h_12_mlp_fc_1_weight, None) # t1844: "cuda:0 bf16[1, 512, 11008]" - # t1844 = prims.linear(t1843, t_transformer_h_12_mlp_fc_1_weight, None) # t1844: "cuda:0 bf16[1, 512, 11008]" - t1845 = torch.nn.functional.linear(t1843, t_transformer_h_12_mlp_fc_2_weight, None) # t1845: "cuda:0 bf16[1, 512, 11008]" - # t1845 = ltorch.linear(t1843, t_transformer_h_12_mlp_fc_2_weight, None) # t1845: "cuda:0 bf16[1, 512, 11008]" - # t1845 = prims.linear(t1843, t_transformer_h_12_mlp_fc_2_weight, None) # t1845: "cuda:0 bf16[1, 512, 11008]" - t1846 = prims.convert_element_type(t1844, dtypes.float32) # t1846: "cuda:0 f32[1, 512, 11008]" - t1847 = prims.neg(t1846) # t1847: "cuda:0 f32[1, 512, 11008]" - t1848 = prims.exp(t1847) # t1848: "cuda:0 f32[1, 512, 11008]" - t1849 = ltorch.add(1.0, t1848, alpha=None) # t1849: "cuda:0 f32[1, 512, 11008]" - # t1849 = prims.add(1.0, t1848) # t1849: "cuda:0 f32[1, 512, 11008]" - t1850 = prims.reciprocal(t1849) # t1850: "cuda:0 f32[1, 512, 11008]" - t1851 = prims.convert_element_type(t1850, dtypes.bfloat16) # t1851: "cuda:0 bf16[1, 512, 11008]" - t1852 = prims.convert_element_type(t1844, dtypes.float32) # t1852: "cuda:0 f32[1, 512, 11008]" - t1853 = prims.convert_element_type(t1851, dtypes.float32) # t1853: "cuda:0 f32[1, 512, 11008]" - t1854 = ltorch.mul(t1852, t1853) # t1854: "cuda:0 f32[1, 512, 11008]" - # t1854 = prims.mul(t1852, t1853) # t1854: "cuda:0 f32[1, 512, 11008]" - t1855 = prims.convert_element_type(t1854, dtypes.bfloat16) # t1855: "cuda:0 bf16[1, 512, 11008]" - t1856 = prims.convert_element_type(t1855, dtypes.float32) # t1856: "cuda:0 f32[1, 512, 11008]" - t1857 = prims.convert_element_type(t1845, dtypes.float32) # t1857: "cuda:0 f32[1, 512, 11008]" - t1858 = ltorch.mul(t1856, t1857) # t1858: "cuda:0 f32[1, 512, 11008]" - # t1858 = prims.mul(t1856, t1857) # t1858: "cuda:0 f32[1, 512, 11008]" - t1859 = prims.convert_element_type(t1858, dtypes.bfloat16) # t1859: "cuda:0 bf16[1, 512, 11008]" - t1860 = torch.nn.functional.linear(t1859, t_transformer_h_12_mlp_proj_weight, None) # t1860: "cuda:0 bf16[1, 512, 4096]" - # t1860 = ltorch.linear(t1859, t_transformer_h_12_mlp_proj_weight, None) # t1860: "cuda:0 bf16[1, 512, 4096]" - # t1860 = prims.linear(t1859, t_transformer_h_12_mlp_proj_weight, None) # t1860: "cuda:0 bf16[1, 512, 4096]" - t1861 = prims.convert_element_type(t1860, dtypes.float32) # t1861: "cuda:0 f32[1, 512, 4096]" - t1862 = prims.convert_element_type(t1825, dtypes.float32) # t1862: "cuda:0 f32[1, 512, 4096]" - t1863 = ltorch.add(t1861, t1862, alpha=None) # t1863: "cuda:0 f32[1, 512, 4096]" - # t1863 = prims.add(t1861, t1862) # t1863: "cuda:0 f32[1, 512, 4096]" - t1864 = prims.convert_element_type(t1863, dtypes.bfloat16) # t1864: "cuda:0 bf16[1, 512, 4096]" - t1865 = prims.convert_element_type(t1864, dtypes.float32) # t1865: "cuda:0 f32[1, 512, 4096]" - t1866 = ltorch.mul(t1865, t1865) # t1866: "cuda:0 f32[1, 512, 4096]" - # t1866 = prims.mul(t1865, t1865) # t1866: "cuda:0 f32[1, 512, 4096]" - t1868 = prims.sum(t1866, (2,)) # t1868: "cuda:0 f32[1, 512]" - t1869 = prims.broadcast_in_dim(t1868, [1, 512, 1], [0, 1]) # t1869: "cuda:0 f32[1, 512, 1]" - t1871 = ltorch.true_divide(t1869, 4096.0) # t1871: "cuda:0 f32[1, 512, 1]" - # t1871 = prims.div(t1869, 4096.0) # t1871: "cuda:0 f32[1, 512, 1]" - t1873 = ltorch.add(t1871, 1e-05, alpha=None) # t1873: "cuda:0 f32[1, 512, 1]" - # t1873 = prims.add(t1871, 1e-05) # t1873: "cuda:0 f32[1, 512, 1]" - t1874 = prims.rsqrt(t1873) # t1874: "cuda:0 f32[1, 512, 1]" - t1875 = prims.broadcast_in_dim(t1874, (1, 512, 4096), (0, 1, 2)) # t1875: "cuda:0 f32[1, 512, 4096]" - t1876 = ltorch.mul(t1865, t1875) # t1876: "cuda:0 f32[1, 512, 4096]" - # t1876 = prims.mul(t1865, t1875) # t1876: "cuda:0 f32[1, 512, 4096]" - t1877 = prims.convert_element_type(t1876, dtypes.bfloat16) # t1877: "cuda:0 bf16[1, 512, 4096]" - t1878 = prims.broadcast_in_dim(t_transformer_h_13_norm_1_weight, (1, 512, 4096), (2,)) # t1878: "cuda:0 bf16[1, 512, 4096]" - t1879 = prims.convert_element_type(t1877, dtypes.float32) # t1879: "cuda:0 f32[1, 512, 4096]" - t1880 = prims.convert_element_type(t1878, dtypes.float32) # t1880: "cuda:0 f32[1, 512, 4096]" - t1881 = ltorch.mul(t1879, t1880) # t1881: "cuda:0 f32[1, 512, 4096]" - # t1881 = prims.mul(t1879, t1880) # t1881: "cuda:0 f32[1, 512, 4096]" - t1882 = prims.convert_element_type(t1881, dtypes.bfloat16) # t1882: "cuda:0 bf16[1, 512, 4096]" - t1883 = torch.nn.functional.linear(t1882, t_transformer_h_13_attn_attn_weight, None) # t1883: "cuda:0 bf16[1, 512, 12288]" - # t1883 = ltorch.linear(t1882, t_transformer_h_13_attn_attn_weight, None) # t1883: "cuda:0 bf16[1, 512, 12288]" - # t1883 = prims.linear(t1882, t_transformer_h_13_attn_attn_weight, None) # t1883: "cuda:0 bf16[1, 512, 12288]" - t1889 = prims.reshape(t1883, (1, 512, 32, 3, 128)) # t1889: "cuda:0 bf16[1, 512, 32, 3, 128]" - t1895 = prims.transpose(t1889, (0, 2, 3, 1, 4)) # t1895: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t1896, t1897, t1898) = ltorch.split(t1895, (1, 1, 1), 2) - # t1896 = prims.slice_prim(t1895, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1896: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1897 = prims.slice_prim(t1895, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1897: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t1898 = prims.slice_prim(t1895, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1898: "cuda:0 bf16[1, 32, 1, 512, 128]" - t1904 = prims.reshape(t1896, (1, 32, 512, 128)) # t1904: "cuda:0 bf16[1, 32, 512, 128]" - t1910 = prims.reshape(t1897, (1, 32, 512, 128)) # t1910: "cuda:0 bf16[1, 32, 512, 128]" - t1916 = prims.reshape(t1898, (1, 32, 512, 128)) # t1916: "cuda:0 bf16[1, 32, 512, 128]" - t1917 = prims.slice_prim(t1904, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1917: "cuda:0 bf16[1, 32, 512, 128]" - t1918 = prims.slice_prim(t1917, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1918: "cuda:0 bf16[1, 32, 512, 64]" - t1919 = prims.slice_prim(t1917, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1919: "cuda:0 bf16[1, 32, 512, 64]" - t1920 = prims.convert_element_type(t1919, dtypes.float32) # t1920: "cuda:0 f32[1, 32, 512, 64]" - t1921 = prims.neg(t1920) # t1921: "cuda:0 f32[1, 32, 512, 64]" - t1922 = prims.convert_element_type(t1921, dtypes.bfloat16) # t1922: "cuda:0 bf16[1, 32, 512, 64]" - t1924 = prims.cat((t1922, t1918), -1) # t1924: "cuda:0 bf16[1, 32, 512, 128]" - t1925 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1925: "cuda:0 f32[1, 32, 512, 128]" - t1926 = prims.convert_element_type(t1917, dtypes.float32) # t1926: "cuda:0 f32[1, 32, 512, 128]" - t1927 = ltorch.mul(t1926, t1925) # t1927: "cuda:0 f32[1, 32, 512, 128]" - # t1927 = prims.mul(t1926, t1925) # t1927: "cuda:0 f32[1, 32, 512, 128]" - t1928 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1928: "cuda:0 f32[1, 32, 512, 128]" - t1929 = prims.convert_element_type(t1924, dtypes.float32) # t1929: "cuda:0 f32[1, 32, 512, 128]" - t1930 = ltorch.mul(t1929, t1928) # t1930: "cuda:0 f32[1, 32, 512, 128]" - # t1930 = prims.mul(t1929, t1928) # t1930: "cuda:0 f32[1, 32, 512, 128]" - t1931 = ltorch.add(t1927, t1930, alpha=None) # t1931: "cuda:0 f32[1, 32, 512, 128]" - # t1931 = prims.add(t1927, t1930) # t1931: "cuda:0 f32[1, 32, 512, 128]" - t1932 = prims.convert_element_type(t1931, dtypes.bfloat16) # t1932: "cuda:0 bf16[1, 32, 512, 128]" - t1933 = prims.slice_prim(t1910, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t1933: "cuda:0 bf16[1, 32, 512, 128]" - t1934 = prims.slice_prim(t1933, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t1934: "cuda:0 bf16[1, 32, 512, 64]" - t1935 = prims.slice_prim(t1933, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t1935: "cuda:0 bf16[1, 32, 512, 64]" - t1936 = prims.convert_element_type(t1935, dtypes.float32) # t1936: "cuda:0 f32[1, 32, 512, 64]" - t1937 = prims.neg(t1936) # t1937: "cuda:0 f32[1, 32, 512, 64]" - t1938 = prims.convert_element_type(t1937, dtypes.bfloat16) # t1938: "cuda:0 bf16[1, 32, 512, 64]" - t1940 = prims.cat((t1938, t1934), -1) # t1940: "cuda:0 bf16[1, 32, 512, 128]" - t1941 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t1941: "cuda:0 f32[1, 32, 512, 128]" - t1942 = prims.convert_element_type(t1933, dtypes.float32) # t1942: "cuda:0 f32[1, 32, 512, 128]" - t1943 = ltorch.mul(t1942, t1941) # t1943: "cuda:0 f32[1, 32, 512, 128]" - # t1943 = prims.mul(t1942, t1941) # t1943: "cuda:0 f32[1, 32, 512, 128]" - t1944 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t1944: "cuda:0 f32[1, 32, 512, 128]" - t1945 = prims.convert_element_type(t1940, dtypes.float32) # t1945: "cuda:0 f32[1, 32, 512, 128]" - t1946 = ltorch.mul(t1945, t1944) # t1946: "cuda:0 f32[1, 32, 512, 128]" - # t1946 = prims.mul(t1945, t1944) # t1946: "cuda:0 f32[1, 32, 512, 128]" - t1947 = ltorch.add(t1943, t1946, alpha=None) # t1947: "cuda:0 f32[1, 32, 512, 128]" - # t1947 = prims.add(t1943, t1946) # t1947: "cuda:0 f32[1, 32, 512, 128]" - t1948 = prims.convert_element_type(t1947, dtypes.bfloat16) # t1948: "cuda:0 bf16[1, 32, 512, 128]" - t1949 = prims.slice_prim(t1904, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1949: "cuda:0 bf16[1, 32, 512, 0]" - t1951 = prims.cat((t1932, t1949), -1) # t1951: "cuda:0 bf16[1, 32, 512, 128]" - t1952 = prims.slice_prim(t1910, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t1952: "cuda:0 bf16[1, 32, 512, 0]" - t1954 = prims.cat((t1948, t1952), -1) # t1954: "cuda:0 bf16[1, 32, 512, 128]" - (t1955, t1956, t1957, t1958) = cudnn_sdpa_fwd(t1951, t1954, t1916, None, 0.0, True, scale=0.08838834764831843) - t1961 = prims.transpose(t1955, (0, 2, 1, 3)) # t1961: "cuda:0 bf16[1, 512, 32, 128]" - t1965 = prims.reshape(t1961, (1, 512, 4096)) # t1965: "cuda:0 bf16[1, 512, 4096]" - t1966 = torch.nn.functional.linear(t1965, t_transformer_h_13_attn_proj_weight, None) # t1966: "cuda:0 bf16[1, 512, 4096]" - # t1966 = ltorch.linear(t1965, t_transformer_h_13_attn_proj_weight, None) # t1966: "cuda:0 bf16[1, 512, 4096]" - # t1966 = prims.linear(t1965, t_transformer_h_13_attn_proj_weight, None) # t1966: "cuda:0 bf16[1, 512, 4096]" - t1967 = prims.convert_element_type(t1966, dtypes.float32) # t1967: "cuda:0 f32[1, 512, 4096]" - t1968 = prims.convert_element_type(t1864, dtypes.float32) # t1968: "cuda:0 f32[1, 512, 4096]" - t1969 = ltorch.add(t1967, t1968, alpha=None) # t1969: "cuda:0 f32[1, 512, 4096]" - # t1969 = prims.add(t1967, t1968) # t1969: "cuda:0 f32[1, 512, 4096]" - t1970 = prims.convert_element_type(t1969, dtypes.bfloat16) # t1970: "cuda:0 bf16[1, 512, 4096]" - t1971 = prims.convert_element_type(t1970, dtypes.float32) # t1971: "cuda:0 f32[1, 512, 4096]" - t1972 = ltorch.mul(t1971, t1971) # t1972: "cuda:0 f32[1, 512, 4096]" - # t1972 = prims.mul(t1971, t1971) # t1972: "cuda:0 f32[1, 512, 4096]" - t1974 = prims.sum(t1972, (2,)) # t1974: "cuda:0 f32[1, 512]" - t1975 = prims.broadcast_in_dim(t1974, [1, 512, 1], [0, 1]) # t1975: "cuda:0 f32[1, 512, 1]" - t1977 = ltorch.true_divide(t1975, 4096.0) # t1977: "cuda:0 f32[1, 512, 1]" - # t1977 = prims.div(t1975, 4096.0) # t1977: "cuda:0 f32[1, 512, 1]" - t1979 = ltorch.add(t1977, 1e-05, alpha=None) # t1979: "cuda:0 f32[1, 512, 1]" - # t1979 = prims.add(t1977, 1e-05) # t1979: "cuda:0 f32[1, 512, 1]" - t1980 = prims.rsqrt(t1979) # t1980: "cuda:0 f32[1, 512, 1]" - t1981 = prims.broadcast_in_dim(t1980, (1, 512, 4096), (0, 1, 2)) # t1981: "cuda:0 f32[1, 512, 4096]" - t1982 = ltorch.mul(t1971, t1981) # t1982: "cuda:0 f32[1, 512, 4096]" - # t1982 = prims.mul(t1971, t1981) # t1982: "cuda:0 f32[1, 512, 4096]" - t1983 = prims.convert_element_type(t1982, dtypes.bfloat16) # t1983: "cuda:0 bf16[1, 512, 4096]" - t1984 = prims.broadcast_in_dim(t_transformer_h_13_norm_2_weight, (1, 512, 4096), (2,)) # t1984: "cuda:0 bf16[1, 512, 4096]" - t1985 = prims.convert_element_type(t1983, dtypes.float32) # t1985: "cuda:0 f32[1, 512, 4096]" - t1986 = prims.convert_element_type(t1984, dtypes.float32) # t1986: "cuda:0 f32[1, 512, 4096]" - t1987 = ltorch.mul(t1985, t1986) # t1987: "cuda:0 f32[1, 512, 4096]" - # t1987 = prims.mul(t1985, t1986) # t1987: "cuda:0 f32[1, 512, 4096]" - t1988 = prims.convert_element_type(t1987, dtypes.bfloat16) # t1988: "cuda:0 bf16[1, 512, 4096]" - t1989 = torch.nn.functional.linear(t1988, t_transformer_h_13_mlp_fc_1_weight, None) # t1989: "cuda:0 bf16[1, 512, 11008]" - # t1989 = ltorch.linear(t1988, t_transformer_h_13_mlp_fc_1_weight, None) # t1989: "cuda:0 bf16[1, 512, 11008]" - # t1989 = prims.linear(t1988, t_transformer_h_13_mlp_fc_1_weight, None) # t1989: "cuda:0 bf16[1, 512, 11008]" - t1990 = torch.nn.functional.linear(t1988, t_transformer_h_13_mlp_fc_2_weight, None) # t1990: "cuda:0 bf16[1, 512, 11008]" - # t1990 = ltorch.linear(t1988, t_transformer_h_13_mlp_fc_2_weight, None) # t1990: "cuda:0 bf16[1, 512, 11008]" - # t1990 = prims.linear(t1988, t_transformer_h_13_mlp_fc_2_weight, None) # t1990: "cuda:0 bf16[1, 512, 11008]" - t1991 = prims.convert_element_type(t1989, dtypes.float32) # t1991: "cuda:0 f32[1, 512, 11008]" - t1992 = prims.neg(t1991) # t1992: "cuda:0 f32[1, 512, 11008]" - t1993 = prims.exp(t1992) # t1993: "cuda:0 f32[1, 512, 11008]" - t1994 = ltorch.add(1.0, t1993, alpha=None) # t1994: "cuda:0 f32[1, 512, 11008]" - # t1994 = prims.add(1.0, t1993) # t1994: "cuda:0 f32[1, 512, 11008]" - t1995 = prims.reciprocal(t1994) # t1995: "cuda:0 f32[1, 512, 11008]" - t1996 = prims.convert_element_type(t1995, dtypes.bfloat16) # t1996: "cuda:0 bf16[1, 512, 11008]" - t1997 = prims.convert_element_type(t1989, dtypes.float32) # t1997: "cuda:0 f32[1, 512, 11008]" - t1998 = prims.convert_element_type(t1996, dtypes.float32) # t1998: "cuda:0 f32[1, 512, 11008]" - t1999 = ltorch.mul(t1997, t1998) # t1999: "cuda:0 f32[1, 512, 11008]" - # t1999 = prims.mul(t1997, t1998) # t1999: "cuda:0 f32[1, 512, 11008]" - t2000 = prims.convert_element_type(t1999, dtypes.bfloat16) # t2000: "cuda:0 bf16[1, 512, 11008]" - t2001 = prims.convert_element_type(t2000, dtypes.float32) # t2001: "cuda:0 f32[1, 512, 11008]" - t2002 = prims.convert_element_type(t1990, dtypes.float32) # t2002: "cuda:0 f32[1, 512, 11008]" - t2003 = ltorch.mul(t2001, t2002) # t2003: "cuda:0 f32[1, 512, 11008]" - # t2003 = prims.mul(t2001, t2002) # t2003: "cuda:0 f32[1, 512, 11008]" - t2004 = prims.convert_element_type(t2003, dtypes.bfloat16) # t2004: "cuda:0 bf16[1, 512, 11008]" - t2005 = torch.nn.functional.linear(t2004, t_transformer_h_13_mlp_proj_weight, None) # t2005: "cuda:0 bf16[1, 512, 4096]" - # t2005 = ltorch.linear(t2004, t_transformer_h_13_mlp_proj_weight, None) # t2005: "cuda:0 bf16[1, 512, 4096]" - # t2005 = prims.linear(t2004, t_transformer_h_13_mlp_proj_weight, None) # t2005: "cuda:0 bf16[1, 512, 4096]" - t2006 = prims.convert_element_type(t2005, dtypes.float32) # t2006: "cuda:0 f32[1, 512, 4096]" - t2007 = prims.convert_element_type(t1970, dtypes.float32) # t2007: "cuda:0 f32[1, 512, 4096]" - t2008 = ltorch.add(t2006, t2007, alpha=None) # t2008: "cuda:0 f32[1, 512, 4096]" - # t2008 = prims.add(t2006, t2007) # t2008: "cuda:0 f32[1, 512, 4096]" - t2009 = prims.convert_element_type(t2008, dtypes.bfloat16) # t2009: "cuda:0 bf16[1, 512, 4096]" - t2010 = prims.convert_element_type(t2009, dtypes.float32) # t2010: "cuda:0 f32[1, 512, 4096]" - t2011 = ltorch.mul(t2010, t2010) # t2011: "cuda:0 f32[1, 512, 4096]" - # t2011 = prims.mul(t2010, t2010) # t2011: "cuda:0 f32[1, 512, 4096]" - t2013 = prims.sum(t2011, (2,)) # t2013: "cuda:0 f32[1, 512]" - t2014 = prims.broadcast_in_dim(t2013, [1, 512, 1], [0, 1]) # t2014: "cuda:0 f32[1, 512, 1]" - t2016 = ltorch.true_divide(t2014, 4096.0) # t2016: "cuda:0 f32[1, 512, 1]" - # t2016 = prims.div(t2014, 4096.0) # t2016: "cuda:0 f32[1, 512, 1]" - t2018 = ltorch.add(t2016, 1e-05, alpha=None) # t2018: "cuda:0 f32[1, 512, 1]" - # t2018 = prims.add(t2016, 1e-05) # t2018: "cuda:0 f32[1, 512, 1]" - t2019 = prims.rsqrt(t2018) # t2019: "cuda:0 f32[1, 512, 1]" - t2020 = prims.broadcast_in_dim(t2019, (1, 512, 4096), (0, 1, 2)) # t2020: "cuda:0 f32[1, 512, 4096]" - t2021 = ltorch.mul(t2010, t2020) # t2021: "cuda:0 f32[1, 512, 4096]" - # t2021 = prims.mul(t2010, t2020) # t2021: "cuda:0 f32[1, 512, 4096]" - t2022 = prims.convert_element_type(t2021, dtypes.bfloat16) # t2022: "cuda:0 bf16[1, 512, 4096]" - t2023 = prims.broadcast_in_dim(t_transformer_h_14_norm_1_weight, (1, 512, 4096), (2,)) # t2023: "cuda:0 bf16[1, 512, 4096]" - t2024 = prims.convert_element_type(t2022, dtypes.float32) # t2024: "cuda:0 f32[1, 512, 4096]" - t2025 = prims.convert_element_type(t2023, dtypes.float32) # t2025: "cuda:0 f32[1, 512, 4096]" - t2026 = ltorch.mul(t2024, t2025) # t2026: "cuda:0 f32[1, 512, 4096]" - # t2026 = prims.mul(t2024, t2025) # t2026: "cuda:0 f32[1, 512, 4096]" - t2027 = prims.convert_element_type(t2026, dtypes.bfloat16) # t2027: "cuda:0 bf16[1, 512, 4096]" - t2028 = torch.nn.functional.linear(t2027, t_transformer_h_14_attn_attn_weight, None) # t2028: "cuda:0 bf16[1, 512, 12288]" - # t2028 = ltorch.linear(t2027, t_transformer_h_14_attn_attn_weight, None) # t2028: "cuda:0 bf16[1, 512, 12288]" - # t2028 = prims.linear(t2027, t_transformer_h_14_attn_attn_weight, None) # t2028: "cuda:0 bf16[1, 512, 12288]" - t2034 = prims.reshape(t2028, (1, 512, 32, 3, 128)) # t2034: "cuda:0 bf16[1, 512, 32, 3, 128]" - t2040 = prims.transpose(t2034, (0, 2, 3, 1, 4)) # t2040: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t2041, t2042, t2043) = ltorch.split(t2040, (1, 1, 1), 2) - # t2041 = prims.slice_prim(t2040, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t2041: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2042 = prims.slice_prim(t2040, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t2042: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2043 = prims.slice_prim(t2040, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t2043: "cuda:0 bf16[1, 32, 1, 512, 128]" - t2049 = prims.reshape(t2041, (1, 32, 512, 128)) # t2049: "cuda:0 bf16[1, 32, 512, 128]" - t2055 = prims.reshape(t2042, (1, 32, 512, 128)) # t2055: "cuda:0 bf16[1, 32, 512, 128]" - t2061 = prims.reshape(t2043, (1, 32, 512, 128)) # t2061: "cuda:0 bf16[1, 32, 512, 128]" - t2062 = prims.slice_prim(t2049, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2062: "cuda:0 bf16[1, 32, 512, 128]" - t2063 = prims.slice_prim(t2062, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2063: "cuda:0 bf16[1, 32, 512, 64]" - t2064 = prims.slice_prim(t2062, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2064: "cuda:0 bf16[1, 32, 512, 64]" - t2065 = prims.convert_element_type(t2064, dtypes.float32) # t2065: "cuda:0 f32[1, 32, 512, 64]" - t2066 = prims.neg(t2065) # t2066: "cuda:0 f32[1, 32, 512, 64]" - t2067 = prims.convert_element_type(t2066, dtypes.bfloat16) # t2067: "cuda:0 bf16[1, 32, 512, 64]" - t2069 = prims.cat((t2067, t2063), -1) # t2069: "cuda:0 bf16[1, 32, 512, 128]" - t2070 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t2070: "cuda:0 f32[1, 32, 512, 128]" - t2071 = prims.convert_element_type(t2062, dtypes.float32) # t2071: "cuda:0 f32[1, 32, 512, 128]" - t2072 = ltorch.mul(t2071, t2070) # t2072: "cuda:0 f32[1, 32, 512, 128]" - # t2072 = prims.mul(t2071, t2070) # t2072: "cuda:0 f32[1, 32, 512, 128]" - t2073 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t2073: "cuda:0 f32[1, 32, 512, 128]" - t2074 = prims.convert_element_type(t2069, dtypes.float32) # t2074: "cuda:0 f32[1, 32, 512, 128]" - t2075 = ltorch.mul(t2074, t2073) # t2075: "cuda:0 f32[1, 32, 512, 128]" - # t2075 = prims.mul(t2074, t2073) # t2075: "cuda:0 f32[1, 32, 512, 128]" - t2076 = ltorch.add(t2072, t2075, alpha=None) # t2076: "cuda:0 f32[1, 32, 512, 128]" - # t2076 = prims.add(t2072, t2075) # t2076: "cuda:0 f32[1, 32, 512, 128]" - t2077 = prims.convert_element_type(t2076, dtypes.bfloat16) # t2077: "cuda:0 bf16[1, 32, 512, 128]" - t2078 = prims.slice_prim(t2055, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2078: "cuda:0 bf16[1, 32, 512, 128]" - t2079 = prims.slice_prim(t2078, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2079: "cuda:0 bf16[1, 32, 512, 64]" - t2080 = prims.slice_prim(t2078, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2080: "cuda:0 bf16[1, 32, 512, 64]" - t2081 = prims.convert_element_type(t2080, dtypes.float32) # t2081: "cuda:0 f32[1, 32, 512, 64]" - t2082 = prims.neg(t2081) # t2082: "cuda:0 f32[1, 32, 512, 64]" - t2083 = prims.convert_element_type(t2082, dtypes.bfloat16) # t2083: "cuda:0 bf16[1, 32, 512, 64]" - t2085 = prims.cat((t2083, t2079), -1) # t2085: "cuda:0 bf16[1, 32, 512, 128]" - t2086 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t2086: "cuda:0 f32[1, 32, 512, 128]" - t2087 = prims.convert_element_type(t2078, dtypes.float32) # t2087: "cuda:0 f32[1, 32, 512, 128]" - t2088 = ltorch.mul(t2087, t2086) # t2088: "cuda:0 f32[1, 32, 512, 128]" - # t2088 = prims.mul(t2087, t2086) # t2088: "cuda:0 f32[1, 32, 512, 128]" - t2089 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t2089: "cuda:0 f32[1, 32, 512, 128]" - t2090 = prims.convert_element_type(t2085, dtypes.float32) # t2090: "cuda:0 f32[1, 32, 512, 128]" - t2091 = ltorch.mul(t2090, t2089) # t2091: "cuda:0 f32[1, 32, 512, 128]" - # t2091 = prims.mul(t2090, t2089) # t2091: "cuda:0 f32[1, 32, 512, 128]" - t2092 = ltorch.add(t2088, t2091, alpha=None) # t2092: "cuda:0 f32[1, 32, 512, 128]" - # t2092 = prims.add(t2088, t2091) # t2092: "cuda:0 f32[1, 32, 512, 128]" - t2093 = prims.convert_element_type(t2092, dtypes.bfloat16) # t2093: "cuda:0 bf16[1, 32, 512, 128]" - t2094 = prims.slice_prim(t2049, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2094: "cuda:0 bf16[1, 32, 512, 0]" - t2096 = prims.cat((t2077, t2094), -1) # t2096: "cuda:0 bf16[1, 32, 512, 128]" - t2097 = prims.slice_prim(t2055, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2097: "cuda:0 bf16[1, 32, 512, 0]" - t2099 = prims.cat((t2093, t2097), -1) # t2099: "cuda:0 bf16[1, 32, 512, 128]" - (t2100, t2101, t2102, t2103) = cudnn_sdpa_fwd(t2096, t2099, t2061, None, 0.0, True, scale=0.08838834764831843) - t2106 = prims.transpose(t2100, (0, 2, 1, 3)) # t2106: "cuda:0 bf16[1, 512, 32, 128]" - t2110 = prims.reshape(t2106, (1, 512, 4096)) # t2110: "cuda:0 bf16[1, 512, 4096]" - t2111 = torch.nn.functional.linear(t2110, t_transformer_h_14_attn_proj_weight, None) # t2111: "cuda:0 bf16[1, 512, 4096]" - # t2111 = ltorch.linear(t2110, t_transformer_h_14_attn_proj_weight, None) # t2111: "cuda:0 bf16[1, 512, 4096]" - # t2111 = prims.linear(t2110, t_transformer_h_14_attn_proj_weight, None) # t2111: "cuda:0 bf16[1, 512, 4096]" - t2112 = prims.convert_element_type(t2111, dtypes.float32) # t2112: "cuda:0 f32[1, 512, 4096]" - t2113 = prims.convert_element_type(t2009, dtypes.float32) # t2113: "cuda:0 f32[1, 512, 4096]" - t2114 = ltorch.add(t2112, t2113, alpha=None) # t2114: "cuda:0 f32[1, 512, 4096]" - # t2114 = prims.add(t2112, t2113) # t2114: "cuda:0 f32[1, 512, 4096]" - t2115 = prims.convert_element_type(t2114, dtypes.bfloat16) # t2115: "cuda:0 bf16[1, 512, 4096]" - t2116 = prims.convert_element_type(t2115, dtypes.float32) # t2116: "cuda:0 f32[1, 512, 4096]" - t2117 = ltorch.mul(t2116, t2116) # t2117: "cuda:0 f32[1, 512, 4096]" - # t2117 = prims.mul(t2116, t2116) # t2117: "cuda:0 f32[1, 512, 4096]" - t2119 = prims.sum(t2117, (2,)) # t2119: "cuda:0 f32[1, 512]" - t2120 = prims.broadcast_in_dim(t2119, [1, 512, 1], [0, 1]) # t2120: "cuda:0 f32[1, 512, 1]" - t2122 = ltorch.true_divide(t2120, 4096.0) # t2122: "cuda:0 f32[1, 512, 1]" - # t2122 = prims.div(t2120, 4096.0) # t2122: "cuda:0 f32[1, 512, 1]" - t2124 = ltorch.add(t2122, 1e-05, alpha=None) # t2124: "cuda:0 f32[1, 512, 1]" - # t2124 = prims.add(t2122, 1e-05) # t2124: "cuda:0 f32[1, 512, 1]" - t2125 = prims.rsqrt(t2124) # t2125: "cuda:0 f32[1, 512, 1]" - t2126 = prims.broadcast_in_dim(t2125, (1, 512, 4096), (0, 1, 2)) # t2126: "cuda:0 f32[1, 512, 4096]" - t2127 = ltorch.mul(t2116, t2126) # t2127: "cuda:0 f32[1, 512, 4096]" - # t2127 = prims.mul(t2116, t2126) # t2127: "cuda:0 f32[1, 512, 4096]" - t2128 = prims.convert_element_type(t2127, dtypes.bfloat16) # t2128: "cuda:0 bf16[1, 512, 4096]" - t2129 = prims.broadcast_in_dim(t_transformer_h_14_norm_2_weight, (1, 512, 4096), (2,)) # t2129: "cuda:0 bf16[1, 512, 4096]" - t2130 = prims.convert_element_type(t2128, dtypes.float32) # t2130: "cuda:0 f32[1, 512, 4096]" - t2131 = prims.convert_element_type(t2129, dtypes.float32) # t2131: "cuda:0 f32[1, 512, 4096]" - t2132 = ltorch.mul(t2130, t2131) # t2132: "cuda:0 f32[1, 512, 4096]" - # t2132 = prims.mul(t2130, t2131) # t2132: "cuda:0 f32[1, 512, 4096]" - t2133 = prims.convert_element_type(t2132, dtypes.bfloat16) # t2133: "cuda:0 bf16[1, 512, 4096]" - t2134 = torch.nn.functional.linear(t2133, t_transformer_h_14_mlp_fc_1_weight, None) # t2134: "cuda:0 bf16[1, 512, 11008]" - # t2134 = ltorch.linear(t2133, t_transformer_h_14_mlp_fc_1_weight, None) # t2134: "cuda:0 bf16[1, 512, 11008]" - # t2134 = prims.linear(t2133, t_transformer_h_14_mlp_fc_1_weight, None) # t2134: "cuda:0 bf16[1, 512, 11008]" - t2135 = torch.nn.functional.linear(t2133, t_transformer_h_14_mlp_fc_2_weight, None) # t2135: "cuda:0 bf16[1, 512, 11008]" - # t2135 = ltorch.linear(t2133, t_transformer_h_14_mlp_fc_2_weight, None) # t2135: "cuda:0 bf16[1, 512, 11008]" - # t2135 = prims.linear(t2133, t_transformer_h_14_mlp_fc_2_weight, None) # t2135: "cuda:0 bf16[1, 512, 11008]" - t2136 = prims.convert_element_type(t2134, dtypes.float32) # t2136: "cuda:0 f32[1, 512, 11008]" - t2137 = prims.neg(t2136) # t2137: "cuda:0 f32[1, 512, 11008]" - t2138 = prims.exp(t2137) # t2138: "cuda:0 f32[1, 512, 11008]" - t2139 = ltorch.add(1.0, t2138, alpha=None) # t2139: "cuda:0 f32[1, 512, 11008]" - # t2139 = prims.add(1.0, t2138) # t2139: "cuda:0 f32[1, 512, 11008]" - t2140 = prims.reciprocal(t2139) # t2140: "cuda:0 f32[1, 512, 11008]" - t2141 = prims.convert_element_type(t2140, dtypes.bfloat16) # t2141: "cuda:0 bf16[1, 512, 11008]" - t2142 = prims.convert_element_type(t2134, dtypes.float32) # t2142: "cuda:0 f32[1, 512, 11008]" - t2143 = prims.convert_element_type(t2141, dtypes.float32) # t2143: "cuda:0 f32[1, 512, 11008]" - t2144 = ltorch.mul(t2142, t2143) # t2144: "cuda:0 f32[1, 512, 11008]" - # t2144 = prims.mul(t2142, t2143) # t2144: "cuda:0 f32[1, 512, 11008]" - t2145 = prims.convert_element_type(t2144, dtypes.bfloat16) # t2145: "cuda:0 bf16[1, 512, 11008]" - t2146 = prims.convert_element_type(t2145, dtypes.float32) # t2146: "cuda:0 f32[1, 512, 11008]" - t2147 = prims.convert_element_type(t2135, dtypes.float32) # t2147: "cuda:0 f32[1, 512, 11008]" - t2148 = ltorch.mul(t2146, t2147) # t2148: "cuda:0 f32[1, 512, 11008]" - # t2148 = prims.mul(t2146, t2147) # t2148: "cuda:0 f32[1, 512, 11008]" - t2149 = prims.convert_element_type(t2148, dtypes.bfloat16) # t2149: "cuda:0 bf16[1, 512, 11008]" - t2150 = torch.nn.functional.linear(t2149, t_transformer_h_14_mlp_proj_weight, None) # t2150: "cuda:0 bf16[1, 512, 4096]" - # t2150 = ltorch.linear(t2149, t_transformer_h_14_mlp_proj_weight, None) # t2150: "cuda:0 bf16[1, 512, 4096]" - # t2150 = prims.linear(t2149, t_transformer_h_14_mlp_proj_weight, None) # t2150: "cuda:0 bf16[1, 512, 4096]" - t2151 = prims.convert_element_type(t2150, dtypes.float32) # t2151: "cuda:0 f32[1, 512, 4096]" - t2152 = prims.convert_element_type(t2115, dtypes.float32) # t2152: "cuda:0 f32[1, 512, 4096]" - t2153 = ltorch.add(t2151, t2152, alpha=None) # t2153: "cuda:0 f32[1, 512, 4096]" - # t2153 = prims.add(t2151, t2152) # t2153: "cuda:0 f32[1, 512, 4096]" - t2154 = prims.convert_element_type(t2153, dtypes.bfloat16) # t2154: "cuda:0 bf16[1, 512, 4096]" - t2155 = prims.convert_element_type(t2154, dtypes.float32) # t2155: "cuda:0 f32[1, 512, 4096]" - t2156 = ltorch.mul(t2155, t2155) # t2156: "cuda:0 f32[1, 512, 4096]" - # t2156 = prims.mul(t2155, t2155) # t2156: "cuda:0 f32[1, 512, 4096]" - t2158 = prims.sum(t2156, (2,)) # t2158: "cuda:0 f32[1, 512]" - t2159 = prims.broadcast_in_dim(t2158, [1, 512, 1], [0, 1]) # t2159: "cuda:0 f32[1, 512, 1]" - t2161 = ltorch.true_divide(t2159, 4096.0) # t2161: "cuda:0 f32[1, 512, 1]" - # t2161 = prims.div(t2159, 4096.0) # t2161: "cuda:0 f32[1, 512, 1]" - t2163 = ltorch.add(t2161, 1e-05, alpha=None) # t2163: "cuda:0 f32[1, 512, 1]" - # t2163 = prims.add(t2161, 1e-05) # t2163: "cuda:0 f32[1, 512, 1]" - t2164 = prims.rsqrt(t2163) # t2164: "cuda:0 f32[1, 512, 1]" - t2165 = prims.broadcast_in_dim(t2164, (1, 512, 4096), (0, 1, 2)) # t2165: "cuda:0 f32[1, 512, 4096]" - t2166 = ltorch.mul(t2155, t2165) # t2166: "cuda:0 f32[1, 512, 4096]" - # t2166 = prims.mul(t2155, t2165) # t2166: "cuda:0 f32[1, 512, 4096]" - t2167 = prims.convert_element_type(t2166, dtypes.bfloat16) # t2167: "cuda:0 bf16[1, 512, 4096]" - t2168 = prims.broadcast_in_dim(t_transformer_h_15_norm_1_weight, (1, 512, 4096), (2,)) # t2168: "cuda:0 bf16[1, 512, 4096]" - t2169 = prims.convert_element_type(t2167, dtypes.float32) # t2169: "cuda:0 f32[1, 512, 4096]" - t2170 = prims.convert_element_type(t2168, dtypes.float32) # t2170: "cuda:0 f32[1, 512, 4096]" - t2171 = ltorch.mul(t2169, t2170) # t2171: "cuda:0 f32[1, 512, 4096]" - # t2171 = prims.mul(t2169, t2170) # t2171: "cuda:0 f32[1, 512, 4096]" - t2172 = prims.convert_element_type(t2171, dtypes.bfloat16) # t2172: "cuda:0 bf16[1, 512, 4096]" - t2173 = torch.nn.functional.linear(t2172, t_transformer_h_15_attn_attn_weight, None) # t2173: "cuda:0 bf16[1, 512, 12288]" - # t2173 = ltorch.linear(t2172, t_transformer_h_15_attn_attn_weight, None) # t2173: "cuda:0 bf16[1, 512, 12288]" - # t2173 = prims.linear(t2172, t_transformer_h_15_attn_attn_weight, None) # t2173: "cuda:0 bf16[1, 512, 12288]" - t2179 = prims.reshape(t2173, (1, 512, 32, 3, 128)) # t2179: "cuda:0 bf16[1, 512, 32, 3, 128]" - t2185 = prims.transpose(t2179, (0, 2, 3, 1, 4)) # t2185: "cuda:0 bf16[1, 32, 3, 512, 128]" - (t2186, t2187, t2188) = ltorch.split(t2185, (1, 1, 1), 2) - # t2186 = prims.slice_prim(t2185, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t2186: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2187 = prims.slice_prim(t2185, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t2187: "cuda:0 bf16[1, 32, 1, 512, 128]" - # t2188 = prims.slice_prim(t2185, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t2188: "cuda:0 bf16[1, 32, 1, 512, 128]" - t2194 = prims.reshape(t2186, (1, 32, 512, 128)) # t2194: "cuda:0 bf16[1, 32, 512, 128]" - t2200 = prims.reshape(t2187, (1, 32, 512, 128)) # t2200: "cuda:0 bf16[1, 32, 512, 128]" - t2206 = prims.reshape(t2188, (1, 32, 512, 128)) # t2206: "cuda:0 bf16[1, 32, 512, 128]" - t2207 = prims.slice_prim(t2194, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2207: "cuda:0 bf16[1, 32, 512, 128]" - t2208 = prims.slice_prim(t2207, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2208: "cuda:0 bf16[1, 32, 512, 64]" - t2209 = prims.slice_prim(t2207, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2209: "cuda:0 bf16[1, 32, 512, 64]" - t2210 = prims.convert_element_type(t2209, dtypes.float32) # t2210: "cuda:0 f32[1, 32, 512, 64]" - t2211 = prims.neg(t2210) # t2211: "cuda:0 f32[1, 32, 512, 64]" - t2212 = prims.convert_element_type(t2211, dtypes.bfloat16) # t2212: "cuda:0 bf16[1, 32, 512, 64]" - t2214 = prims.cat((t2212, t2208), -1) # t2214: "cuda:0 bf16[1, 32, 512, 128]" - t2215 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t2215: "cuda:0 f32[1, 32, 512, 128]" - t2216 = prims.convert_element_type(t2207, dtypes.float32) # t2216: "cuda:0 f32[1, 32, 512, 128]" - t2217 = ltorch.mul(t2216, t2215) # t2217: "cuda:0 f32[1, 32, 512, 128]" - # t2217 = prims.mul(t2216, t2215) # t2217: "cuda:0 f32[1, 32, 512, 128]" - t2218 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t2218: "cuda:0 f32[1, 32, 512, 128]" - t2219 = prims.convert_element_type(t2214, dtypes.float32) # t2219: "cuda:0 f32[1, 32, 512, 128]" - t2220 = ltorch.mul(t2219, t2218) # t2220: "cuda:0 f32[1, 32, 512, 128]" - # t2220 = prims.mul(t2219, t2218) # t2220: "cuda:0 f32[1, 32, 512, 128]" - t2221 = ltorch.add(t2217, t2220, alpha=None) # t2221: "cuda:0 f32[1, 32, 512, 128]" - # t2221 = prims.add(t2217, t2220) # t2221: "cuda:0 f32[1, 32, 512, 128]" - t2222 = prims.convert_element_type(t2221, dtypes.bfloat16) # t2222: "cuda:0 bf16[1, 32, 512, 128]" - t2223 = prims.slice_prim(t2200, [0, 0, 0, 0], [1, 32, 512, 128], [1, 1, 1, 1]) # t2223: "cuda:0 bf16[1, 32, 512, 128]" - t2224 = prims.slice_prim(t2223, [0, 0, 0, 0], [1, 32, 512, 64], [1, 1, 1, 1]) # t2224: "cuda:0 bf16[1, 32, 512, 64]" - t2225 = prims.slice_prim(t2223, [0, 0, 0, 64], [1, 32, 512, 128], [1, 1, 1, 1]) # t2225: "cuda:0 bf16[1, 32, 512, 64]" - t2226 = prims.convert_element_type(t2225, dtypes.float32) # t2226: "cuda:0 f32[1, 32, 512, 64]" - t2227 = prims.neg(t2226) # t2227: "cuda:0 f32[1, 32, 512, 64]" - t2228 = prims.convert_element_type(t2227, dtypes.bfloat16) # t2228: "cuda:0 bf16[1, 32, 512, 64]" - t2230 = prims.cat((t2228, t2224), -1) # t2230: "cuda:0 bf16[1, 32, 512, 128]" - t2231 = prims.broadcast_in_dim(t0, (1, 32, 512, 128), (2, 3)) # t2231: "cuda:0 f32[1, 32, 512, 128]" - t2232 = prims.convert_element_type(t2223, dtypes.float32) # t2232: "cuda:0 f32[1, 32, 512, 128]" - t2233 = ltorch.mul(t2232, t2231) # t2233: "cuda:0 f32[1, 32, 512, 128]" - # t2233 = prims.mul(t2232, t2231) # t2233: "cuda:0 f32[1, 32, 512, 128]" - t2234 = prims.broadcast_in_dim(t1, (1, 32, 512, 128), (2, 3)) # t2234: "cuda:0 f32[1, 32, 512, 128]" - t2235 = prims.convert_element_type(t2230, dtypes.float32) # t2235: "cuda:0 f32[1, 32, 512, 128]" - t2236 = ltorch.mul(t2235, t2234) # t2236: "cuda:0 f32[1, 32, 512, 128]" - # t2236 = prims.mul(t2235, t2234) # t2236: "cuda:0 f32[1, 32, 512, 128]" - t2237 = ltorch.add(t2233, t2236, alpha=None) # t2237: "cuda:0 f32[1, 32, 512, 128]" - # t2237 = prims.add(t2233, t2236) # t2237: "cuda:0 f32[1, 32, 512, 128]" - t2238 = prims.convert_element_type(t2237, dtypes.bfloat16) # t2238: "cuda:0 bf16[1, 32, 512, 128]" - t2239 = prims.slice_prim(t2194, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2239: "cuda:0 bf16[1, 32, 512, 0]" - t2241 = prims.cat((t2222, t2239), -1) # t2241: "cuda:0 bf16[1, 32, 512, 128]" - t2242 = prims.slice_prim(t2200, [0, 0, 0, 0], [1, 32, 512, 0], [1, 1, 1, 1]) # t2242: "cuda:0 bf16[1, 32, 512, 0]" - t2244 = prims.cat((t2238, t2242), -1) # t2244: "cuda:0 bf16[1, 32, 512, 128]" - (t2245, t2246, t2247, t2248) = cudnn_sdpa_fwd(t2241, t2244, t2206, None, 0.0, True, scale=0.08838834764831843) - t2251 = prims.transpose(t2245, (0, 2, 1, 3)) # t2251: "cuda:0 bf16[1, 512, 32, 128]" - t2255 = prims.reshape(t2251, (1, 512, 4096)) # t2255: "cuda:0 bf16[1, 512, 4096]" - t2256 = torch.nn.functional.linear(t2255, t_transformer_h_15_attn_proj_weight, None) # t2256: "cuda:0 bf16[1, 512, 4096]" - # t2256 = ltorch.linear(t2255, t_transformer_h_15_attn_proj_weight, None) # t2256: "cuda:0 bf16[1, 512, 4096]" - # t2256 = prims.linear(t2255, t_transformer_h_15_attn_proj_weight, None) # t2256: "cuda:0 bf16[1, 512, 4096]" - t2257 = prims.convert_element_type(t2256, dtypes.float32) # t2257: "cuda:0 f32[1, 512, 4096]" - t2258 = prims.convert_element_type(t2154, dtypes.float32) # t2258: "cuda:0 f32[1, 512, 4096]" - t2259 = ltorch.add(t2257, t2258, alpha=None) # t2259: "cuda:0 f32[1, 512, 4096]" - # t2259 = prims.add(t2257, t2258) # t2259: "cuda:0 f32[1, 512, 4096]" - t2260 = prims.convert_element_type(t2259, dtypes.bfloat16) # t2260: "cuda:0 bf16[1, 512, 4096]" - t2261 = prims.convert_element_type(t2260, dtypes.float32) # t2261: "cuda:0 f32[1, 512, 4096]" - t2262 = ltorch.mul(t2261, t2261) # t2262: "cuda:0 f32[1, 512, 4096]" - # t2262 = prims.mul(t2261, t2261) # t2262: "cuda:0 f32[1, 512, 4096]" - t2264 = prims.sum(t2262, (2,)) # t2264: "cuda:0 f32[1, 512]" - t2265 = prims.broadcast_in_dim(t2264, [1, 512, 1], [0, 1]) # t2265: "cuda:0 f32[1, 512, 1]" - t2267 = ltorch.true_divide(t2265, 4096.0) # t2267: "cuda:0 f32[1, 512, 1]" - # t2267 = prims.div(t2265, 4096.0) # t2267: "cuda:0 f32[1, 512, 1]" - t2269 = ltorch.add(t2267, 1e-05, alpha=None) # t2269: "cuda:0 f32[1, 512, 1]" - # t2269 = prims.add(t2267, 1e-05) # t2269: "cuda:0 f32[1, 512, 1]" - t2270 = prims.rsqrt(t2269) # t2270: "cuda:0 f32[1, 512, 1]" - t2271 = prims.broadcast_in_dim(t2270, (1, 512, 4096), (0, 1, 2)) # t2271: "cuda:0 f32[1, 512, 4096]" - t2272 = ltorch.mul(t2261, t2271) # t2272: "cuda:0 f32[1, 512, 4096]" - # t2272 = prims.mul(t2261, t2271) # t2272: "cuda:0 f32[1, 512, 4096]" - t2273 = prims.convert_element_type(t2272, dtypes.bfloat16) # t2273: "cuda:0 bf16[1, 512, 4096]" - t2274 = prims.broadcast_in_dim(t_transformer_h_15_norm_2_weight, (1, 512, 4096), (2,)) # t2274: "cuda:0 bf16[1, 512, 4096]" - t2275 = prims.convert_element_type(t2273, dtypes.float32) # t2275: "cuda:0 f32[1, 512, 4096]" - t2276 = prims.convert_element_type(t2274, dtypes.float32) # t2276: "cuda:0 f32[1, 512, 4096]" - t2277 = ltorch.mul(t2275, t2276) # t2277: "cuda:0 f32[1, 512, 4096]" - # t2277 = prims.mul(t2275, t2276) # t2277: "cuda:0 f32[1, 512, 4096]" - t2278 = prims.convert_element_type(t2277, dtypes.bfloat16) # t2278: "cuda:0 bf16[1, 512, 4096]" - t2279 = torch.nn.functional.linear(t2278, t_transformer_h_15_mlp_fc_1_weight, None) # t2279: "cuda:0 bf16[1, 512, 11008]" - # t2279 = ltorch.linear(t2278, t_transformer_h_15_mlp_fc_1_weight, None) # t2279: "cuda:0 bf16[1, 512, 11008]" - # t2279 = prims.linear(t2278, t_transformer_h_15_mlp_fc_1_weight, None) # t2279: "cuda:0 bf16[1, 512, 11008]" - t2280 = torch.nn.functional.linear(t2278, t_transformer_h_15_mlp_fc_2_weight, None) # t2280: "cuda:0 bf16[1, 512, 11008]" - # t2280 = ltorch.linear(t2278, t_transformer_h_15_mlp_fc_2_weight, None) # t2280: "cuda:0 bf16[1, 512, 11008]" - # t2280 = prims.linear(t2278, t_transformer_h_15_mlp_fc_2_weight, None) # t2280: "cuda:0 bf16[1, 512, 11008]" - t2281 = prims.convert_element_type(t2279, dtypes.float32) # t2281: "cuda:0 f32[1, 512, 11008]" - t2282 = prims.neg(t2281) # t2282: "cuda:0 f32[1, 512, 11008]" - t2283 = prims.exp(t2282) # t2283: "cuda:0 f32[1, 512, 11008]" - t2284 = ltorch.add(1.0, t2283, alpha=None) # t2284: "cuda:0 f32[1, 512, 11008]" - # t2284 = prims.add(1.0, t2283) # t2284: "cuda:0 f32[1, 512, 11008]" - t2285 = prims.reciprocal(t2284) # t2285: "cuda:0 f32[1, 512, 11008]" - t2286 = prims.convert_element_type(t2285, dtypes.bfloat16) # t2286: "cuda:0 bf16[1, 512, 11008]" - t2287 = prims.convert_element_type(t2279, dtypes.float32) # t2287: "cuda:0 f32[1, 512, 11008]" - t2288 = prims.convert_element_type(t2286, dtypes.float32) # t2288: "cuda:0 f32[1, 512, 11008]" - t2289 = ltorch.mul(t2287, t2288) # t2289: "cuda:0 f32[1, 512, 11008]" - # t2289 = prims.mul(t2287, t2288) # t2289: "cuda:0 f32[1, 512, 11008]" - t2290 = prims.convert_element_type(t2289, dtypes.bfloat16) # t2290: "cuda:0 bf16[1, 512, 11008]" - t2291 = prims.convert_element_type(t2290, dtypes.float32) # t2291: "cuda:0 f32[1, 512, 11008]" - t2292 = prims.convert_element_type(t2280, dtypes.float32) # t2292: "cuda:0 f32[1, 512, 11008]" - t2293 = ltorch.mul(t2291, t2292) # t2293: "cuda:0 f32[1, 512, 11008]" - # t2293 = prims.mul(t2291, t2292) # t2293: "cuda:0 f32[1, 512, 11008]" - t2294 = prims.convert_element_type(t2293, dtypes.bfloat16) # t2294: "cuda:0 bf16[1, 512, 11008]" - t2295 = torch.nn.functional.linear(t2294, t_transformer_h_15_mlp_proj_weight, None) # t2295: "cuda:0 bf16[1, 512, 4096]" - # t2295 = ltorch.linear(t2294, t_transformer_h_15_mlp_proj_weight, None) # t2295: "cuda:0 bf16[1, 512, 4096]" - # t2295 = prims.linear(t2294, t_transformer_h_15_mlp_proj_weight, None) # t2295: "cuda:0 bf16[1, 512, 4096]" - t2296 = prims.convert_element_type(t2295, dtypes.float32) # t2296: "cuda:0 f32[1, 512, 4096]" - t2297 = prims.convert_element_type(t2260, dtypes.float32) # t2297: "cuda:0 f32[1, 512, 4096]" - t2298 = ltorch.add(t2296, t2297, alpha=None) # t2298: "cuda:0 f32[1, 512, 4096]" - # t2298 = prims.add(t2296, t2297) # t2298: "cuda:0 f32[1, 512, 4096]" - t2299 = prims.convert_element_type(t2298, dtypes.bfloat16) # t2299: "cuda:0 bf16[1, 512, 4096]" - t2300 = prims.convert_element_type(t2299, dtypes.float32) # t2300: "cuda:0 f32[1, 512, 4096]" - t2301 = ltorch.mul(t2300, t2300) # t2301: "cuda:0 f32[1, 512, 4096]" - # t2301 = prims.mul(t2300, t2300) # t2301: "cuda:0 f32[1, 512, 4096]" - t2303 = prims.sum(t2301, (2,)) # t2303: "cuda:0 f32[1, 512]" - t2304 = prims.broadcast_in_dim(t2303, [1, 512, 1], [0, 1]) # t2304: "cuda:0 f32[1, 512, 1]" - t2306 = ltorch.true_divide(t2304, 4096.0) # t2306: "cuda:0 f32[1, 512, 1]" - # t2306 = prims.div(t2304, 4096.0) # t2306: "cuda:0 f32[1, 512, 1]" - t2308 = ltorch.add(t2306, 1e-05, alpha=None) # t2308: "cuda:0 f32[1, 512, 1]" - # t2308 = prims.add(t2306, 1e-05) # t2308: "cuda:0 f32[1, 512, 1]" - t2309 = prims.rsqrt(t2308) # t2309: "cuda:0 f32[1, 512, 1]" - t2310 = prims.broadcast_in_dim(t2309, (1, 512, 4096), (0, 1, 2)) # t2310: "cuda:0 f32[1, 512, 4096]" - t2311 = ltorch.mul(t2300, t2310) # t2311: "cuda:0 f32[1, 512, 4096]" - # t2311 = prims.mul(t2300, t2310) # t2311: "cuda:0 f32[1, 512, 4096]" - t2312 = prims.convert_element_type(t2311, dtypes.bfloat16) # t2312: "cuda:0 bf16[1, 512, 4096]" - t2313 = prims.broadcast_in_dim(t_transformer_ln_f_weight, (1, 512, 4096), (2,)) # t2313: "cuda:0 bf16[1, 512, 4096]" - t2314 = prims.convert_element_type(t2312, dtypes.float32) # t2314: "cuda:0 f32[1, 512, 4096]" - t2315 = prims.convert_element_type(t2313, dtypes.float32) # t2315: "cuda:0 f32[1, 512, 4096]" - t2316 = ltorch.mul(t2314, t2315) # t2316: "cuda:0 f32[1, 512, 4096]" - # t2316 = prims.mul(t2314, t2315) # t2316: "cuda:0 f32[1, 512, 4096]" - t2317 = prims.convert_element_type(t2316, dtypes.bfloat16) # t2317: "cuda:0 bf16[1, 512, 4096]" - t2318 = torch.nn.functional.linear(t2317, t_lm_head_weight, None) # t2318: "cuda:0 bf16[1, 512, 32000]" - # t2318 = ltorch.linear(t2317, t_lm_head_weight, None) # t2318: "cuda:0 bf16[1, 512, 32000]" - # t2318 = prims.linear(t2317, t_lm_head_weight, None) # t2318: "cuda:0 bf16[1, 512, 32000]" - return {'output': t2318, 'flat_args': [idx, tos1, t_lm_head_weight, t_sin, t_transformer_h_0_attn_attn_weight, t_transformer_h_0_attn_proj_weight, t_transformer_h_0_mlp_fc_1_weight, t_transformer_h_0_mlp_fc_2_weight, t_transformer_h_0_mlp_proj_weight, t_transformer_h_0_norm_1_weight, t_transformer_h_0_norm_2_weight, t_transformer_h_1_attn_attn_weight, t_transformer_h_1_attn_proj_weight, t_transformer_h_1_mlp_fc_1_weight, t_transformer_h_1_mlp_fc_2_weight, t_transformer_h_1_mlp_proj_weight, t_transformer_h_1_norm_1_weight, t_transformer_h_1_norm_2_weight, t_transformer_h_2_attn_attn_weight, t_transformer_h_2_attn_proj_weight, t_transformer_h_2_mlp_fc_1_weight, t_transformer_h_2_mlp_fc_2_weight, t_transformer_h_2_mlp_proj_weight, t_transformer_h_2_norm_1_weight, t_transformer_h_2_norm_2_weight, t_transformer_h_3_attn_attn_weight, t_transformer_h_3_attn_proj_weight, t_transformer_h_3_mlp_fc_1_weight, t_transformer_h_3_mlp_fc_2_weight, t_transformer_h_3_mlp_proj_weight, t_transformer_h_3_norm_1_weight, t_transformer_h_3_norm_2_weight, t_transformer_h_4_attn_attn_weight, t_transformer_h_4_attn_proj_weight, t_transformer_h_4_mlp_fc_1_weight, t_transformer_h_4_mlp_fc_2_weight, t_transformer_h_4_mlp_proj_weight, t_transformer_h_4_norm_1_weight, t_transformer_h_4_norm_2_weight, t_transformer_h_5_attn_attn_weight, t_transformer_h_5_attn_proj_weight, t_transformer_h_5_mlp_fc_1_weight, t_transformer_h_5_mlp_fc_2_weight, t_transformer_h_5_mlp_proj_weight, t_transformer_h_5_norm_1_weight, t_transformer_h_5_norm_2_weight, t_transformer_h_6_attn_attn_weight, t_transformer_h_6_attn_proj_weight, t_transformer_h_6_mlp_fc_1_weight, t_transformer_h_6_mlp_fc_2_weight, t_transformer_h_6_mlp_proj_weight, t_transformer_h_6_norm_1_weight, t_transformer_h_6_norm_2_weight, t_transformer_h_7_attn_attn_weight, t_transformer_h_7_attn_proj_weight, t_transformer_h_7_mlp_fc_1_weight, t_transformer_h_7_mlp_fc_2_weight, t_transformer_h_7_mlp_proj_weight, t_transformer_h_7_norm_1_weight, t_transformer_h_7_norm_2_weight, t_transformer_h_8_attn_attn_weight, t_transformer_h_8_attn_proj_weight, t_transformer_h_8_mlp_fc_1_weight, t_transformer_h_8_mlp_fc_2_weight, t_transformer_h_8_mlp_proj_weight, t_transformer_h_8_norm_1_weight, t_transformer_h_8_norm_2_weight, t_transformer_h_9_attn_attn_weight, t_transformer_h_9_attn_proj_weight, t_transformer_h_9_mlp_fc_1_weight, t_transformer_h_9_mlp_fc_2_weight, t_transformer_h_9_mlp_proj_weight, t_transformer_h_9_norm_1_weight, t_transformer_h_9_norm_2_weight, t_transformer_h_10_attn_attn_weight, t_transformer_h_10_attn_proj_weight, t_transformer_h_10_mlp_fc_1_weight, t_transformer_h_10_mlp_fc_2_weight, t_transformer_h_10_mlp_proj_weight, t_transformer_h_10_norm_1_weight, t_transformer_h_10_norm_2_weight, t_transformer_h_11_attn_attn_weight, t_transformer_h_11_attn_proj_weight, t_transformer_h_11_mlp_fc_1_weight, t_transformer_h_11_mlp_fc_2_weight, t_transformer_h_11_mlp_proj_weight, t_transformer_h_11_norm_1_weight, t_transformer_h_11_norm_2_weight, t_transformer_h_12_attn_attn_weight, t_transformer_h_12_attn_proj_weight, t_transformer_h_12_mlp_fc_1_weight, t_transformer_h_12_mlp_fc_2_weight, t_transformer_h_12_mlp_proj_weight, t_transformer_h_12_norm_1_weight, t_transformer_h_12_norm_2_weight, t_transformer_h_13_attn_attn_weight, t_transformer_h_13_attn_proj_weight, t_transformer_h_13_mlp_fc_1_weight, t_transformer_h_13_mlp_fc_2_weight, t_transformer_h_13_mlp_proj_weight, t_transformer_h_13_norm_1_weight, t_transformer_h_13_norm_2_weight, t_transformer_h_14_attn_attn_weight, t_transformer_h_14_attn_proj_weight, t_transformer_h_14_mlp_fc_1_weight, t_transformer_h_14_mlp_fc_2_weight, t_transformer_h_14_mlp_proj_weight, t_transformer_h_14_norm_1_weight, t_transformer_h_14_norm_2_weight, t_transformer_h_15_attn_attn_weight, t_transformer_h_15_attn_proj_weight, t_transformer_h_15_mlp_fc_1_weight, t_transformer_h_15_mlp_fc_2_weight, t_transformer_h_15_mlp_proj_weight, t_transformer_h_15_norm_1_weight, t_transformer_h_15_norm_2_weight, t_transformer_ln_f_weight, t_transformer_wte_weight], 'flat_output': (t2318,)}, ((idx, t5, t11, t12, t17, t16, t19, t_transformer_h_0_attn_attn_weight, t46, t47, t49, t50, t62, t63, t65, t66, t71, t74, t38, t75, t76, t77, t78, t80, t_transformer_h_0_attn_proj_weight, t86, t95, t96, t101, t100, t103, t_transformer_h_0_mlp_fc_1_weight, t_transformer_h_0_mlp_fc_2_weight, t108, t110, t113, t112, t117, t116, t119, t_transformer_h_0_mlp_proj_weight, t125, t134, t135, t140, t139, t142, t_transformer_h_1_attn_attn_weight, t185, t186, t188, t189, t201, t202, t204, t205, t211, t214, t176, t215, t216, t217, t218, t225, t_transformer_h_1_attn_proj_weight, t231, t240, t241, t246, t245, t248, t_transformer_h_1_mlp_fc_1_weight, t_transformer_h_1_mlp_fc_2_weight, t253, t255, t258, t257, t262, t261, t264, t_transformer_h_1_mlp_proj_weight, t270, t279, t280, t285, t284, t287, t_transformer_h_2_attn_attn_weight, t330, t331, t333, t334, t346, t347, t349, t350, t356, t359, t321, t360, t361, t362, t363, t370, t_transformer_h_2_attn_proj_weight, t376, t385, t386, t391, t390, t393, t_transformer_h_2_mlp_fc_1_weight, t_transformer_h_2_mlp_fc_2_weight, t398, t400, t403, t402, t407, t406, t409, t_transformer_h_2_mlp_proj_weight, t415, t424, t425, t430, t429, t432, t_transformer_h_3_attn_attn_weight, t475, t476, t478, t479, t491, t492, t494, t495, t501, t504, t466, t505, t506, t507, t508, t515, t_transformer_h_3_attn_proj_weight, t521, t530, t531, t536, t535, t538, t_transformer_h_3_mlp_fc_1_weight, t_transformer_h_3_mlp_fc_2_weight, t543, t545, t548, t547, t552, t551, t554, t_transformer_h_3_mlp_proj_weight, t560, t569, t570, t575, t574, t577, t_transformer_h_4_attn_attn_weight, t620, t621, t623, t624, t636, t637, t639, t640, t646, t649, t611, t650, t651, t652, t653, t660, t_transformer_h_4_attn_proj_weight, t666, t675, t676, t681, t680, t683, t_transformer_h_4_mlp_fc_1_weight, t_transformer_h_4_mlp_fc_2_weight, t688, t690, t693, t692, t697, t696, t699, t_transformer_h_4_mlp_proj_weight, t705, t714, t715, t720, t719, t722, t_transformer_h_5_attn_attn_weight, t765, t766, t768, t769, t781, t782, t784, t785, t791, t794, t756, t795, t796, t797, t798, t805, t_transformer_h_5_attn_proj_weight, t811, t820, t821, t826, t825, t828, t_transformer_h_5_mlp_fc_1_weight, t_transformer_h_5_mlp_fc_2_weight, t833, t835, t838, t837, t842, t841, t844, t_transformer_h_5_mlp_proj_weight, t850, t859, t860, t865, t864, t867, t_transformer_h_6_attn_attn_weight, t910, t911, t913, t914, t926, t927, t929, t930, t936, t939, t901, t940, t941, t942, t943, t950, t_transformer_h_6_attn_proj_weight, t956, t965, t966, t971, t970, t973, t_transformer_h_6_mlp_fc_1_weight, t_transformer_h_6_mlp_fc_2_weight, t978, t980, t983, t982, t987, t986, t989, t_transformer_h_6_mlp_proj_weight, t995, t1004, t1005, t1010, t1009, t1012, t_transformer_h_7_attn_attn_weight, t1055, t1056, t1058, t1059, t1071, t1072, t1074, t1075, t1081, t1084, t1046, t1085, t1086, t1087, t1088, t1095, t_transformer_h_7_attn_proj_weight, t1101, t1110, t1111, t1116, t1115, t1118, t_transformer_h_7_mlp_fc_1_weight, t_transformer_h_7_mlp_fc_2_weight, t1123, t1125, t1128, t1127, t1132, t1131, t1134, t_transformer_h_7_mlp_proj_weight, t1140, t1149, t1150, t1155, t1154, t1157, t_transformer_h_8_attn_attn_weight, t1200, t1201, t1203, t1204, t1216, t1217, t1219, t1220, t1226, t1229, t1191, t1230, t1231, t1232, t1233, t1240, t_transformer_h_8_attn_proj_weight, t1246, t1255, t1256, t1261, t1260, t1263, t_transformer_h_8_mlp_fc_1_weight, t_transformer_h_8_mlp_fc_2_weight, t1268, t1270, t1273, t1272, t1277, t1276, t1279, t_transformer_h_8_mlp_proj_weight, t1285, t1294, t1295, t1300, t1299, t1302, t_transformer_h_9_attn_attn_weight, t1345, t1346, t1348, t1349, t1361, t1362, t1364, t1365, t1371, t1374, t1336, t1375, t1376, t1377, t1378, t1385, t_transformer_h_9_attn_proj_weight, t1391, t1400, t1401, t1406, t1405, t1408, t_transformer_h_9_mlp_fc_1_weight, t_transformer_h_9_mlp_fc_2_weight, t1413, t1415, t1418, t1417, t1422, t1421, t1424, t_transformer_h_9_mlp_proj_weight, t1430, t1439, t1440, t1445, t1444, t1447, t_transformer_h_10_attn_attn_weight, t1490, t1491, t1493, t1494, t1506, t1507, t1509, t1510, t1516, t1519, t1481, t1520, t1521, t1522, t1523, t1530, t_transformer_h_10_attn_proj_weight, t1536, t1545, t1546, t1551, t1550, t1553, t_transformer_h_10_mlp_fc_1_weight, t_transformer_h_10_mlp_fc_2_weight, t1558, t1560, t1563, t1562, t1567, t1566, t1569, t_transformer_h_10_mlp_proj_weight, t1575, t1584, t1585, t1590, t1589, t1592, t_transformer_h_11_attn_attn_weight, t1635, t1636, t1638, t1639, t1651, t1652, t1654, t1655, t1661, t1664, t1626, t1665, t1666, t1667, t1668, t1675, t_transformer_h_11_attn_proj_weight, t1681, t1690, t1691, t1696, t1695, t1698, t_transformer_h_11_mlp_fc_1_weight, t_transformer_h_11_mlp_fc_2_weight, t1703, t1705, t1708, t1707, t1712, t1711, t1714, t_transformer_h_11_mlp_proj_weight, t1720, t1729, t1730, t1735, t1734, t1737, t_transformer_h_12_attn_attn_weight, t1780, t1781, t1783, t1784, t1796, t1797, t1799, t1800, t1806, t1809, t1771, t1810, t1811, t1812, t1813, t1820, t_transformer_h_12_attn_proj_weight, t1826, t1835, t1836, t1841, t1840, t1843, t_transformer_h_12_mlp_fc_1_weight, t_transformer_h_12_mlp_fc_2_weight, t1848, t1850, t1853, t1852, t1857, t1856, t1859, t_transformer_h_12_mlp_proj_weight, t1865, t1874, t1875, t1880, t1879, t1882, t_transformer_h_13_attn_attn_weight, t1925, t1926, t1928, t1929, t1941, t1942, t1944, t1945, t1951, t1954, t1916, t1955, t1956, t1957, t1958, t1965, t_transformer_h_13_attn_proj_weight, t1971, t1980, t1981, t1986, t1985, t1988, t_transformer_h_13_mlp_fc_1_weight, t_transformer_h_13_mlp_fc_2_weight, t1993, t1995, t1998, t1997, t2002, t2001, t2004, t_transformer_h_13_mlp_proj_weight, t2010, t2019, t2020, t2025, t2024, t2027, t_transformer_h_14_attn_attn_weight, t2070, t2071, t2073, t2074, t2086, t2087, t2089, t2090, t2096, t2099, t2061, t2100, t2101, t2102, t2103, t2110, t_transformer_h_14_attn_proj_weight, t2116, t2125, t2126, t2131, t2130, t2133, t_transformer_h_14_mlp_fc_1_weight, t_transformer_h_14_mlp_fc_2_weight, t2138, t2140, t2143, t2142, t2147, t2146, t2149, t_transformer_h_14_mlp_proj_weight, t2155, t2164, t2165, t2170, t2169, t2172, t_transformer_h_15_attn_attn_weight, t2215, t2216, t2218, t2219, t2231, t2232, t2234, t2235, t2241, t2244, t2206, t2245, t2246, t2247, t2248, t2255, t_transformer_h_15_attn_proj_weight, t2261, t2270, t2271, t2276, t2275, t2278, t_transformer_h_15_mlp_fc_1_weight, t_transformer_h_15_mlp_fc_2_weight, t2283, t2285, t2288, t2287, t2292, t2291, t2294, t_transformer_h_15_mlp_proj_weight, t2300, t2309, t2310, t2315, t2314, t2317, t_lm_head_weight), (32000, False, False, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0, 2, 0.0, True, 0.08838834764831843, 4096.0, 4096.0)) -============================================ GRAPH: _transform_for_operator_executor_execution - -cur node 0 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 1 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 2 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 3 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 4 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 5 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 6 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 7 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 8 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 9 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 10 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 11 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 12 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 13 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 14 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 15 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 16 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 17 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 18 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 19 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 20 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 21 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 22 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 23 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 24 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 25 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 26 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 27 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 28 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 29 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 30 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 31 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 32 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 33 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 34 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 35 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 36 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 37 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 38 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 39 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 40 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 41 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 42 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 43 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 44 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 45 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 46 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 47 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 48 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 49 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 50 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 51 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 52 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 53 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 54 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 55 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 56 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 57 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 58 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 59 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 60 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 61 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 62 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 63 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 64 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 65 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 66 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 67 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 68 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 69 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 70 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 71 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 72 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 73 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 74 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 75 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 76 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 77 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 78 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 79 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 80 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 81 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 82 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 83 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 84 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 85 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 86 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 87 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 88 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 89 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 90 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 91 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 92 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 93 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 94 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 95 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 96 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 97 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 98 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 99 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 100 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 101 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 102 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 103 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 104 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 105 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 106 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 107 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 108 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 109 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 110 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 111 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 112 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 113 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 114 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 115 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 116 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 117 group_bsyms len: 1 -[Symbol name=unpack_trivial] -[] - -cur node 120 group_bsyms len: 1 -[Symbol name=embedding] -[t2 = ltorch.reshape(idx, [512]) # t2: "cuda:0 i64[512]" - # t2 = prims.reshape(idx, (512,)) # t2: "cuda:0 i64[512]", t3 = prims.take(t_transformer_wte_weight, t2, 0) # t3: "cuda:0 bf16[512, 4096]", t4 = ltorch.reshape(t3, [1, 512, 4096]) # t4: "cuda:0 bf16[1, 512, 4096]" - # t4 = prims.reshape(t3, (1, 512, 4096)) # t4: "cuda:0 bf16[1, 512, 4096]"] - -cur node 1737 group_bsyms len: 1 -[Symbol name=return] -[] - -cur node 118 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1736 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 119 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 136 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 180 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 200 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 201 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 216 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 131 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 195 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 236 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 280 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 300 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 301 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 316 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 231 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 295 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 336 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 380 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 400 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 401 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 416 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 331 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 395 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 436 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 480 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 500 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 501 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 516 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 431 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 495 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 536 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 580 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 600 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 601 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 616 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 531 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 595 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 636 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 680 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 700 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 701 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 716 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 631 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 695 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 736 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 780 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 800 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 801 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 816 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 731 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 795 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 836 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 880 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 900 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 901 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 916 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 831 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 895 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 936 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 980 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1000 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1001 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1016 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 931 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 995 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1036 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1080 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1100 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1101 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1116 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1031 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1095 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1136 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1180 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1200 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1201 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1216 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1131 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1195 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1236 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1280 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1300 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1301 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1316 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1231 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1295 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1336 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1380 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1400 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1401 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1416 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1331 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1395 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1436 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1480 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1500 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1501 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1516 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1431 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1495 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1536 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1580 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1600 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1601 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1616 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1531 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1595 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1636 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1680 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1700 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1701 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1716 group_bsyms len: 1 -[Symbol name=linear] -[] - -cur node 1631 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1695 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1731 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 121 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 182 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1665 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 265 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 650 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1165 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1550 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 150 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 665 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1050 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1565 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 165 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 550 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1065 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1450 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 565 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 950 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1465 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 450 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 965 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1350 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 465 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 850 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1365 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 350 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 865 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1250 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 365 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 750 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1265 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1650 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 250 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 765 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1150 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 768 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1153 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1668 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 268 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 653 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1168 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1553 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 153 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 668 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1053 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1568 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 168 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 553 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1068 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1453 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 568 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 953 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1468 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 453 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 968 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1353 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 468 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 853 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1368 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 353 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 868 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1253 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 368 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 753 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1268 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1653 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 253 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 137 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 181 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 208 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 202 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 213 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 217 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 133 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 197 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 237 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 281 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 308 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 302 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 313 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 317 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 233 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 297 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 337 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 381 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 408 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 402 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 413 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 417 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 333 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 397 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 437 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 481 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 508 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 502 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 513 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 517 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 433 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 497 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 537 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 581 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 608 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 602 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 613 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 617 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 533 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 597 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 637 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 681 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 708 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 702 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 713 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 717 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 633 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 697 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 737 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 781 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 808 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 802 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 813 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 817 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 733 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 797 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 837 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 881 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 908 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 902 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 913 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 917 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 833 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 897 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 937 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 981 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1008 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1002 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1013 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1017 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 933 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 997 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1037 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1081 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1108 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1102 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1113 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1117 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1033 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1097 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1137 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1181 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1208 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1202 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1213 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1217 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1133 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1197 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1237 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1281 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1308 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1302 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1313 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1317 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1233 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1297 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1337 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1381 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1408 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1402 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1413 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1417 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1333 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1397 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1437 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1481 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1508 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1502 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1513 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1517 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1433 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1497 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1537 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1581 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1608 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1602 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1613 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1617 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1533 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1597 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1637 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1681 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1708 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1702 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1713 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1717 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1633 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1697 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1733 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 129 group_bsyms len: 1 -[Symbol name=mul] -[t13 = prims.mul(t5, t12) # t13: "cuda:0 f32[1, 512, 4096]"] - -cur node 122 group_bsyms len: 1 -[Symbol name=mul] -[t6 = prims.mul(t5, t5) # t6: "cuda:0 f32[1, 512, 4096]"] - -cur node 183 group_bsyms len: 1 -[Symbol name=add] -[t84 = prims.add(t82, t83) # t84: "cuda:0 f32[1, 512, 4096]"] - -cur node 1667 group_bsyms len: 1 -[Symbol name=mul] -[t2233 = prims.mul(t2232, t2231) # t2233: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 267 group_bsyms len: 1 -[Symbol name=mul] -[t203 = prims.mul(t202, t201) # t203: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 652 group_bsyms len: 1 -[Symbol name=mul] -[t767 = prims.mul(t766, t765) # t767: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1167 group_bsyms len: 1 -[Symbol name=mul] -[t1508 = prims.mul(t1507, t1506) # t1508: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1552 group_bsyms len: 1 -[Symbol name=mul] -[t2072 = prims.mul(t2071, t2070) # t2072: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 152 group_bsyms len: 1 -[Symbol name=mul] -[t48 = prims.mul(t47, t46) # t48: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 667 group_bsyms len: 1 -[Symbol name=mul] -[t783 = prims.mul(t782, t781) # t783: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1052 group_bsyms len: 1 -[Symbol name=mul] -[t1347 = prims.mul(t1346, t1345) # t1347: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1567 group_bsyms len: 1 -[Symbol name=mul] -[t2088 = prims.mul(t2087, t2086) # t2088: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 167 group_bsyms len: 1 -[Symbol name=mul] -[t64 = prims.mul(t63, t62) # t64: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 552 group_bsyms len: 1 -[Symbol name=mul] -[t622 = prims.mul(t621, t620) # t622: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1067 group_bsyms len: 1 -[Symbol name=mul] -[t1363 = prims.mul(t1362, t1361) # t1363: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1452 group_bsyms len: 1 -[Symbol name=mul] -[t1927 = prims.mul(t1926, t1925) # t1927: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 567 group_bsyms len: 1 -[Symbol name=mul] -[t638 = prims.mul(t637, t636) # t638: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 952 group_bsyms len: 1 -[Symbol name=mul] -[t1202 = prims.mul(t1201, t1200) # t1202: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1467 group_bsyms len: 1 -[Symbol name=mul] -[t1943 = prims.mul(t1942, t1941) # t1943: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 452 group_bsyms len: 1 -[Symbol name=mul] -[t477 = prims.mul(t476, t475) # t477: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 967 group_bsyms len: 1 -[Symbol name=mul] -[t1218 = prims.mul(t1217, t1216) # t1218: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1352 group_bsyms len: 1 -[Symbol name=mul] -[t1782 = prims.mul(t1781, t1780) # t1782: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 467 group_bsyms len: 1 -[Symbol name=mul] -[t493 = prims.mul(t492, t491) # t493: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 852 group_bsyms len: 1 -[Symbol name=mul] -[t1057 = prims.mul(t1056, t1055) # t1057: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1367 group_bsyms len: 1 -[Symbol name=mul] -[t1798 = prims.mul(t1797, t1796) # t1798: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 352 group_bsyms len: 1 -[Symbol name=mul] -[t332 = prims.mul(t331, t330) # t332: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 867 group_bsyms len: 1 -[Symbol name=mul] -[t1073 = prims.mul(t1072, t1071) # t1073: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1252 group_bsyms len: 1 -[Symbol name=mul] -[t1637 = prims.mul(t1636, t1635) # t1637: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 367 group_bsyms len: 1 -[Symbol name=mul] -[t348 = prims.mul(t347, t346) # t348: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 752 group_bsyms len: 1 -[Symbol name=mul] -[t912 = prims.mul(t911, t910) # t912: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1267 group_bsyms len: 1 -[Symbol name=mul] -[t1653 = prims.mul(t1652, t1651) # t1653: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1652 group_bsyms len: 1 -[Symbol name=mul] -[t2217 = prims.mul(t2216, t2215) # t2217: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 252 group_bsyms len: 1 -[Symbol name=mul] -[t187 = prims.mul(t186, t185) # t187: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 767 group_bsyms len: 1 -[Symbol name=mul] -[t928 = prims.mul(t927, t926) # t928: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1152 group_bsyms len: 1 -[Symbol name=mul] -[t1492 = prims.mul(t1491, t1490) # t1492: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 770 group_bsyms len: 1 -[Symbol name=mul] -[t931 = prims.mul(t930, t929) # t931: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1155 group_bsyms len: 1 -[Symbol name=mul] -[t1495 = prims.mul(t1494, t1493) # t1495: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1670 group_bsyms len: 1 -[Symbol name=mul] -[t2236 = prims.mul(t2235, t2234) # t2236: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 270 group_bsyms len: 1 -[Symbol name=mul] -[t206 = prims.mul(t205, t204) # t206: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 655 group_bsyms len: 1 -[Symbol name=mul] -[t770 = prims.mul(t769, t768) # t770: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1170 group_bsyms len: 1 -[Symbol name=mul] -[t1511 = prims.mul(t1510, t1509) # t1511: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1555 group_bsyms len: 1 -[Symbol name=mul] -[t2075 = prims.mul(t2074, t2073) # t2075: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 155 group_bsyms len: 1 -[Symbol name=mul] -[t51 = prims.mul(t50, t49) # t51: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 670 group_bsyms len: 1 -[Symbol name=mul] -[t786 = prims.mul(t785, t784) # t786: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1055 group_bsyms len: 1 -[Symbol name=mul] -[t1350 = prims.mul(t1349, t1348) # t1350: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1570 group_bsyms len: 1 -[Symbol name=mul] -[t2091 = prims.mul(t2090, t2089) # t2091: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 170 group_bsyms len: 1 -[Symbol name=mul] -[t67 = prims.mul(t66, t65) # t67: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 555 group_bsyms len: 1 -[Symbol name=mul] -[t625 = prims.mul(t624, t623) # t625: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1070 group_bsyms len: 1 -[Symbol name=mul] -[t1366 = prims.mul(t1365, t1364) # t1366: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1455 group_bsyms len: 1 -[Symbol name=mul] -[t1930 = prims.mul(t1929, t1928) # t1930: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 570 group_bsyms len: 1 -[Symbol name=mul] -[t641 = prims.mul(t640, t639) # t641: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 955 group_bsyms len: 1 -[Symbol name=mul] -[t1205 = prims.mul(t1204, t1203) # t1205: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1470 group_bsyms len: 1 -[Symbol name=mul] -[t1946 = prims.mul(t1945, t1944) # t1946: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 455 group_bsyms len: 1 -[Symbol name=mul] -[t480 = prims.mul(t479, t478) # t480: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 970 group_bsyms len: 1 -[Symbol name=mul] -[t1221 = prims.mul(t1220, t1219) # t1221: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1355 group_bsyms len: 1 -[Symbol name=mul] -[t1785 = prims.mul(t1784, t1783) # t1785: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 470 group_bsyms len: 1 -[Symbol name=mul] -[t496 = prims.mul(t495, t494) # t496: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 855 group_bsyms len: 1 -[Symbol name=mul] -[t1060 = prims.mul(t1059, t1058) # t1060: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1370 group_bsyms len: 1 -[Symbol name=mul] -[t1801 = prims.mul(t1800, t1799) # t1801: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 355 group_bsyms len: 1 -[Symbol name=mul] -[t335 = prims.mul(t334, t333) # t335: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 870 group_bsyms len: 1 -[Symbol name=mul] -[t1076 = prims.mul(t1075, t1074) # t1076: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1255 group_bsyms len: 1 -[Symbol name=mul] -[t1640 = prims.mul(t1639, t1638) # t1640: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 370 group_bsyms len: 1 -[Symbol name=mul] -[t351 = prims.mul(t350, t349) # t351: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 755 group_bsyms len: 1 -[Symbol name=mul] -[t915 = prims.mul(t914, t913) # t915: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1270 group_bsyms len: 1 -[Symbol name=mul] -[t1656 = prims.mul(t1655, t1654) # t1656: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1655 group_bsyms len: 1 -[Symbol name=mul] -[t2220 = prims.mul(t2219, t2218) # t2220: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 255 group_bsyms len: 1 -[Symbol name=mul] -[t190 = prims.mul(t189, t188) # t190: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 138 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 210 group_bsyms len: 1 -[Symbol name=mul] -[t114 = prims.mul(t112, t113) # t114: "cuda:0 f32[1, 512, 11008]"] - -cur node 203 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 214 group_bsyms len: 1 -[Symbol name=mul] -[t118 = prims.mul(t116, t117) # t118: "cuda:0 f32[1, 512, 11008]"] - -cur node 219 group_bsyms len: 1 -[Symbol name=add] -[t123 = prims.add(t121, t122) # t123: "cuda:0 f32[1, 512, 4096]"] - -cur node 134 group_bsyms len: 1 -[Symbol name=mul] -[t18 = prims.mul(t16, t17) # t18: "cuda:0 f32[1, 512, 4096]"] - -cur node 198 group_bsyms len: 1 -[Symbol name=mul] -[t102 = prims.mul(t100, t101) # t102: "cuda:0 f32[1, 512, 4096]"] - -cur node 238 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 283 group_bsyms len: 1 -[Symbol name=add] -[t229 = prims.add(t227, t228) # t229: "cuda:0 f32[1, 512, 4096]"] - -cur node 310 group_bsyms len: 1 -[Symbol name=mul] -[t259 = prims.mul(t257, t258) # t259: "cuda:0 f32[1, 512, 11008]"] - -cur node 303 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 314 group_bsyms len: 1 -[Symbol name=mul] -[t263 = prims.mul(t261, t262) # t263: "cuda:0 f32[1, 512, 11008]"] - -cur node 319 group_bsyms len: 1 -[Symbol name=add] -[t268 = prims.add(t266, t267) # t268: "cuda:0 f32[1, 512, 4096]"] - -cur node 234 group_bsyms len: 1 -[Symbol name=mul] -[t141 = prims.mul(t139, t140) # t141: "cuda:0 f32[1, 512, 4096]"] - -cur node 298 group_bsyms len: 1 -[Symbol name=mul] -[t247 = prims.mul(t245, t246) # t247: "cuda:0 f32[1, 512, 4096]"] - -cur node 338 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 383 group_bsyms len: 1 -[Symbol name=add] -[t374 = prims.add(t372, t373) # t374: "cuda:0 f32[1, 512, 4096]"] - -cur node 410 group_bsyms len: 1 -[Symbol name=mul] -[t404 = prims.mul(t402, t403) # t404: "cuda:0 f32[1, 512, 11008]"] - -cur node 403 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 414 group_bsyms len: 1 -[Symbol name=mul] -[t408 = prims.mul(t406, t407) # t408: "cuda:0 f32[1, 512, 11008]"] - -cur node 419 group_bsyms len: 1 -[Symbol name=add] -[t413 = prims.add(t411, t412) # t413: "cuda:0 f32[1, 512, 4096]"] - -cur node 334 group_bsyms len: 1 -[Symbol name=mul] -[t286 = prims.mul(t284, t285) # t286: "cuda:0 f32[1, 512, 4096]"] - -cur node 398 group_bsyms len: 1 -[Symbol name=mul] -[t392 = prims.mul(t390, t391) # t392: "cuda:0 f32[1, 512, 4096]"] - -cur node 438 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 483 group_bsyms len: 1 -[Symbol name=add] -[t519 = prims.add(t517, t518) # t519: "cuda:0 f32[1, 512, 4096]"] - -cur node 510 group_bsyms len: 1 -[Symbol name=mul] -[t549 = prims.mul(t547, t548) # t549: "cuda:0 f32[1, 512, 11008]"] - -cur node 503 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 514 group_bsyms len: 1 -[Symbol name=mul] -[t553 = prims.mul(t551, t552) # t553: "cuda:0 f32[1, 512, 11008]"] - -cur node 519 group_bsyms len: 1 -[Symbol name=add] -[t558 = prims.add(t556, t557) # t558: "cuda:0 f32[1, 512, 4096]"] - -cur node 434 group_bsyms len: 1 -[Symbol name=mul] -[t431 = prims.mul(t429, t430) # t431: "cuda:0 f32[1, 512, 4096]"] - -cur node 498 group_bsyms len: 1 -[Symbol name=mul] -[t537 = prims.mul(t535, t536) # t537: "cuda:0 f32[1, 512, 4096]"] - -cur node 538 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 583 group_bsyms len: 1 -[Symbol name=add] -[t664 = prims.add(t662, t663) # t664: "cuda:0 f32[1, 512, 4096]"] - -cur node 610 group_bsyms len: 1 -[Symbol name=mul] -[t694 = prims.mul(t692, t693) # t694: "cuda:0 f32[1, 512, 11008]"] - -cur node 603 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 614 group_bsyms len: 1 -[Symbol name=mul] -[t698 = prims.mul(t696, t697) # t698: "cuda:0 f32[1, 512, 11008]"] - -cur node 619 group_bsyms len: 1 -[Symbol name=add] -[t703 = prims.add(t701, t702) # t703: "cuda:0 f32[1, 512, 4096]"] - -cur node 534 group_bsyms len: 1 -[Symbol name=mul] -[t576 = prims.mul(t574, t575) # t576: "cuda:0 f32[1, 512, 4096]"] - -cur node 598 group_bsyms len: 1 -[Symbol name=mul] -[t682 = prims.mul(t680, t681) # t682: "cuda:0 f32[1, 512, 4096]"] - -cur node 638 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 683 group_bsyms len: 1 -[Symbol name=add] -[t809 = prims.add(t807, t808) # t809: "cuda:0 f32[1, 512, 4096]"] - -cur node 710 group_bsyms len: 1 -[Symbol name=mul] -[t839 = prims.mul(t837, t838) # t839: "cuda:0 f32[1, 512, 11008]"] - -cur node 703 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 714 group_bsyms len: 1 -[Symbol name=mul] -[t843 = prims.mul(t841, t842) # t843: "cuda:0 f32[1, 512, 11008]"] - -cur node 719 group_bsyms len: 1 -[Symbol name=add] -[t848 = prims.add(t846, t847) # t848: "cuda:0 f32[1, 512, 4096]"] - -cur node 634 group_bsyms len: 1 -[Symbol name=mul] -[t721 = prims.mul(t719, t720) # t721: "cuda:0 f32[1, 512, 4096]"] - -cur node 698 group_bsyms len: 1 -[Symbol name=mul] -[t827 = prims.mul(t825, t826) # t827: "cuda:0 f32[1, 512, 4096]"] - -cur node 738 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 783 group_bsyms len: 1 -[Symbol name=add] -[t954 = prims.add(t952, t953) # t954: "cuda:0 f32[1, 512, 4096]"] - -cur node 810 group_bsyms len: 1 -[Symbol name=mul] -[t984 = prims.mul(t982, t983) # t984: "cuda:0 f32[1, 512, 11008]"] - -cur node 803 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 814 group_bsyms len: 1 -[Symbol name=mul] -[t988 = prims.mul(t986, t987) # t988: "cuda:0 f32[1, 512, 11008]"] - -cur node 819 group_bsyms len: 1 -[Symbol name=add] -[t993 = prims.add(t991, t992) # t993: "cuda:0 f32[1, 512, 4096]"] - -cur node 734 group_bsyms len: 1 -[Symbol name=mul] -[t866 = prims.mul(t864, t865) # t866: "cuda:0 f32[1, 512, 4096]"] - -cur node 798 group_bsyms len: 1 -[Symbol name=mul] -[t972 = prims.mul(t970, t971) # t972: "cuda:0 f32[1, 512, 4096]"] - -cur node 838 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 883 group_bsyms len: 1 -[Symbol name=add] -[t1099 = prims.add(t1097, t1098) # t1099: "cuda:0 f32[1, 512, 4096]"] - -cur node 910 group_bsyms len: 1 -[Symbol name=mul] -[t1129 = prims.mul(t1127, t1128) # t1129: "cuda:0 f32[1, 512, 11008]"] - -cur node 903 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 914 group_bsyms len: 1 -[Symbol name=mul] -[t1133 = prims.mul(t1131, t1132) # t1133: "cuda:0 f32[1, 512, 11008]"] - -cur node 919 group_bsyms len: 1 -[Symbol name=add] -[t1138 = prims.add(t1136, t1137) # t1138: "cuda:0 f32[1, 512, 4096]"] - -cur node 834 group_bsyms len: 1 -[Symbol name=mul] -[t1011 = prims.mul(t1009, t1010) # t1011: "cuda:0 f32[1, 512, 4096]"] - -cur node 898 group_bsyms len: 1 -[Symbol name=mul] -[t1117 = prims.mul(t1115, t1116) # t1117: "cuda:0 f32[1, 512, 4096]"] - -cur node 938 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 983 group_bsyms len: 1 -[Symbol name=add] -[t1244 = prims.add(t1242, t1243) # t1244: "cuda:0 f32[1, 512, 4096]"] - -cur node 1010 group_bsyms len: 1 -[Symbol name=mul] -[t1274 = prims.mul(t1272, t1273) # t1274: "cuda:0 f32[1, 512, 11008]"] - -cur node 1003 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1014 group_bsyms len: 1 -[Symbol name=mul] -[t1278 = prims.mul(t1276, t1277) # t1278: "cuda:0 f32[1, 512, 11008]"] - -cur node 1019 group_bsyms len: 1 -[Symbol name=add] -[t1283 = prims.add(t1281, t1282) # t1283: "cuda:0 f32[1, 512, 4096]"] - -cur node 934 group_bsyms len: 1 -[Symbol name=mul] -[t1156 = prims.mul(t1154, t1155) # t1156: "cuda:0 f32[1, 512, 4096]"] - -cur node 998 group_bsyms len: 1 -[Symbol name=mul] -[t1262 = prims.mul(t1260, t1261) # t1262: "cuda:0 f32[1, 512, 4096]"] - -cur node 1038 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 1083 group_bsyms len: 1 -[Symbol name=add] -[t1389 = prims.add(t1387, t1388) # t1389: "cuda:0 f32[1, 512, 4096]"] - -cur node 1110 group_bsyms len: 1 -[Symbol name=mul] -[t1419 = prims.mul(t1417, t1418) # t1419: "cuda:0 f32[1, 512, 11008]"] - -cur node 1103 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1114 group_bsyms len: 1 -[Symbol name=mul] -[t1423 = prims.mul(t1421, t1422) # t1423: "cuda:0 f32[1, 512, 11008]"] - -cur node 1119 group_bsyms len: 1 -[Symbol name=add] -[t1428 = prims.add(t1426, t1427) # t1428: "cuda:0 f32[1, 512, 4096]"] - -cur node 1034 group_bsyms len: 1 -[Symbol name=mul] -[t1301 = prims.mul(t1299, t1300) # t1301: "cuda:0 f32[1, 512, 4096]"] - -cur node 1098 group_bsyms len: 1 -[Symbol name=mul] -[t1407 = prims.mul(t1405, t1406) # t1407: "cuda:0 f32[1, 512, 4096]"] - -cur node 1138 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 1183 group_bsyms len: 1 -[Symbol name=add] -[t1534 = prims.add(t1532, t1533) # t1534: "cuda:0 f32[1, 512, 4096]"] - -cur node 1210 group_bsyms len: 1 -[Symbol name=mul] -[t1564 = prims.mul(t1562, t1563) # t1564: "cuda:0 f32[1, 512, 11008]"] - -cur node 1203 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1214 group_bsyms len: 1 -[Symbol name=mul] -[t1568 = prims.mul(t1566, t1567) # t1568: "cuda:0 f32[1, 512, 11008]"] - -cur node 1219 group_bsyms len: 1 -[Symbol name=add] -[t1573 = prims.add(t1571, t1572) # t1573: "cuda:0 f32[1, 512, 4096]"] - -cur node 1134 group_bsyms len: 1 -[Symbol name=mul] -[t1446 = prims.mul(t1444, t1445) # t1446: "cuda:0 f32[1, 512, 4096]"] - -cur node 1198 group_bsyms len: 1 -[Symbol name=mul] -[t1552 = prims.mul(t1550, t1551) # t1552: "cuda:0 f32[1, 512, 4096]"] - -cur node 1238 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 1283 group_bsyms len: 1 -[Symbol name=add] -[t1679 = prims.add(t1677, t1678) # t1679: "cuda:0 f32[1, 512, 4096]"] - -cur node 1310 group_bsyms len: 1 -[Symbol name=mul] -[t1709 = prims.mul(t1707, t1708) # t1709: "cuda:0 f32[1, 512, 11008]"] - -cur node 1303 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1314 group_bsyms len: 1 -[Symbol name=mul] -[t1713 = prims.mul(t1711, t1712) # t1713: "cuda:0 f32[1, 512, 11008]"] - -cur node 1319 group_bsyms len: 1 -[Symbol name=add] -[t1718 = prims.add(t1716, t1717) # t1718: "cuda:0 f32[1, 512, 4096]"] - -cur node 1234 group_bsyms len: 1 -[Symbol name=mul] -[t1591 = prims.mul(t1589, t1590) # t1591: "cuda:0 f32[1, 512, 4096]"] - -cur node 1298 group_bsyms len: 1 -[Symbol name=mul] -[t1697 = prims.mul(t1695, t1696) # t1697: "cuda:0 f32[1, 512, 4096]"] - -cur node 1338 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 1383 group_bsyms len: 1 -[Symbol name=add] -[t1824 = prims.add(t1822, t1823) # t1824: "cuda:0 f32[1, 512, 4096]"] - -cur node 1410 group_bsyms len: 1 -[Symbol name=mul] -[t1854 = prims.mul(t1852, t1853) # t1854: "cuda:0 f32[1, 512, 11008]"] - -cur node 1403 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1414 group_bsyms len: 1 -[Symbol name=mul] -[t1858 = prims.mul(t1856, t1857) # t1858: "cuda:0 f32[1, 512, 11008]"] - -cur node 1419 group_bsyms len: 1 -[Symbol name=add] -[t1863 = prims.add(t1861, t1862) # t1863: "cuda:0 f32[1, 512, 4096]"] - -cur node 1334 group_bsyms len: 1 -[Symbol name=mul] -[t1736 = prims.mul(t1734, t1735) # t1736: "cuda:0 f32[1, 512, 4096]"] - -cur node 1398 group_bsyms len: 1 -[Symbol name=mul] -[t1842 = prims.mul(t1840, t1841) # t1842: "cuda:0 f32[1, 512, 4096]"] - -cur node 1438 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 1483 group_bsyms len: 1 -[Symbol name=add] -[t1969 = prims.add(t1967, t1968) # t1969: "cuda:0 f32[1, 512, 4096]"] - -cur node 1510 group_bsyms len: 1 -[Symbol name=mul] -[t1999 = prims.mul(t1997, t1998) # t1999: "cuda:0 f32[1, 512, 11008]"] - -cur node 1503 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1514 group_bsyms len: 1 -[Symbol name=mul] -[t2003 = prims.mul(t2001, t2002) # t2003: "cuda:0 f32[1, 512, 11008]"] - -cur node 1519 group_bsyms len: 1 -[Symbol name=add] -[t2008 = prims.add(t2006, t2007) # t2008: "cuda:0 f32[1, 512, 4096]"] - -cur node 1434 group_bsyms len: 1 -[Symbol name=mul] -[t1881 = prims.mul(t1879, t1880) # t1881: "cuda:0 f32[1, 512, 4096]"] - -cur node 1498 group_bsyms len: 1 -[Symbol name=mul] -[t1987 = prims.mul(t1985, t1986) # t1987: "cuda:0 f32[1, 512, 4096]"] - -cur node 1538 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 1583 group_bsyms len: 1 -[Symbol name=add] -[t2114 = prims.add(t2112, t2113) # t2114: "cuda:0 f32[1, 512, 4096]"] - -cur node 1610 group_bsyms len: 1 -[Symbol name=mul] -[t2144 = prims.mul(t2142, t2143) # t2144: "cuda:0 f32[1, 512, 11008]"] - -cur node 1603 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1614 group_bsyms len: 1 -[Symbol name=mul] -[t2148 = prims.mul(t2146, t2147) # t2148: "cuda:0 f32[1, 512, 11008]"] - -cur node 1619 group_bsyms len: 1 -[Symbol name=add] -[t2153 = prims.add(t2151, t2152) # t2153: "cuda:0 f32[1, 512, 4096]"] - -cur node 1534 group_bsyms len: 1 -[Symbol name=mul] -[t2026 = prims.mul(t2024, t2025) # t2026: "cuda:0 f32[1, 512, 4096]"] - -cur node 1598 group_bsyms len: 1 -[Symbol name=mul] -[t2132 = prims.mul(t2130, t2131) # t2132: "cuda:0 f32[1, 512, 4096]"] - -cur node 1638 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 1683 group_bsyms len: 1 -[Symbol name=add] -[t2259 = prims.add(t2257, t2258) # t2259: "cuda:0 f32[1, 512, 4096]"] - -cur node 1710 group_bsyms len: 1 -[Symbol name=mul] -[t2289 = prims.mul(t2287, t2288) # t2289: "cuda:0 f32[1, 512, 11008]"] - -cur node 1703 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1714 group_bsyms len: 1 -[Symbol name=mul] -[t2293 = prims.mul(t2291, t2292) # t2293: "cuda:0 f32[1, 512, 11008]"] - -cur node 1719 group_bsyms len: 1 -[Symbol name=add] -[t2298 = prims.add(t2296, t2297) # t2298: "cuda:0 f32[1, 512, 4096]"] - -cur node 1634 group_bsyms len: 1 -[Symbol name=mul] -[t2171 = prims.mul(t2169, t2170) # t2171: "cuda:0 f32[1, 512, 4096]"] - -cur node 1698 group_bsyms len: 1 -[Symbol name=mul] -[t2277 = prims.mul(t2275, t2276) # t2277: "cuda:0 f32[1, 512, 4096]"] - -cur node 1734 group_bsyms len: 1 -[Symbol name=mul] -[t2316 = prims.mul(t2314, t2315) # t2316: "cuda:0 f32[1, 512, 4096]"] - -cur node 130 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 123 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 184 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1671 group_bsyms len: 1 -[Symbol name=add] -[t2237 = prims.add(t2233, t2236) # t2237: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 271 group_bsyms len: 1 -[Symbol name=add] -[t207 = prims.add(t203, t206) # t207: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 656 group_bsyms len: 1 -[Symbol name=add] -[t771 = prims.add(t767, t770) # t771: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1171 group_bsyms len: 1 -[Symbol name=add] -[t1512 = prims.add(t1508, t1511) # t1512: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1556 group_bsyms len: 1 -[Symbol name=add] -[t2076 = prims.add(t2072, t2075) # t2076: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 156 group_bsyms len: 1 -[Symbol name=add] -[t52 = prims.add(t48, t51) # t52: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 671 group_bsyms len: 1 -[Symbol name=add] -[t787 = prims.add(t783, t786) # t787: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1056 group_bsyms len: 1 -[Symbol name=add] -[t1351 = prims.add(t1347, t1350) # t1351: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1571 group_bsyms len: 1 -[Symbol name=add] -[t2092 = prims.add(t2088, t2091) # t2092: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 171 group_bsyms len: 1 -[Symbol name=add] -[t68 = prims.add(t64, t67) # t68: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 556 group_bsyms len: 1 -[Symbol name=add] -[t626 = prims.add(t622, t625) # t626: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1071 group_bsyms len: 1 -[Symbol name=add] -[t1367 = prims.add(t1363, t1366) # t1367: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1456 group_bsyms len: 1 -[Symbol name=add] -[t1931 = prims.add(t1927, t1930) # t1931: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 571 group_bsyms len: 1 -[Symbol name=add] -[t642 = prims.add(t638, t641) # t642: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 956 group_bsyms len: 1 -[Symbol name=add] -[t1206 = prims.add(t1202, t1205) # t1206: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1471 group_bsyms len: 1 -[Symbol name=add] -[t1947 = prims.add(t1943, t1946) # t1947: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 456 group_bsyms len: 1 -[Symbol name=add] -[t481 = prims.add(t477, t480) # t481: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 971 group_bsyms len: 1 -[Symbol name=add] -[t1222 = prims.add(t1218, t1221) # t1222: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1356 group_bsyms len: 1 -[Symbol name=add] -[t1786 = prims.add(t1782, t1785) # t1786: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 471 group_bsyms len: 1 -[Symbol name=add] -[t497 = prims.add(t493, t496) # t497: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 856 group_bsyms len: 1 -[Symbol name=add] -[t1061 = prims.add(t1057, t1060) # t1061: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1371 group_bsyms len: 1 -[Symbol name=add] -[t1802 = prims.add(t1798, t1801) # t1802: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 356 group_bsyms len: 1 -[Symbol name=add] -[t336 = prims.add(t332, t335) # t336: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 871 group_bsyms len: 1 -[Symbol name=add] -[t1077 = prims.add(t1073, t1076) # t1077: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1256 group_bsyms len: 1 -[Symbol name=add] -[t1641 = prims.add(t1637, t1640) # t1641: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 371 group_bsyms len: 1 -[Symbol name=add] -[t352 = prims.add(t348, t351) # t352: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 756 group_bsyms len: 1 -[Symbol name=add] -[t916 = prims.add(t912, t915) # t916: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1271 group_bsyms len: 1 -[Symbol name=add] -[t1657 = prims.add(t1653, t1656) # t1657: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1656 group_bsyms len: 1 -[Symbol name=add] -[t2221 = prims.add(t2217, t2220) # t2221: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 256 group_bsyms len: 1 -[Symbol name=add] -[t191 = prims.add(t187, t190) # t191: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 771 group_bsyms len: 1 -[Symbol name=add] -[t932 = prims.add(t928, t931) # t932: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 1156 group_bsyms len: 1 -[Symbol name=add] -[t1496 = prims.add(t1492, t1495) # t1496: "cuda:0 f32[1, 32, 512, 128]"] - -cur node 139 group_bsyms len: 1 -[Symbol name=split] -[t23 = prims.slice_prim(t22, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t23: "cuda:0 bf16[1, 32, 1, 512, 128]", t24 = prims.slice_prim(t22, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t24: "cuda:0 bf16[1, 32, 1, 512, 128]", t25 = prims.slice_prim(t22, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t25: "cuda:0 bf16[1, 32, 1, 512, 128]"] - -cur node 211 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 204 group_bsyms len: 1 -[Symbol name=exp] -[] - -cur node 215 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 220 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 135 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 199 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 239 group_bsyms len: 1 -[Symbol name=split] -[t156 = prims.slice_prim(t155, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t156: "cuda:0 bf16[1, 32, 1, 512, 128]", t157 = prims.slice_prim(t155, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t157: "cuda:0 bf16[1, 32, 1, 512, 128]", t158 = prims.slice_prim(t155, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t158: "cuda:0 bf16[1, 32, 1, 512, 128]"] - -cur node 284 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 311 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 304 group_bsyms len: 1 -[Symbol name=exp] -[] - -cur node 315 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 320 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 235 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 299 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 339 group_bsyms len: 1 -[Symbol name=split] -[t301 = prims.slice_prim(t300, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t301: "cuda:0 bf16[1, 32, 1, 512, 128]", t302 = prims.slice_prim(t300, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t302: "cuda:0 bf16[1, 32, 1, 512, 128]", t303 = prims.slice_prim(t300, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t303: "cuda:0 bf16[1, 32, 1, 512, 128]"] - -cur node 384 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 411 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 404 group_bsyms len: 1 -[Symbol name=exp] -[] - -cur node 415 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 420 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 335 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 399 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 439 group_bsyms len: 1 -[Symbol name=split] -[t446 = prims.slice_prim(t445, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t446: "cuda:0 bf16[1, 32, 1, 512, 128]", t447 = prims.slice_prim(t445, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t447: "cuda:0 bf16[1, 32, 1, 512, 128]", t448 = prims.slice_prim(t445, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t448: "cuda:0 bf16[1, 32, 1, 512, 128]"] - -cur node 484 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 511 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 504 group_bsyms len: 1 -[Symbol name=exp] -[] - -cur node 515 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 520 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 435 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 499 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 539 group_bsyms len: 1 -[Symbol name=split] -[t591 = prims.slice_prim(t590, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t591: "cuda:0 bf16[1, 32, 1, 512, 128]", t592 = prims.slice_prim(t590, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t592: "cuda:0 bf16[1, 32, 1, 512, 128]", t593 = prims.slice_prim(t590, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t593: "cuda:0 bf16[1, 32, 1, 512, 128]"] - -cur node 584 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 611 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 604 group_bsyms len: 1 -[Symbol name=exp] -[] - -cur node 615 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 620 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 535 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 599 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 639 group_bsyms len: 1 -[Symbol name=split] -[t736 = prims.slice_prim(t735, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t736: "cuda:0 bf16[1, 32, 1, 512, 128]", t737 = prims.slice_prim(t735, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t737: "cuda:0 bf16[1, 32, 1, 512, 128]", t738 = prims.slice_prim(t735, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t738: "cuda:0 bf16[1, 32, 1, 512, 128]"] - -cur node 684 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 711 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 704 group_bsyms len: 1 -[Symbol name=exp] -[] - -cur node 715 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 720 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 635 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 699 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 739 group_bsyms len: 1 -[Symbol name=split] -[t881 = prims.slice_prim(t880, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t881: "cuda:0 bf16[1, 32, 1, 512, 128]", t882 = prims.slice_prim(t880, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t882: "cuda:0 bf16[1, 32, 1, 512, 128]", t883 = prims.slice_prim(t880, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t883: "cuda:0 bf16[1, 32, 1, 512, 128]"] - -cur node 784 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 811 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 804 group_bsyms len: 1 -[Symbol name=exp] -[] - -cur node 815 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 820 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 735 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 799 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 839 group_bsyms len: 1 -[Symbol name=split] -[t1026 = prims.slice_prim(t1025, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1026: "cuda:0 bf16[1, 32, 1, 512, 128]", t1027 = prims.slice_prim(t1025, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1027: "cuda:0 bf16[1, 32, 1, 512, 128]", t1028 = prims.slice_prim(t1025, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1028: "cuda:0 bf16[1, 32, 1, 512, 128]"] - -cur node 884 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 911 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 904 group_bsyms len: 1 -[Symbol name=exp] -[] - -cur node 915 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 920 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 835 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 899 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 939 group_bsyms len: 1 -[Symbol name=split] -[t1171 = prims.slice_prim(t1170, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1171: "cuda:0 bf16[1, 32, 1, 512, 128]", t1172 = prims.slice_prim(t1170, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1172: "cuda:0 bf16[1, 32, 1, 512, 128]", t1173 = prims.slice_prim(t1170, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1173: "cuda:0 bf16[1, 32, 1, 512, 128]"] - -cur node 984 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1011 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1004 group_bsyms len: 1 -[Symbol name=exp] -[] - -cur node 1015 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1020 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 935 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 999 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1039 group_bsyms len: 1 -[Symbol name=split] -[t1316 = prims.slice_prim(t1315, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1316: "cuda:0 bf16[1, 32, 1, 512, 128]", t1317 = prims.slice_prim(t1315, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1317: "cuda:0 bf16[1, 32, 1, 512, 128]", t1318 = prims.slice_prim(t1315, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1318: "cuda:0 bf16[1, 32, 1, 512, 128]"] - -cur node 1084 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1111 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1104 group_bsyms len: 1 -[Symbol name=exp] -[] - -cur node 1115 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1120 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1035 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1099 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1139 group_bsyms len: 1 -[Symbol name=split] -[t1461 = prims.slice_prim(t1460, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1461: "cuda:0 bf16[1, 32, 1, 512, 128]", t1462 = prims.slice_prim(t1460, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1462: "cuda:0 bf16[1, 32, 1, 512, 128]", t1463 = prims.slice_prim(t1460, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1463: "cuda:0 bf16[1, 32, 1, 512, 128]"] - -cur node 1184 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1211 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1204 group_bsyms len: 1 -[Symbol name=exp] -[] - -cur node 1215 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1220 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1135 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1199 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1239 group_bsyms len: 1 -[Symbol name=split] -[t1606 = prims.slice_prim(t1605, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1606: "cuda:0 bf16[1, 32, 1, 512, 128]", t1607 = prims.slice_prim(t1605, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1607: "cuda:0 bf16[1, 32, 1, 512, 128]", t1608 = prims.slice_prim(t1605, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1608: "cuda:0 bf16[1, 32, 1, 512, 128]"] - -cur node 1284 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1311 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1304 group_bsyms len: 1 -[Symbol name=exp] -[] - -cur node 1315 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1320 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1235 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1299 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1339 group_bsyms len: 1 -[Symbol name=split] -[t1751 = prims.slice_prim(t1750, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1751: "cuda:0 bf16[1, 32, 1, 512, 128]", t1752 = prims.slice_prim(t1750, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1752: "cuda:0 bf16[1, 32, 1, 512, 128]", t1753 = prims.slice_prim(t1750, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1753: "cuda:0 bf16[1, 32, 1, 512, 128]"] - -cur node 1384 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1411 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1404 group_bsyms len: 1 -[Symbol name=exp] -[] - -cur node 1415 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1420 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1335 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1399 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1439 group_bsyms len: 1 -[Symbol name=split] -[t1896 = prims.slice_prim(t1895, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t1896: "cuda:0 bf16[1, 32, 1, 512, 128]", t1897 = prims.slice_prim(t1895, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t1897: "cuda:0 bf16[1, 32, 1, 512, 128]", t1898 = prims.slice_prim(t1895, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t1898: "cuda:0 bf16[1, 32, 1, 512, 128]"] - -cur node 1484 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1511 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1504 group_bsyms len: 1 -[Symbol name=exp] -[] - -cur node 1515 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1520 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1435 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1499 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1539 group_bsyms len: 1 -[Symbol name=split] -[t2041 = prims.slice_prim(t2040, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t2041: "cuda:0 bf16[1, 32, 1, 512, 128]", t2042 = prims.slice_prim(t2040, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t2042: "cuda:0 bf16[1, 32, 1, 512, 128]", t2043 = prims.slice_prim(t2040, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t2043: "cuda:0 bf16[1, 32, 1, 512, 128]"] - -cur node 1584 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1611 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1604 group_bsyms len: 1 -[Symbol name=exp] -[] - -cur node 1615 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1620 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1535 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1599 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1639 group_bsyms len: 1 -[Symbol name=split] -[t2186 = prims.slice_prim(t2185, [0, 0, 0, 0, 0], [1, 32, 1, 512, 128], [1, 1, 1, 1, 1]) # t2186: "cuda:0 bf16[1, 32, 1, 512, 128]", t2187 = prims.slice_prim(t2185, [0, 0, 1, 0, 0], [1, 32, 2, 512, 128], [1, 1, 1, 1, 1]) # t2187: "cuda:0 bf16[1, 32, 1, 512, 128]", t2188 = prims.slice_prim(t2185, [0, 0, 2, 0, 0], [1, 32, 3, 512, 128], [1, 1, 1, 1, 1]) # t2188: "cuda:0 bf16[1, 32, 1, 512, 128]"] - -cur node 1684 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1711 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1704 group_bsyms len: 1 -[Symbol name=exp] -[] - -cur node 1715 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1720 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1635 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1699 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1735 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 132 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 124 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 185 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 218 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1672 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 272 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 657 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1172 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1557 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 157 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 672 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1057 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1572 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 172 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 557 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1072 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1457 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 572 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 957 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1472 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 457 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 972 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1357 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 472 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 857 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1372 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 357 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 872 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1257 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 372 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 757 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1272 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1657 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 257 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 772 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1157 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 140 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 141 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 142 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 212 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 205 group_bsyms len: 1 -[Symbol name=add] -[t109 = prims.add(1.0, t108) # t109: "cuda:0 f32[1, 512, 11008]"] - -cur node 282 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 221 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 240 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 241 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 242 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 285 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 318 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 312 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 305 group_bsyms len: 1 -[Symbol name=add] -[t254 = prims.add(1.0, t253) # t254: "cuda:0 f32[1, 512, 11008]"] - -cur node 321 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 382 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 340 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 341 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 342 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 385 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 418 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 412 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 405 group_bsyms len: 1 -[Symbol name=add] -[t399 = prims.add(1.0, t398) # t399: "cuda:0 f32[1, 512, 11008]"] - -cur node 482 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 421 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 440 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 441 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 442 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 485 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 518 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 512 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 505 group_bsyms len: 1 -[Symbol name=add] -[t544 = prims.add(1.0, t543) # t544: "cuda:0 f32[1, 512, 11008]"] - -cur node 521 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 582 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 540 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 541 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 542 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 585 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 618 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 612 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 605 group_bsyms len: 1 -[Symbol name=add] -[t689 = prims.add(1.0, t688) # t689: "cuda:0 f32[1, 512, 11008]"] - -cur node 682 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 621 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 640 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 641 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 642 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 685 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 718 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 712 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 705 group_bsyms len: 1 -[Symbol name=add] -[t834 = prims.add(1.0, t833) # t834: "cuda:0 f32[1, 512, 11008]"] - -cur node 721 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 782 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 740 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 741 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 742 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 785 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 818 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 812 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 805 group_bsyms len: 1 -[Symbol name=add] -[t979 = prims.add(1.0, t978) # t979: "cuda:0 f32[1, 512, 11008]"] - -cur node 882 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 821 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 840 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 841 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 842 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 885 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 918 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 912 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 905 group_bsyms len: 1 -[Symbol name=add] -[t1124 = prims.add(1.0, t1123) # t1124: "cuda:0 f32[1, 512, 11008]"] - -cur node 921 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 982 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 940 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 941 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 942 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 985 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1018 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1012 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1005 group_bsyms len: 1 -[Symbol name=add] -[t1269 = prims.add(1.0, t1268) # t1269: "cuda:0 f32[1, 512, 11008]"] - -cur node 1082 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1021 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1040 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1041 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1042 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1085 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1118 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1112 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1105 group_bsyms len: 1 -[Symbol name=add] -[t1414 = prims.add(1.0, t1413) # t1414: "cuda:0 f32[1, 512, 11008]"] - -cur node 1121 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1182 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1140 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1141 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1142 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1185 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1218 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1212 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1205 group_bsyms len: 1 -[Symbol name=add] -[t1559 = prims.add(1.0, t1558) # t1559: "cuda:0 f32[1, 512, 11008]"] - -cur node 1282 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1221 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1240 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1241 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1242 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1285 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1318 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1312 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1305 group_bsyms len: 1 -[Symbol name=add] -[t1704 = prims.add(1.0, t1703) # t1704: "cuda:0 f32[1, 512, 11008]"] - -cur node 1321 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1382 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1340 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1341 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1342 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1385 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1418 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1412 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1405 group_bsyms len: 1 -[Symbol name=add] -[t1849 = prims.add(1.0, t1848) # t1849: "cuda:0 f32[1, 512, 11008]"] - -cur node 1482 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1421 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1440 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1441 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1442 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1485 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1518 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1512 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1505 group_bsyms len: 1 -[Symbol name=add] -[t1994 = prims.add(1.0, t1993) # t1994: "cuda:0 f32[1, 512, 11008]"] - -cur node 1521 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1582 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1540 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1541 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1542 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1585 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1618 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1612 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1605 group_bsyms len: 1 -[Symbol name=add] -[t2139 = prims.add(1.0, t2138) # t2139: "cuda:0 f32[1, 512, 11008]"] - -cur node 1682 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1621 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1640 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1641 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1642 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1685 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1718 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1712 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1705 group_bsyms len: 1 -[Symbol name=add] -[t2284 = prims.add(1.0, t2283) # t2284: "cuda:0 f32[1, 512, 11008]"] - -cur node 1721 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 125 group_bsyms len: 1 -[Symbol name=true_divide] -[t9 = prims.div(t8, 4096.0) # t9: "cuda:0 f32[1, 512, 1]"] - -cur node 193 group_bsyms len: 1 -[Symbol name=mul] -[t97 = prims.mul(t86, t96) # t97: "cuda:0 f32[1, 512, 4096]"] - -cur node 186 group_bsyms len: 1 -[Symbol name=mul] -[t87 = prims.mul(t86, t86) # t87: "cuda:0 f32[1, 512, 4096]"] - -cur node 1676 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 276 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 674 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1176 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1574 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 174 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 676 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1074 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1576 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 176 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 574 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1076 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1474 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 576 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 974 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1476 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 474 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 976 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1374 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 476 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 874 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1376 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 374 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 876 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1274 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 376 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 774 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1276 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1674 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 274 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 776 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1174 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 173 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 143 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 158 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 175 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 177 group_bsyms len: 1 -[Symbol name=cudnn_sdpa_fwd] -[] - -cur node 206 group_bsyms len: 1 -[Symbol name=reciprocal] -[] - -cur node 229 group_bsyms len: 1 -[Symbol name=mul] -[t136 = prims.mul(t125, t135) # t136: "cuda:0 f32[1, 512, 4096]"] - -cur node 222 group_bsyms len: 1 -[Symbol name=mul] -[t126 = prims.mul(t125, t125) # t126: "cuda:0 f32[1, 512, 4096]"] - -cur node 273 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 243 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 258 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 275 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 277 group_bsyms len: 1 -[Symbol name=cudnn_sdpa_fwd] -[] - -cur node 293 group_bsyms len: 1 -[Symbol name=mul] -[t242 = prims.mul(t231, t241) # t242: "cuda:0 f32[1, 512, 4096]"] - -cur node 286 group_bsyms len: 1 -[Symbol name=mul] -[t232 = prims.mul(t231, t231) # t232: "cuda:0 f32[1, 512, 4096]"] - -cur node 306 group_bsyms len: 1 -[Symbol name=reciprocal] -[] - -cur node 329 group_bsyms len: 1 -[Symbol name=mul] -[t281 = prims.mul(t270, t280) # t281: "cuda:0 f32[1, 512, 4096]"] - -cur node 322 group_bsyms len: 1 -[Symbol name=mul] -[t271 = prims.mul(t270, t270) # t271: "cuda:0 f32[1, 512, 4096]"] - -cur node 373 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 343 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 358 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 375 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 377 group_bsyms len: 1 -[Symbol name=cudnn_sdpa_fwd] -[] - -cur node 393 group_bsyms len: 1 -[Symbol name=mul] -[t387 = prims.mul(t376, t386) # t387: "cuda:0 f32[1, 512, 4096]"] - -cur node 386 group_bsyms len: 1 -[Symbol name=mul] -[t377 = prims.mul(t376, t376) # t377: "cuda:0 f32[1, 512, 4096]"] - -cur node 406 group_bsyms len: 1 -[Symbol name=reciprocal] -[] - -cur node 429 group_bsyms len: 1 -[Symbol name=mul] -[t426 = prims.mul(t415, t425) # t426: "cuda:0 f32[1, 512, 4096]"] - -cur node 422 group_bsyms len: 1 -[Symbol name=mul] -[t416 = prims.mul(t415, t415) # t416: "cuda:0 f32[1, 512, 4096]"] - -cur node 473 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 443 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 458 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 475 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 477 group_bsyms len: 1 -[Symbol name=cudnn_sdpa_fwd] -[] - -cur node 493 group_bsyms len: 1 -[Symbol name=mul] -[t532 = prims.mul(t521, t531) # t532: "cuda:0 f32[1, 512, 4096]"] - -cur node 486 group_bsyms len: 1 -[Symbol name=mul] -[t522 = prims.mul(t521, t521) # t522: "cuda:0 f32[1, 512, 4096]"] - -cur node 506 group_bsyms len: 1 -[Symbol name=reciprocal] -[] - -cur node 529 group_bsyms len: 1 -[Symbol name=mul] -[t571 = prims.mul(t560, t570) # t571: "cuda:0 f32[1, 512, 4096]"] - -cur node 522 group_bsyms len: 1 -[Symbol name=mul] -[t561 = prims.mul(t560, t560) # t561: "cuda:0 f32[1, 512, 4096]"] - -cur node 573 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 543 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 558 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 575 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 577 group_bsyms len: 1 -[Symbol name=cudnn_sdpa_fwd] -[] - -cur node 593 group_bsyms len: 1 -[Symbol name=mul] -[t677 = prims.mul(t666, t676) # t677: "cuda:0 f32[1, 512, 4096]"] - -cur node 586 group_bsyms len: 1 -[Symbol name=mul] -[t667 = prims.mul(t666, t666) # t667: "cuda:0 f32[1, 512, 4096]"] - -cur node 606 group_bsyms len: 1 -[Symbol name=reciprocal] -[] - -cur node 629 group_bsyms len: 1 -[Symbol name=mul] -[t716 = prims.mul(t705, t715) # t716: "cuda:0 f32[1, 512, 4096]"] - -cur node 622 group_bsyms len: 1 -[Symbol name=mul] -[t706 = prims.mul(t705, t705) # t706: "cuda:0 f32[1, 512, 4096]"] - -cur node 673 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 643 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 658 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 675 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 677 group_bsyms len: 1 -[Symbol name=cudnn_sdpa_fwd] -[] - -cur node 693 group_bsyms len: 1 -[Symbol name=mul] -[t822 = prims.mul(t811, t821) # t822: "cuda:0 f32[1, 512, 4096]"] - -cur node 686 group_bsyms len: 1 -[Symbol name=mul] -[t812 = prims.mul(t811, t811) # t812: "cuda:0 f32[1, 512, 4096]"] - -cur node 706 group_bsyms len: 1 -[Symbol name=reciprocal] -[] - -cur node 729 group_bsyms len: 1 -[Symbol name=mul] -[t861 = prims.mul(t850, t860) # t861: "cuda:0 f32[1, 512, 4096]"] - -cur node 722 group_bsyms len: 1 -[Symbol name=mul] -[t851 = prims.mul(t850, t850) # t851: "cuda:0 f32[1, 512, 4096]"] - -cur node 773 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 743 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 758 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 775 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 777 group_bsyms len: 1 -[Symbol name=cudnn_sdpa_fwd] -[] - -cur node 793 group_bsyms len: 1 -[Symbol name=mul] -[t967 = prims.mul(t956, t966) # t967: "cuda:0 f32[1, 512, 4096]"] - -cur node 786 group_bsyms len: 1 -[Symbol name=mul] -[t957 = prims.mul(t956, t956) # t957: "cuda:0 f32[1, 512, 4096]"] - -cur node 806 group_bsyms len: 1 -[Symbol name=reciprocal] -[] - -cur node 829 group_bsyms len: 1 -[Symbol name=mul] -[t1006 = prims.mul(t995, t1005) # t1006: "cuda:0 f32[1, 512, 4096]"] - -cur node 822 group_bsyms len: 1 -[Symbol name=mul] -[t996 = prims.mul(t995, t995) # t996: "cuda:0 f32[1, 512, 4096]"] - -cur node 873 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 843 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 858 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 875 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 877 group_bsyms len: 1 -[Symbol name=cudnn_sdpa_fwd] -[] - -cur node 893 group_bsyms len: 1 -[Symbol name=mul] -[t1112 = prims.mul(t1101, t1111) # t1112: "cuda:0 f32[1, 512, 4096]"] - -cur node 886 group_bsyms len: 1 -[Symbol name=mul] -[t1102 = prims.mul(t1101, t1101) # t1102: "cuda:0 f32[1, 512, 4096]"] - -cur node 906 group_bsyms len: 1 -[Symbol name=reciprocal] -[] - -cur node 929 group_bsyms len: 1 -[Symbol name=mul] -[t1151 = prims.mul(t1140, t1150) # t1151: "cuda:0 f32[1, 512, 4096]"] - -cur node 922 group_bsyms len: 1 -[Symbol name=mul] -[t1141 = prims.mul(t1140, t1140) # t1141: "cuda:0 f32[1, 512, 4096]"] - -cur node 973 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 943 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 958 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 975 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 977 group_bsyms len: 1 -[Symbol name=cudnn_sdpa_fwd] -[] - -cur node 993 group_bsyms len: 1 -[Symbol name=mul] -[t1257 = prims.mul(t1246, t1256) # t1257: "cuda:0 f32[1, 512, 4096]"] - -cur node 986 group_bsyms len: 1 -[Symbol name=mul] -[t1247 = prims.mul(t1246, t1246) # t1247: "cuda:0 f32[1, 512, 4096]"] - -cur node 1006 group_bsyms len: 1 -[Symbol name=reciprocal] -[] - -cur node 1029 group_bsyms len: 1 -[Symbol name=mul] -[t1296 = prims.mul(t1285, t1295) # t1296: "cuda:0 f32[1, 512, 4096]"] - -cur node 1022 group_bsyms len: 1 -[Symbol name=mul] -[t1286 = prims.mul(t1285, t1285) # t1286: "cuda:0 f32[1, 512, 4096]"] - -cur node 1073 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1043 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1058 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1075 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1077 group_bsyms len: 1 -[Symbol name=cudnn_sdpa_fwd] -[] - -cur node 1093 group_bsyms len: 1 -[Symbol name=mul] -[t1402 = prims.mul(t1391, t1401) # t1402: "cuda:0 f32[1, 512, 4096]"] - -cur node 1086 group_bsyms len: 1 -[Symbol name=mul] -[t1392 = prims.mul(t1391, t1391) # t1392: "cuda:0 f32[1, 512, 4096]"] - -cur node 1106 group_bsyms len: 1 -[Symbol name=reciprocal] -[] - -cur node 1129 group_bsyms len: 1 -[Symbol name=mul] -[t1441 = prims.mul(t1430, t1440) # t1441: "cuda:0 f32[1, 512, 4096]"] - -cur node 1122 group_bsyms len: 1 -[Symbol name=mul] -[t1431 = prims.mul(t1430, t1430) # t1431: "cuda:0 f32[1, 512, 4096]"] - -cur node 1173 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1143 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1158 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1175 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1177 group_bsyms len: 1 -[Symbol name=cudnn_sdpa_fwd] -[] - -cur node 1193 group_bsyms len: 1 -[Symbol name=mul] -[t1547 = prims.mul(t1536, t1546) # t1547: "cuda:0 f32[1, 512, 4096]"] - -cur node 1186 group_bsyms len: 1 -[Symbol name=mul] -[t1537 = prims.mul(t1536, t1536) # t1537: "cuda:0 f32[1, 512, 4096]"] - -cur node 1206 group_bsyms len: 1 -[Symbol name=reciprocal] -[] - -cur node 1229 group_bsyms len: 1 -[Symbol name=mul] -[t1586 = prims.mul(t1575, t1585) # t1586: "cuda:0 f32[1, 512, 4096]"] - -cur node 1222 group_bsyms len: 1 -[Symbol name=mul] -[t1576 = prims.mul(t1575, t1575) # t1576: "cuda:0 f32[1, 512, 4096]"] - -cur node 1273 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1243 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1258 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1275 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1277 group_bsyms len: 1 -[Symbol name=cudnn_sdpa_fwd] -[] - -cur node 1293 group_bsyms len: 1 -[Symbol name=mul] -[t1692 = prims.mul(t1681, t1691) # t1692: "cuda:0 f32[1, 512, 4096]"] - -cur node 1286 group_bsyms len: 1 -[Symbol name=mul] -[t1682 = prims.mul(t1681, t1681) # t1682: "cuda:0 f32[1, 512, 4096]"] - -cur node 1306 group_bsyms len: 1 -[Symbol name=reciprocal] -[] - -cur node 1329 group_bsyms len: 1 -[Symbol name=mul] -[t1731 = prims.mul(t1720, t1730) # t1731: "cuda:0 f32[1, 512, 4096]"] - -cur node 1322 group_bsyms len: 1 -[Symbol name=mul] -[t1721 = prims.mul(t1720, t1720) # t1721: "cuda:0 f32[1, 512, 4096]"] - -cur node 1373 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1343 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1358 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1375 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1377 group_bsyms len: 1 -[Symbol name=cudnn_sdpa_fwd] -[] - -cur node 1393 group_bsyms len: 1 -[Symbol name=mul] -[t1837 = prims.mul(t1826, t1836) # t1837: "cuda:0 f32[1, 512, 4096]"] - -cur node 1386 group_bsyms len: 1 -[Symbol name=mul] -[t1827 = prims.mul(t1826, t1826) # t1827: "cuda:0 f32[1, 512, 4096]"] - -cur node 1406 group_bsyms len: 1 -[Symbol name=reciprocal] -[] - -cur node 1429 group_bsyms len: 1 -[Symbol name=mul] -[t1876 = prims.mul(t1865, t1875) # t1876: "cuda:0 f32[1, 512, 4096]"] - -cur node 1422 group_bsyms len: 1 -[Symbol name=mul] -[t1866 = prims.mul(t1865, t1865) # t1866: "cuda:0 f32[1, 512, 4096]"] - -cur node 1473 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1443 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1458 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1475 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1477 group_bsyms len: 1 -[Symbol name=cudnn_sdpa_fwd] -[] - -cur node 1493 group_bsyms len: 1 -[Symbol name=mul] -[t1982 = prims.mul(t1971, t1981) # t1982: "cuda:0 f32[1, 512, 4096]"] - -cur node 1486 group_bsyms len: 1 -[Symbol name=mul] -[t1972 = prims.mul(t1971, t1971) # t1972: "cuda:0 f32[1, 512, 4096]"] - -cur node 1506 group_bsyms len: 1 -[Symbol name=reciprocal] -[] - -cur node 1529 group_bsyms len: 1 -[Symbol name=mul] -[t2021 = prims.mul(t2010, t2020) # t2021: "cuda:0 f32[1, 512, 4096]"] - -cur node 1522 group_bsyms len: 1 -[Symbol name=mul] -[t2011 = prims.mul(t2010, t2010) # t2011: "cuda:0 f32[1, 512, 4096]"] - -cur node 1573 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1543 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1558 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1575 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1577 group_bsyms len: 1 -[Symbol name=cudnn_sdpa_fwd] -[] - -cur node 1593 group_bsyms len: 1 -[Symbol name=mul] -[t2127 = prims.mul(t2116, t2126) # t2127: "cuda:0 f32[1, 512, 4096]"] - -cur node 1586 group_bsyms len: 1 -[Symbol name=mul] -[t2117 = prims.mul(t2116, t2116) # t2117: "cuda:0 f32[1, 512, 4096]"] - -cur node 1606 group_bsyms len: 1 -[Symbol name=reciprocal] -[] - -cur node 1629 group_bsyms len: 1 -[Symbol name=mul] -[t2166 = prims.mul(t2155, t2165) # t2166: "cuda:0 f32[1, 512, 4096]"] - -cur node 1622 group_bsyms len: 1 -[Symbol name=mul] -[t2156 = prims.mul(t2155, t2155) # t2156: "cuda:0 f32[1, 512, 4096]"] - -cur node 1673 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1643 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1658 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1675 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1677 group_bsyms len: 1 -[Symbol name=cudnn_sdpa_fwd] -[] - -cur node 1693 group_bsyms len: 1 -[Symbol name=mul] -[t2272 = prims.mul(t2261, t2271) # t2272: "cuda:0 f32[1, 512, 4096]"] - -cur node 1686 group_bsyms len: 1 -[Symbol name=mul] -[t2262 = prims.mul(t2261, t2261) # t2262: "cuda:0 f32[1, 512, 4096]"] - -cur node 1706 group_bsyms len: 1 -[Symbol name=reciprocal] -[] - -cur node 1729 group_bsyms len: 1 -[Symbol name=mul] -[t2311 = prims.mul(t2300, t2310) # t2311: "cuda:0 f32[1, 512, 4096]"] - -cur node 1722 group_bsyms len: 1 -[Symbol name=mul] -[t2301 = prims.mul(t2300, t2300) # t2301: "cuda:0 f32[1, 512, 4096]"] - -cur node 126 group_bsyms len: 1 -[Symbol name=add] -[t10 = prims.add(t9, 1e-05) # t10: "cuda:0 f32[1, 512, 1]"] - -cur node 194 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 187 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 144 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 145 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 151 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 160 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 166 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 159 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 178 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 207 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 230 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 223 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 251 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 244 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 245 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 266 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 259 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 260 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 278 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 294 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 287 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 307 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 330 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 323 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 344 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 345 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 351 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 360 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 366 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 359 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 378 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 394 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 387 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 407 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 430 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 423 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 451 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 444 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 445 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 466 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 459 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 460 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 478 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 494 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 487 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 507 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 530 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 523 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 544 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 545 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 551 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 560 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 566 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 559 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 578 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 594 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 587 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 607 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 630 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 623 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 651 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 644 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 645 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 666 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 659 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 660 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 678 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 694 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 687 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 707 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 730 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 723 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 744 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 745 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 751 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 760 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 766 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 759 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 778 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 794 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 787 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 807 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 830 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 823 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 851 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 844 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 845 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 866 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 859 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 860 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 878 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 894 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 887 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 907 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 930 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 923 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 944 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 945 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 951 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 960 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 966 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 959 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 978 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 994 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 987 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 1007 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1030 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1023 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 1051 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1044 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1045 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1066 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1059 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1060 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1078 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 1094 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1087 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 1107 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1130 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1123 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 1144 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1145 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1151 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1160 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1166 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1159 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1178 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 1194 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1187 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 1207 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1230 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1223 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 1251 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1244 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1245 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1266 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1259 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1260 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1278 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 1294 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1287 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 1307 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1330 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1323 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 1344 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1345 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1351 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1360 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1366 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1359 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1378 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 1394 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1387 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 1407 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1430 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1423 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 1451 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1444 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1445 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1466 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1459 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1460 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1478 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 1494 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1487 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 1507 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1530 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1523 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 1544 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1545 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1551 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1560 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1566 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1559 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1578 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 1594 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1587 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 1607 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1630 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1623 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 1651 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1644 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1645 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1666 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1659 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1660 group_bsyms len: 1 -[Symbol name=slice_prim] -[] - -cur node 1678 group_bsyms len: 1 -[Symbol name=transpose] -[] - -cur node 1694 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1687 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 1707 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1730 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1723 group_bsyms len: 1 -[Symbol name=sum] -[] - -cur node 127 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 196 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 188 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 149 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 146 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 161 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 164 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 179 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 209 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 232 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 224 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 249 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 246 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 264 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 261 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 279 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 296 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 288 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 309 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 332 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 324 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 349 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 346 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 361 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 364 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 379 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 396 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 388 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 409 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 432 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 424 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 449 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 446 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 464 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 461 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 479 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 496 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 488 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 509 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 532 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 524 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 549 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 546 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 561 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 564 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 579 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 596 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 588 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 609 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 632 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 624 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 649 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 646 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 664 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 661 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 679 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 696 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 688 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 709 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 732 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 724 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 749 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 746 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 761 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 764 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 779 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 796 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 788 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 809 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 832 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 824 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 849 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 846 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 864 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 861 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 879 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 896 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 888 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 909 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 932 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 924 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 949 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 946 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 961 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 964 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 979 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 996 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 988 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1009 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1032 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1024 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1049 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1046 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1064 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1061 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1079 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1096 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1088 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1109 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1132 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1124 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1149 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1146 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1161 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1164 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1179 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1196 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1188 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1209 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1232 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1224 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1249 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1246 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1264 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1261 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1279 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1296 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1288 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1309 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1332 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1324 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1349 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1346 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1361 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1364 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1379 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1396 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1388 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1409 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1432 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1424 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1449 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1446 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1464 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1461 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1479 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1496 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1488 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1509 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1532 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1524 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1549 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1546 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1561 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1564 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1579 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1596 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1588 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1609 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1632 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1624 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1649 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1646 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1664 group_bsyms len: 1 -[Symbol name=cat] -[] - -cur node 1661 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1679 group_bsyms len: 1 -[Symbol name=reshape] -[] - -cur node 1696 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1688 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1709 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1732 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1724 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 128 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 189 group_bsyms len: 1 -[Symbol name=true_divide] -[t92 = prims.div(t90, 4096.0) # t92: "cuda:0 f32[1, 512, 1]"] - -cur node 154 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 147 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 162 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 169 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 225 group_bsyms len: 1 -[Symbol name=true_divide] -[t131 = prims.div(t129, 4096.0) # t131: "cuda:0 f32[1, 512, 1]"] - -cur node 254 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 247 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 269 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 262 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 289 group_bsyms len: 1 -[Symbol name=true_divide] -[t237 = prims.div(t235, 4096.0) # t237: "cuda:0 f32[1, 512, 1]"] - -cur node 325 group_bsyms len: 1 -[Symbol name=true_divide] -[t276 = prims.div(t274, 4096.0) # t276: "cuda:0 f32[1, 512, 1]"] - -cur node 354 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 347 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 362 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 369 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 389 group_bsyms len: 1 -[Symbol name=true_divide] -[t382 = prims.div(t380, 4096.0) # t382: "cuda:0 f32[1, 512, 1]"] - -cur node 425 group_bsyms len: 1 -[Symbol name=true_divide] -[t421 = prims.div(t419, 4096.0) # t421: "cuda:0 f32[1, 512, 1]"] - -cur node 454 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 447 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 469 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 462 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 489 group_bsyms len: 1 -[Symbol name=true_divide] -[t527 = prims.div(t525, 4096.0) # t527: "cuda:0 f32[1, 512, 1]"] - -cur node 525 group_bsyms len: 1 -[Symbol name=true_divide] -[t566 = prims.div(t564, 4096.0) # t566: "cuda:0 f32[1, 512, 1]"] - -cur node 554 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 547 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 562 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 569 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 589 group_bsyms len: 1 -[Symbol name=true_divide] -[t672 = prims.div(t670, 4096.0) # t672: "cuda:0 f32[1, 512, 1]"] - -cur node 625 group_bsyms len: 1 -[Symbol name=true_divide] -[t711 = prims.div(t709, 4096.0) # t711: "cuda:0 f32[1, 512, 1]"] - -cur node 654 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 647 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 669 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 662 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 689 group_bsyms len: 1 -[Symbol name=true_divide] -[t817 = prims.div(t815, 4096.0) # t817: "cuda:0 f32[1, 512, 1]"] - -cur node 725 group_bsyms len: 1 -[Symbol name=true_divide] -[t856 = prims.div(t854, 4096.0) # t856: "cuda:0 f32[1, 512, 1]"] - -cur node 754 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 747 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 762 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 769 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 789 group_bsyms len: 1 -[Symbol name=true_divide] -[t962 = prims.div(t960, 4096.0) # t962: "cuda:0 f32[1, 512, 1]"] - -cur node 825 group_bsyms len: 1 -[Symbol name=true_divide] -[t1001 = prims.div(t999, 4096.0) # t1001: "cuda:0 f32[1, 512, 1]"] - -cur node 854 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 847 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 869 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 862 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 889 group_bsyms len: 1 -[Symbol name=true_divide] -[t1107 = prims.div(t1105, 4096.0) # t1107: "cuda:0 f32[1, 512, 1]"] - -cur node 925 group_bsyms len: 1 -[Symbol name=true_divide] -[t1146 = prims.div(t1144, 4096.0) # t1146: "cuda:0 f32[1, 512, 1]"] - -cur node 954 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 947 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 962 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 969 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 989 group_bsyms len: 1 -[Symbol name=true_divide] -[t1252 = prims.div(t1250, 4096.0) # t1252: "cuda:0 f32[1, 512, 1]"] - -cur node 1025 group_bsyms len: 1 -[Symbol name=true_divide] -[t1291 = prims.div(t1289, 4096.0) # t1291: "cuda:0 f32[1, 512, 1]"] - -cur node 1054 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1047 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1069 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1062 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1089 group_bsyms len: 1 -[Symbol name=true_divide] -[t1397 = prims.div(t1395, 4096.0) # t1397: "cuda:0 f32[1, 512, 1]"] - -cur node 1125 group_bsyms len: 1 -[Symbol name=true_divide] -[t1436 = prims.div(t1434, 4096.0) # t1436: "cuda:0 f32[1, 512, 1]"] - -cur node 1154 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1147 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1162 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1169 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1189 group_bsyms len: 1 -[Symbol name=true_divide] -[t1542 = prims.div(t1540, 4096.0) # t1542: "cuda:0 f32[1, 512, 1]"] - -cur node 1225 group_bsyms len: 1 -[Symbol name=true_divide] -[t1581 = prims.div(t1579, 4096.0) # t1581: "cuda:0 f32[1, 512, 1]"] - -cur node 1254 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1247 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1269 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1262 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1289 group_bsyms len: 1 -[Symbol name=true_divide] -[t1687 = prims.div(t1685, 4096.0) # t1687: "cuda:0 f32[1, 512, 1]"] - -cur node 1325 group_bsyms len: 1 -[Symbol name=true_divide] -[t1726 = prims.div(t1724, 4096.0) # t1726: "cuda:0 f32[1, 512, 1]"] - -cur node 1354 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1347 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1362 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1369 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1389 group_bsyms len: 1 -[Symbol name=true_divide] -[t1832 = prims.div(t1830, 4096.0) # t1832: "cuda:0 f32[1, 512, 1]"] - -cur node 1425 group_bsyms len: 1 -[Symbol name=true_divide] -[t1871 = prims.div(t1869, 4096.0) # t1871: "cuda:0 f32[1, 512, 1]"] - -cur node 1454 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1447 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1469 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1462 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1489 group_bsyms len: 1 -[Symbol name=true_divide] -[t1977 = prims.div(t1975, 4096.0) # t1977: "cuda:0 f32[1, 512, 1]"] - -cur node 1525 group_bsyms len: 1 -[Symbol name=true_divide] -[t2016 = prims.div(t2014, 4096.0) # t2016: "cuda:0 f32[1, 512, 1]"] - -cur node 1554 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1547 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1562 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1569 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1589 group_bsyms len: 1 -[Symbol name=true_divide] -[t2122 = prims.div(t2120, 4096.0) # t2122: "cuda:0 f32[1, 512, 1]"] - -cur node 1625 group_bsyms len: 1 -[Symbol name=true_divide] -[t2161 = prims.div(t2159, 4096.0) # t2161: "cuda:0 f32[1, 512, 1]"] - -cur node 1654 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1647 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1669 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1662 group_bsyms len: 1 -[Symbol name=neg] -[] - -cur node 1689 group_bsyms len: 1 -[Symbol name=true_divide] -[t2267 = prims.div(t2265, 4096.0) # t2267: "cuda:0 f32[1, 512, 1]"] - -cur node 1725 group_bsyms len: 1 -[Symbol name=true_divide] -[t2306 = prims.div(t2304, 4096.0) # t2306: "cuda:0 f32[1, 512, 1]"] - -cur node 190 group_bsyms len: 1 -[Symbol name=add] -[t94 = prims.add(t92, 1e-05) # t94: "cuda:0 f32[1, 512, 1]"] - -cur node 148 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 163 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 226 group_bsyms len: 1 -[Symbol name=add] -[t133 = prims.add(t131, 1e-05) # t133: "cuda:0 f32[1, 512, 1]"] - -cur node 248 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 263 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 290 group_bsyms len: 1 -[Symbol name=add] -[t239 = prims.add(t237, 1e-05) # t239: "cuda:0 f32[1, 512, 1]"] - -cur node 326 group_bsyms len: 1 -[Symbol name=add] -[t278 = prims.add(t276, 1e-05) # t278: "cuda:0 f32[1, 512, 1]"] - -cur node 348 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 363 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 390 group_bsyms len: 1 -[Symbol name=add] -[t384 = prims.add(t382, 1e-05) # t384: "cuda:0 f32[1, 512, 1]"] - -cur node 426 group_bsyms len: 1 -[Symbol name=add] -[t423 = prims.add(t421, 1e-05) # t423: "cuda:0 f32[1, 512, 1]"] - -cur node 448 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 463 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 490 group_bsyms len: 1 -[Symbol name=add] -[t529 = prims.add(t527, 1e-05) # t529: "cuda:0 f32[1, 512, 1]"] - -cur node 526 group_bsyms len: 1 -[Symbol name=add] -[t568 = prims.add(t566, 1e-05) # t568: "cuda:0 f32[1, 512, 1]"] - -cur node 548 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 563 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 590 group_bsyms len: 1 -[Symbol name=add] -[t674 = prims.add(t672, 1e-05) # t674: "cuda:0 f32[1, 512, 1]"] - -cur node 626 group_bsyms len: 1 -[Symbol name=add] -[t713 = prims.add(t711, 1e-05) # t713: "cuda:0 f32[1, 512, 1]"] - -cur node 648 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 663 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 690 group_bsyms len: 1 -[Symbol name=add] -[t819 = prims.add(t817, 1e-05) # t819: "cuda:0 f32[1, 512, 1]"] - -cur node 726 group_bsyms len: 1 -[Symbol name=add] -[t858 = prims.add(t856, 1e-05) # t858: "cuda:0 f32[1, 512, 1]"] - -cur node 748 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 763 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 790 group_bsyms len: 1 -[Symbol name=add] -[t964 = prims.add(t962, 1e-05) # t964: "cuda:0 f32[1, 512, 1]"] - -cur node 826 group_bsyms len: 1 -[Symbol name=add] -[t1003 = prims.add(t1001, 1e-05) # t1003: "cuda:0 f32[1, 512, 1]"] - -cur node 848 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 863 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 890 group_bsyms len: 1 -[Symbol name=add] -[t1109 = prims.add(t1107, 1e-05) # t1109: "cuda:0 f32[1, 512, 1]"] - -cur node 926 group_bsyms len: 1 -[Symbol name=add] -[t1148 = prims.add(t1146, 1e-05) # t1148: "cuda:0 f32[1, 512, 1]"] - -cur node 948 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 963 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 990 group_bsyms len: 1 -[Symbol name=add] -[t1254 = prims.add(t1252, 1e-05) # t1254: "cuda:0 f32[1, 512, 1]"] - -cur node 1026 group_bsyms len: 1 -[Symbol name=add] -[t1293 = prims.add(t1291, 1e-05) # t1293: "cuda:0 f32[1, 512, 1]"] - -cur node 1048 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1063 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1090 group_bsyms len: 1 -[Symbol name=add] -[t1399 = prims.add(t1397, 1e-05) # t1399: "cuda:0 f32[1, 512, 1]"] - -cur node 1126 group_bsyms len: 1 -[Symbol name=add] -[t1438 = prims.add(t1436, 1e-05) # t1438: "cuda:0 f32[1, 512, 1]"] - -cur node 1148 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1163 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1190 group_bsyms len: 1 -[Symbol name=add] -[t1544 = prims.add(t1542, 1e-05) # t1544: "cuda:0 f32[1, 512, 1]"] - -cur node 1226 group_bsyms len: 1 -[Symbol name=add] -[t1583 = prims.add(t1581, 1e-05) # t1583: "cuda:0 f32[1, 512, 1]"] - -cur node 1248 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1263 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1290 group_bsyms len: 1 -[Symbol name=add] -[t1689 = prims.add(t1687, 1e-05) # t1689: "cuda:0 f32[1, 512, 1]"] - -cur node 1326 group_bsyms len: 1 -[Symbol name=add] -[t1728 = prims.add(t1726, 1e-05) # t1728: "cuda:0 f32[1, 512, 1]"] - -cur node 1348 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1363 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1390 group_bsyms len: 1 -[Symbol name=add] -[t1834 = prims.add(t1832, 1e-05) # t1834: "cuda:0 f32[1, 512, 1]"] - -cur node 1426 group_bsyms len: 1 -[Symbol name=add] -[t1873 = prims.add(t1871, 1e-05) # t1873: "cuda:0 f32[1, 512, 1]"] - -cur node 1448 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1463 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1490 group_bsyms len: 1 -[Symbol name=add] -[t1979 = prims.add(t1977, 1e-05) # t1979: "cuda:0 f32[1, 512, 1]"] - -cur node 1526 group_bsyms len: 1 -[Symbol name=add] -[t2018 = prims.add(t2016, 1e-05) # t2018: "cuda:0 f32[1, 512, 1]"] - -cur node 1548 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1563 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1590 group_bsyms len: 1 -[Symbol name=add] -[t2124 = prims.add(t2122, 1e-05) # t2124: "cuda:0 f32[1, 512, 1]"] - -cur node 1626 group_bsyms len: 1 -[Symbol name=add] -[t2163 = prims.add(t2161, 1e-05) # t2163: "cuda:0 f32[1, 512, 1]"] - -cur node 1648 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1663 group_bsyms len: 1 -[Symbol name=convert_element_type] -[] - -cur node 1690 group_bsyms len: 1 -[Symbol name=add] -[t2269 = prims.add(t2267, 1e-05) # t2269: "cuda:0 f32[1, 512, 1]"] - -cur node 1726 group_bsyms len: 1 -[Symbol name=add] -[t2308 = prims.add(t2306, 1e-05) # t2308: "cuda:0 f32[1, 512, 1]"] - -cur node 191 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 227 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 291 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 327 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 391 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 427 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 491 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 527 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 591 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 627 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 691 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 727 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 791 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 827 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 891 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 927 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 991 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 1027 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 1091 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 1127 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 1191 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 1227 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 1291 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 1327 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 1391 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 1427 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 1491 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 1527 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 1591 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 1627 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 1691 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 1727 group_bsyms len: 1 -[Symbol name=rsqrt] -[] - -cur node 192 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 228 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 292 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 328 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 392 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 428 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 492 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 528 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 592 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 628 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 692 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 728 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 792 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 828 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 892 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 928 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 992 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1028 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1092 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1128 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1192 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1228 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1292 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1328 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1392 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1428 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1492 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1528 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1592 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1628 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] - -cur node 1692 group_bsyms len: 1 -[Symbol name=broadcast_in_dim] -[] diff --git a/examples/dev/my_graph.png b/examples/dev/my_graph.png deleted file mode 100644 index 08e87a916a60d35ebd037f88de169c4df81e2578..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 144 zcmeAS@N?(olHy`uVBq!ia0vp^+#t-s1|(OmDOUqhY)RhkE)4%caKYZ?lYt_xo-U3d z5>wA!*vNapfQR{DdO}9?o5iW2l@lri-Y@JCFk?IPOKRV)-5<^Lg_#QY`d29_*)3cW o-B4;^{QX1nhd)c%419OTIjHb$ogtHI2Q-ku)78&qol`;+0Q9ObbpQYW diff --git a/examples/dev/simple.py b/examples/dev/simple.py index 631caa8a89..b33f31b21e 100644 --- a/examples/dev/simple.py +++ b/examples/dev/simple.py @@ -10,30 +10,28 @@ def __init__(self, in_features, out_features) -> None: def forward(self, x: torch.Tensor): a = x + x - # a_silu = self.silu(a) b: torch.Tensor = self.linear(a) c = b * b - # c_silu = self.silu(c) d = c + c - return d + return self.silu(d) with torch.device('cuda'): - multiplier = 10 + multiplier = 1000 in_features = 20 * multiplier out_features = 30 * multiplier model = Module(in_features, out_features) x = torch.randn(128, in_features) - jmodel = thunder.jit(model) + jmodel = thunder.jit(model, autotune_executors=True, executors=['nvfuser', 'torchcompile', 'torch']) - for _ in range(100): - start = time.time_ns() + for _ in range(10): + start = time.perf_counter_ns() ans = jmodel(x) - end = time.time_ns() - # print('---------------------------------------------- all traces') - # for t in thunder.last_traces(jmodel): - # print(t) - # print('##############################################') + torch.autograd.grad(ans.sum(), model.parameters()) + torch.cuda.synchronize() + end = time.perf_counter_ns() print(f'tot time = {(end - start) / 1000000} ms') + print(thunder.last_backward_traces(jmodel)[-1]) + diff --git a/examples/dev/simple_log.out b/examples/dev/simple_log.out deleted file mode 100644 index aa715847ae..0000000000 --- a/examples/dev/simple_log.out +++ /dev/null @@ -1,132 +0,0 @@ -Interpretation used: INTERPRETATION_OPTIONS.TRANSLATE_PYTHON -comp trce before -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -# No signature available -comp trace after -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x): - # x: "cuda:0 f32[2, 2]" - - # /workspace/pj/lightning-thunder/examples/dev/simple.py:10: a = x + x - result = ltorch.add(x, x, alpha=None) # result: "cuda:0 f32[2, 2]" - # result = prims.add(x, x) # result: "cuda:0 f32[2, 2]" - return result -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x): - # x: "cuda:0 f32[2, 2]" - - # /workspace/pj/lightning-thunder/examples/dev/simple.py:10: a = x + x - result = ltorch.add(x, x, alpha=None) # result: "cuda:0 f32[2, 2]" - # result = prims.add(x, x) # result: "cuda:0 f32[2, 2]" - return result -============================================ START: LABEL default -============================================ START: LABEL computation_trc -> backward_trc = None -============================================ START: post_optimization_transforms -[] -============================================ END: post_optimization_transforms -============================================ START: before computation_trc python Callable -# Constructed by Delete Last Used (took 0 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x): - # x: "cuda:0 f32[2, 2]" - [result] = nvFusion0(x) - # result = prims.add(x, x) # result: "cuda:0 f32[2, 2]" - del x - return result -============================================ END: before computation_trc python Callable ----------------------------------------------- all traces -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x): - # x: "cuda:0 f32[2, 2]" - - # /workspace/pj/lightning-thunder/examples/dev/simple.py:10: a = x + x - result = ltorch.add(x, x, alpha=None) # result: "cuda:0 f32[2, 2]" - # result = prims.add(x, x) # result: "cuda:0 f32[2, 2]" - return result -############################################## -# Constructed by Dead Code Elimination (took 0 milliseconds) -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x): - # x: "cuda:0 f32[2, 2]" - - # /workspace/pj/lightning-thunder/examples/dev/simple.py:10: a = x + x - result = ltorch.add(x, x, alpha=None) # result: "cuda:0 f32[2, 2]" - # result = prims.add(x, x) # result: "cuda:0 f32[2, 2]" - return result -############################################## -# Constructed by Dead Code Elimination (took 0 milliseconds) -import thunder -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x): - # x: "cuda:0 f32[2, 2]" - - # /workspace/pj/lightning-thunder/examples/dev/simple.py:10: a = x + x - result = ltorch.add(x, x, alpha=None) # result: "cuda:0 f32[2, 2]" - # result = prims.add(x, x) # result: "cuda:0 f32[2, 2]" - return result -############################################## -# Constructed by Transform for execution (took 1 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x): - # x: "cuda:0 f32[2, 2]" - [result] = nvFusion0(x) - # result = prims.add(x, x) # result: "cuda:0 f32[2, 2]" - return result -############################################## -# Constructed by Delete Last Used (took 0 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def computation(x): - # x: "cuda:0 f32[2, 2]" - [result] = nvFusion0(x) - # result = prims.add(x, x) # result: "cuda:0 f32[2, 2]" - del x - return result -############################################## ----------------------------------------------- ans -tensor([[-1.9710, -4.7323], - [ 0.1026, 0.5416]], device='cuda:0') diff --git a/thunder/__init__.py b/thunder/__init__.py index 84c375bace..c38482ddbc 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -285,6 +285,7 @@ def jit( additional_transforms: list[AdditionalTransform] | None = None, post_optimization_transforms: list[PostOptimizationTransform] | None = None, record_history: bool = False, + autotune_executors: bool = True, **compile_options, # TODO RC1 Make this explicit -- dict of options ) -> Callable: """Just-in-time compile a callable (function or model). @@ -582,7 +583,7 @@ def get_computation_and_inputs(*args, **kwargs): # transform_for_execution and various sorting of symbols, # applying transform_for_execution after this would be # breaking the order of operations - computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps) + computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, autotune_executors, *inps) # Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces # by split_forward_backward extraces = cs.last_traces diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 7f6216a2b6..d22d78da9e 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -465,7 +465,7 @@ def benchmark_traces(self): del res self.debug_msg += f'Trace name = [{label}] - Time = {trace_time / 1000000} ms\n{trace}\n\n' self.log(f'Benchmark trace "{label}" (time = {trace_time / 1000000} ms):\n{trace}') - if trace_time < min_run_time: + if trace_time < min_run_time and label=='priority_list': min_run_time = trace_time optimal_trace = trace best_label = label diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 57bf6c58ea..0b2c393982 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -106,12 +106,12 @@ def backward(ctx, *args): return (None, None, None, None, None, *([None] * n_grads)) # TODO (matteochen): add control for using autotuner or not -def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stats, /, *flat_args): +def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stats, autotune_executors, /, *flat_args): from thunder.core.rematerialization import rematerialize_all_gather, rematerialize_forward_and_backward from thunder.core.transforms import forward_and_backward_from_trace from thunder.distributed.transforms import FSDPCommBucketing from thunder.distributed.utils import sort_data_parallel_syncs, sort_waits, sort_communication_ops - from thunder.executors.passes import del_last_used, autotune_transform_for_execution + from thunder.executors.passes import del_last_used, transform_for_execution, autotune_transform_for_execution from thunder.visualizer.visualizer_helper import Visualizer visualizer = Visualizer(produce_hidden=False) @@ -170,11 +170,17 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat # TODO Restore request for no rematerialization visualizer.set_fw_initial_trace(fw_trace) - fw_extrace = autotune_transform_for_execution( - fw_trace, - executors_list=compile_data.executors_list, - visualizer=visualizer - ) + if autotune_executors: + fw_extrace = autotune_transform_for_execution( + fw_trace, + executors_list=compile_data.executors_list, + visualizer=visualizer + ) + else: + fw_extrace = transform_for_execution( + fw_trace, + executors_list=compile_data.executors_list + ) fw_traces.append(fw_extrace) visualizer.set_fw_optimized_trace(fw_extrace) @@ -210,11 +216,17 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat # Now we can run the optimization passes on the backward trace # TODO Restore request for no rematerialization visualizer.set_bw_initial_trace(bw_trace) - bw_extrace = autotune_transform_for_execution( - bw_trace, - executors_list=compile_data.executors_list, - visualizer=visualizer - ) + if autotune_executors: + bw_extrace = autotune_transform_for_execution( + bw_trace, + executors_list=compile_data.executors_list, + visualizer=visualizer + ) + else: + bw_extrace = transform_for_execution( + bw_trace, + executors_list=compile_data.executors_list, + ) bw_traces.append(bw_extrace) visualizer.set_bw_optimized_trace(bw_extrace) From 6f93cba49ac07791cccb28c2999fc2951a8669d4 Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 16 Jul 2024 14:53:50 +0200 Subject: [PATCH 014/171] Added torch empty cache during benchmark execution / added trace provenance to optimizer object --- examples/dev/LLaMAMLP.py | 67 ++++++++++++++++++-------- examples/dev/simple.py | 49 ++++++++++++++----- thunder/backend_optimizer/optimizer.py | 41 +++++++++++----- 3 files changed, 116 insertions(+), 41 deletions(-) diff --git a/examples/dev/LLaMAMLP.py b/examples/dev/LLaMAMLP.py index 2a68d16e49..500bfb0bf4 100644 --- a/examples/dev/LLaMAMLP.py +++ b/examples/dev/LLaMAMLP.py @@ -1,3 +1,4 @@ +import inspect import torch import thunder import time @@ -15,31 +16,59 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.proj(x) with torch.device('cuda'): - a = 4096 * 3 - b = 11008 * 3 - model = LLaMAMLP(a, b) + a = 4096 * 1 + b = 11008 * 1 x = torch.randn(2, 2048, a, requires_grad=True) - jmodel = thunder.jit(model) + jmodel_def = thunder.jit(LLaMAMLP(a, b), autotune_executors=False) + jmodel_auto = thunder.jit(LLaMAMLP(a, b), autotune_executors=True) + warm_up_iters = 2 + iters = 10 - tot_time = 0 - iters = 12 for i in range(iters): - start = time.perf_counter_ns() - ans = jmodel(x) + start_fw = time.perf_counter_ns() + y = jmodel_auto(x) torch.cuda.synchronize() - end = time.perf_counter_ns() + end_fw = time.perf_counter_ns() + grad_outputs = torch.ones_like(y) + torch.cuda.synchronize() + start_bw = time.perf_counter_ns() + torch.autograd.grad(y, x, grad_outputs=grad_outputs) + torch.cuda.synchronize() + end_bw = time.perf_counter_ns() + torch.cuda.empty_cache() + # source = inspect.getsource(y.grad_fn.compiled_backward) - # Skip the model without cache - if i > 1: - tot_time += (end - start) - print(f'tot time = {(end - start) / 1000000} ms') + if i >= warm_up_iters: + print(f'tot time auto forward = {(end_fw - start_fw) / 1000000} ms') + print(f'tot time auto backward = {(end_bw - start_bw) / 1000000} ms') + torch.cuda.synchronize() + torch.cuda.empty_cache() + + for i in range(iters): + start_fw = time.perf_counter_ns() + y = jmodel_def(x) + torch.cuda.synchronize() + end_fw = time.perf_counter_ns() + grad_outputs = torch.ones_like(y) + torch.cuda.synchronize() + start_bw = time.perf_counter_ns() + torch.autograd.grad(y, x, grad_outputs=grad_outputs) + torch.cuda.synchronize() + end_bw = time.perf_counter_ns() + torch.cuda.empty_cache() + # source = inspect.getsource(y.grad_fn.compiled_backward) - # for t in thunder.last_traces(jmodel): - # print(t) - print(thunder.last_traces(jmodel)[-1]) - print(thunder.last_backward_traces(jmodel)[-1]) - print(f'Mean time = {(tot_time/(iters-2))/1000000} ms') + if i >= warm_up_iters: + print(f'tot time def forward = {(end_fw - start_fw) / 1000000} ms') + print(f'tot time def backward = {(end_bw - start_bw) / 1000000} ms') + print('\n\n\n\n\n\n') + print(f'{thunder.last_traces(jmodel_def)[-1]}') + print('###############################################################################') + print(f'{thunder.last_traces(jmodel_auto)[-1]}') - print('deviation:', (jmodel(x) - model(x)).abs().max().item()) + print('\n\n') + print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') + print('###############################################################################') + print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') diff --git a/examples/dev/simple.py b/examples/dev/simple.py index b33f31b21e..e0c7d65e11 100644 --- a/examples/dev/simple.py +++ b/examples/dev/simple.py @@ -1,6 +1,7 @@ import torch import thunder import time +import inspect class Module(torch.nn.Module): def __init__(self, in_features, out_features) -> None: @@ -19,19 +20,45 @@ def forward(self, x: torch.Tensor): multiplier = 1000 in_features = 20 * multiplier out_features = 30 * multiplier - model = Module(in_features, out_features) - x = torch.randn(128, in_features) - jmodel = thunder.jit(model, autotune_executors=True, executors=['nvfuser', 'torchcompile', 'torch']) - - for _ in range(10): - start = time.perf_counter_ns() - ans = jmodel(x) - torch.autograd.grad(ans.sum(), model.parameters()) + jmodel_default = thunder.jit(Module(in_features, out_features), autotune_executors=False) + jmodel_autotune = thunder.jit(Module(in_features, out_features), autotune_executors=True) + x = torch.randn(128, in_features, requires_grad=True) + warm_up_iters = 3 + for i in range(10): + start_fw = time.perf_counter_ns() + y = jmodel_default(x) + torch.cuda.synchronize() + end_fw = time.perf_counter_ns() + grad_outputs = torch.ones_like(y) torch.cuda.synchronize() - end = time.perf_counter_ns() + start_bw = time.perf_counter_ns() + torch.autograd.grad(y, x, grad_outputs=grad_outputs) + torch.cuda.synchronize() + end_bw = time.perf_counter_ns() + torch.cuda.empty_cache() + # source = inspect.getsource(y.grad_fn.compiled_backward) + + if i >= warm_up_iters: + print(f'tot time default forward = {(end_fw - start_fw) / 1000000} ms') + print(f'tot time default backward = {(end_bw - start_bw) / 1000000} ms') - print(f'tot time = {(end - start) / 1000000} ms') + for i in range(10): + start_fw = time.perf_counter_ns() + y = jmodel_autotune(x) + torch.cuda.synchronize() + end_fw = time.perf_counter_ns() + grad_outputs = torch.ones_like(y) + torch.cuda.synchronize() + start_bw = time.perf_counter_ns() + torch.autograd.grad(y, x, grad_outputs=grad_outputs) + torch.cuda.synchronize() + end_bw = time.perf_counter_ns() + torch.cuda.empty_cache() + # source = inspect.getsource(y.grad_fn.compiled_backward) - print(thunder.last_backward_traces(jmodel)[-1]) + if i >= warm_up_iters: + print(f'tot time autotune forward = {(end_fw - start_fw) / 1000000} ms') + print(f'tot time autotune backward = {(end_bw - start_bw) / 1000000} ms') + # print('\n\n', thunder.last_backward_traces(jmodel)[-1]) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index d22d78da9e..0052d961bb 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -7,7 +7,7 @@ from thunder.core.trace import from_trace, set_tracectx, reset_tracectx, get_tracectx, TraceCtx from thunder.core.utils import check, safe_map_flat from thunder.executors.data_dependent_partition import Graph, Node -from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_all_executors, get_always_executors +from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_all_executors, get_always_executors, resolve_executors from thunder.visualizer.visualizer_helper import Visualizer from typing import Any, Hashable import thunder @@ -28,13 +28,15 @@ class BackendOptimizer(): def log(self, what: str): print(f'================================================================================ Autotune: {what}') - def __init__(self, trace: TraceCtx, executors: Sequence[Executor], produce_log=True, log_file_name='autotune_debug.log', visualizer: Visualizer | None = None) -> None: + def __init__(self, trace: TraceCtx, priority_executors: Sequence[Executor], produce_log=True, log_file_name='autotune_debug.log', visualizer: Visualizer | None = None) -> None: + # Add more supported ones + self.executors: Sequence[Executor] = resolve_executors(['torch', 'python', 'nvfuser', 'torchcompile', 'sdpa', 'cudnn']) + self.priority_executors: Sequence[Executor] = priority_executors self.trace: TraceCtx = trace self.incremental_search_out_trace: TraceCtx self.optimal_trace: TraceCtx = trace self.computation_graph: Graph = Graph(trace) - self.executors: Sequence[Executor] = executors - self.fusion_executors: Sequence[FusionExecutor] = [ex for ex in executors if isinstance(ex, FusionExecutor)] + self.fusion_executors: Sequence[FusionExecutor] = [ex for ex in self.executors if isinstance(ex, FusionExecutor)] self.empty_executor_hashable_placeholder: str = 'empty' self.placement_options: list[list[Executor]] = [] self.optimized_traces: list[dict[str, TraceCtx]] = [] @@ -69,8 +71,13 @@ def __init__(self, symbol: BoundSymbolInterface, idx: int)-> None: self.symbol = symbol self.idx = idx + class TraceType(Enum): + COMPUTATIONAL = 0 + FW = 1 + BW = 2 + # TODO (matteochen): this has a lot in common with the exaustive search, compact them - def build_placement_options_incremental(self): + def build_placement_options_incremental(self, whoami: TraceType = TraceType.COMPUTATIONAL): import sys old_max_recursion = sys.getrecursionlimit() @@ -79,6 +86,14 @@ def build_placement_options_incremental(self): # Last index inclusive def benchmark_partial_trace(trace_in: TraceCtx, last_idx: int, configuration: list[Executor]) -> tuple[float, TraceCtx]: + def safe_update_dict(d: dict, key_one, key_two, value): + if key_one not in d: + d[key_one] = { + key_two: value + } + else: + d[key_one][key_two] = value + # Retrive all output tensors from each subregion tensors = [] for i in range(last_idx+1): @@ -103,14 +118,14 @@ def benchmark_partial_trace(trace_in: TraceCtx, last_idx: int, configuration: li # Place the assigned symbols placed_t = self.place_optimizers(t, configuration) - cost, answer = benchmark_trace(placed_t, iters=10) + cost, answer = benchmark_trace(placed_t, iters=5) del answer self.log(f'Executing partial trace for incremental benchmark:\n{placed_t}') self.log(f'Symbol under test = {t.bound_symbols[-2].sym.name}') self.log(f'Assigned executor = {configuration[-2].name}') self.log(f'Time = {cost/1000000} ms') # TODO (matteochen): log this to file - self.partial_costs[t] = cost + safe_update_dict(self.partial_costs, whoami, t, cost) return cost, placed_t # We assign an internal id to each symbol based on its idx inside the bound_symbols list @@ -409,7 +424,7 @@ def greedy(): self.optimized_traces.append({'fused_greedy': trace_greedy_fused}) # 3. Try the priority list approach - trace_priority = transform_for_execution(self.trace, self.executors) + trace_priority = transform_for_execution(self.trace, self.priority_executors) self.optimized_traces.append({'priority_list': trace_priority}) # There are no hidden placements hence do not call the visualizer @@ -465,7 +480,7 @@ def benchmark_traces(self): del res self.debug_msg += f'Trace name = [{label}] - Time = {trace_time / 1000000} ms\n{trace}\n\n' self.log(f'Benchmark trace "{label}" (time = {trace_time / 1000000} ms):\n{trace}') - if trace_time < min_run_time and label=='priority_list': + if trace_time < min_run_time: min_run_time = trace_time optimal_trace = trace best_label = label @@ -486,14 +501,18 @@ def benchmark_trace(trace: TraceCtx, iters: int = 1) -> tuple[float, Any]: input_args = [] def compute_time_cost(fn: Callable, iters: int, *args) -> tuple[float, Any]: + warm_up_iters = 3 total_time = 0 out = None - for _ in range(iters): + torch.cuda.empty_cache() + for i in range(iters + warm_up_iters): time_s = time.perf_counter_ns() out = fn(*args) torch.cuda.synchronize() time_e = time.perf_counter_ns() - total_time += (time_e - time_s) + torch.cuda.empty_cache() + if i >= warm_up_iters: + total_time += (time_e - time_s) return total_time / iters, out From 0945159c232a70708c827910421d49cb28f25b28 Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 16 Jul 2024 23:36:55 +0200 Subject: [PATCH 015/171] Wip fuser ex focus --- thunder/backend_optimizer/optimizer.py | 276 ++++++++++++++++++++++++- 1 file changed, 275 insertions(+), 1 deletion(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 0052d961bb..d8aaaf1be8 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -2,7 +2,7 @@ from enum import Enum from itertools import chain from thunder.core.baseutils import BoundSymbolInterface -from thunder.core.proxies import Proxy, TensorProxy, variableify, Variable +from thunder.core.proxies import CollectionProxy, Proxy, TensorProxy, variableify, Variable from thunder.core.symbol import BoundSymbol, Symbol from thunder.core.trace import from_trace, set_tracectx, reset_tracectx, get_tracectx, TraceCtx from thunder.core.utils import check, safe_map_flat @@ -405,6 +405,9 @@ def optimize(self, strat: OptimizationStrat = OptimizationStrat.GREEDY): import thunder.core.codeutils as cutils from thunder.executors.passes import transform_for_execution + self.split_fusion_executors_placement() + return + def greedy(): # 1. This builds one option by default self.build_placement_options_incremental() @@ -455,6 +458,277 @@ def exaustive(): else: raise AssertionError('Optimization strat not implemented') + # For each node retrive subgraph + # 1 For each fusion executor compute the min common subgraph + # 2 Select the best runtime for the current subgraph + # 3 Mark nodes as visited + # 4 Nodes which can not be fused do select the best backend + + def split_fusion_executors_placement(self): + import pprint + from thunder.executors.data_dependent_partition import Node, fuse_bound_symbols + + # TODO: parametrize + increment = 1 + + def sequence_hash(s: Sequence) -> str: + name = "" + for e in s: + if isinstance(e, CollectionProxy) or isinstance(e, TensorProxy): + name += e.name + elif e is None: + name += "None" + else: + raise AssertionError(f'What? type = {type(e)}') + return name + + # Benchmark the optimal executor and call this optimal + def get_default_executor(bsym: BoundSymbol): + for ex in self.executors: + if isinstance(ex, FusionExecutor): + continue + if ex.can_execute(bsym): + return ex + return Executor(name=self.empty_executor_hashable_placeholder) + + def get_placed_trace(mapping: dict[str, Executor], bound_symbols: Sequence[BoundSymbol]): + import pprint + self.log(f'Input mapping len = {len(mapping)}:') + pprint.pprint(mapping) + trc = from_trace(self.trace) + trc.bound_symbols = list(bound_symbols) + + # for b in trc.bound_symbols: + # print(b.sym.name) + + # print(f'trc:\n{trc}') + + # Check if this naming is always valid + def is_possible_out(name: str): + num = name[1:] + return num.isdigit() + + def find_original_return_tensors(trace_in: TraceCtx) -> list[str]: + return_bsym = trace_in.bound_symbols[-1] + if return_bsym.sym.name != 'return': + raise AssertionError(f'Expected return symbol got {return_bsym.sym.name}') + + ans = [] + # forward trace + if isinstance(return_bsym.args, tuple): + if isinstance(return_bsym.args[0], dict): + ans.append(return_bsym.args[0]['output']) + else: + ans.extend([s for s in return_bsym.args if s is not None]) + else: + raise AssertionError('Not supported') + + return ans + + def find_last_out_tensor(trace_in: TraceCtx): + m = 0 + t = None + for b in trace_in.bound_symbols: + if b.sym.name == 'return': + continue + if isinstance(b.output, TensorProxy): + if is_possible_out(b.output.name) and int(b.output.name[1:]) > m: + m = int(b.output.name[1:]) + t = b.output + # else: + # raise AssertionError(f'Not implemented, type = {type(b.output)}') + if t is None: + raise AssertionError('Max tensor output not found') + print(f'max tensor out name: {t}') + return t + + return_tensor = [find_last_out_tensor(trc)] + original_returns = find_original_return_tensors(self.trace) + for t in original_returns: + if t in trc.bound_symbols: + return_tensor.append(t) + + print(f'Return tensors: {return_tensor}') + + forced_return_bsym = self.trace.bound_symbols[-1].from_bsym(args=return_tensor) # Should not be an Interface type at this point + + executor_configuration = [] + empty_executor = Executor(name=self.empty_executor_hashable_placeholder) + keys = [] + for bsym in trc.bound_symbols: + print(f'current bsym {bsym.sym.name} -> type out {type(bsym.output)}') + if bsym.sym.name == 'return': + executor_configuration.append(empty_executor) + keys.append('return') + elif isinstance(bsym.output, Sequence): + seq_hash = sequence_hash(bsym.output) + executor_configuration.append(mapping.get(seq_hash, empty_executor)) + keys.append(seq_hash) + elif isinstance(bsym.output, CollectionProxy) or isinstance(bsym.output, TensorProxy): + if bsym.output.name not in mapping: + raise AssertionError(f'Expected key {bsym.output.name} in mapping {mapping}') + executor_configuration.append(mapping[bsym.output.name]) + keys.append(bsym.output.name) + else: + raise AssertionError(f"Type not handled: {type(bsym.output)}") + + if trc.bound_symbols[-1].sym.name != 'return': + trc.bound_symbols.append(forced_return_bsym) + executor_configuration.append(Executor(name=self.empty_executor_hashable_placeholder)) + keys.append('return') + + if len(trc.bound_symbols) != len(executor_configuration) or len(keys) != len(executor_configuration): + raise AssertionError(f'len trc.bound_symbols ({len(trc.bound_symbols)}) != len executor_configuration ({len(executor_configuration)}) != len keys ({len(keys)})') + + placed_trace = self.place_optimizers(trc, executor_configuration) + return placed_trace, keys, executor_configuration + + best_ex_time = float('inf') + best_ex_trace = None + ex: FusionExecutor + for ex in self.fusion_executors: + # TODO (matteochen): fix + if not ex.name == 'nvfuser' and not ex.name == 'torchcompile': + raise AssertionError(f'Fusion operator not supported: {ex.name}') + + self.log(f'Searching best placement for fusion executor = {ex.name}') + # TODO (matteochen): each executor has a custo def + def _should_fuse_nvfuser(a: Node, b: Node): + def _can_fuse_node(n: Node): + # if already merged, then node can be fused + if len(n.group_bsyms) > 1: + return True + bsym: BoundSymbol = n.group_bsyms[0] + can_fuse: bool = ex.can_fuse(bsym) + cuda_in_or_out: bool = ex.has_cuda_input_or_output(bsym) + return can_fuse and cuda_in_or_out + return _can_fuse_node(a) and _can_fuse_node(b) + + def _should_fuse_torchcompile(a: Node, b: Node): + def _can_fuse_node(n: Node): + if len(n.group_bsyms) > 1: + return True + bsym: BoundSymbol = n.group_bsyms[0] + return ex.can_fuse(bsym) + return _can_fuse_node(a) and _can_fuse_node(b) + + bound_symbol_groups =fuse_bound_symbols(self.trace, _should_fuse_nvfuser if ex.name == 'nvfuser' else _should_fuse_torchcompile) + self.log(f'Num of groups = {len(bound_symbol_groups)}') + + for group in bound_symbol_groups: + for sub in group: + print(f'{sub.sym.name} -> out: {sub.output}') + if len(group) > 0: + print('\n') + + t_map: dict[str, Executor] = {} + increasing_symbols = [] + for i, group in enumerate(bound_symbol_groups): + print(f'group start = {group[0].sym.name}') + print(f'group end = {group[-1].sym.name}') + + # Is not a fusion region, get the default executor + increasing_symbols += group + if len(group) < 2: + symbol = group[0] + print(f'Single: {symbol.sym.name}') + name = symbol.sym.name + ex_for_this = get_default_executor(symbol) + if name == 'return': + t_map['return'] = ex_for_this + elif isinstance(symbol.output, Sequence): + t_map[sequence_hash(symbol.output)] = ex_for_this + elif isinstance(symbol.output, CollectionProxy) or isinstance(symbol.output, TensorProxy): + t_map[symbol.output.name] = ex_for_this + continue + + # Inside groups we should have alwasy tensors as out + # -> First iteration is the one with no fusion regions + # -> Last iteration gives the complete fusion region + best_trc = None + best_time = float('inf') + best_placement = None + best_keys = None + for i in range(len(group)): + # From top to bottom + for j in range(0, i+1, increment): + t_map[group[j].output.name] = ex + for k in range(i+1, len(group), increment): + t_map[group[k].output.name] = get_default_executor(group[k]) + + # Benchmark this placement + trc, keys, placements = get_placed_trace(t_map, increasing_symbols) + cost, out = benchmark_trace(trc, iters=2) + del out + self.log(f'Placed trace (cost = {cost / 1000000} ms)\n{trc}') + if cost < best_time: + best_time = cost + best_trc = trc + best_placement = placements + best_keys = keys + + # From bottom to up + for j in range(0, i+1, increment): + t_map[group[j].output.name] = get_default_executor(group[j]) + for k in range(i+1, len(group), increment): + t_map[group[k].output.name] = ex + + # Benchmark this placement + trc, keys, placements = get_placed_trace(t_map, increasing_symbols) + cost, out = benchmark_trace(trc, iters=2) + del out + self.log(f'Placed trace (cost = {cost / 1000000} ms)\n{trc}') + if cost < best_time: + best_time = cost + best_trc = trc + best_placement = placements + best_keys = keys + if best_placement is None or best_keys is None: + raise AssertionError('Failed to get best placement') + + self.log(f'For group {i} best placement with cost = {best_time / 1000000} ms is:\n{best_trc}') + + # for n, p in zip(best_keys, best_placement): + # print(f'{n} -> {p.name}') + + # Update our dict + for n, p in zip(best_keys, best_placement): + t_map |= {n: p} + + self.log('End of group search') + pprint.pprint(t_map) + + # Generate final trace + trc = from_trace(self.trace) + trc.bound_symbols = list(self.trace.bound_symbols) + executors = [] + for bsym in trc.bound_symbols: + if bsym.sym.name == 'return': + if 'return' not in t_map: + raise AssertionError(f'Expected key return in mapping {t_map}') + executors.append(t_map['return']) + elif isinstance(bsym.output, Sequence): + seq_hash = sequence_hash(bsym.output) + if seq_hash not in t_map: + raise AssertionError(f'Expected key {seq_hash} in mapping {t_map}') + executors.append(t_map[seq_hash]) + elif isinstance(bsym.output, CollectionProxy) or isinstance(bsym.output, TensorProxy): + if bsym.output.name not in t_map: + raise AssertionError(f'Expected key {bsym.output.name} in mapping {t_map}') + executors.append(t_map[bsym.output.name]) + else: + raise AssertionError(f"Type not handled: {type(bsym.output)}") + trc = self.place_optimizers(trc, executors) + + # Update res + cost, out = benchmark_trace(trc) + self.log(f'Final trace for ex {ex.name}, cost = {cost / 1000000} ms:\n{trc}') + del out + if cost < best_ex_time: + best_ex_time = cost + best_ex_trace = trc + self.log(f'Selected trace from fusion split optimizer:\n{best_ex_trace}') + def get_optimal_trace(self) -> TraceCtx: return self.optimal_trace From 24ad50958e6ce7a716e3dce8bb7fd0762b0c9e3a Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 19 Jul 2024 11:50:37 +0200 Subject: [PATCH 016/171] Fusion executor placements / runtime vs memory placements options --- examples/dev/LLaMAMLP.py | 124 +++-- thunder/__init__.py | 13 +- thunder/backend_optimizer/optimizer.py | 655 ++++++++++++++++--------- thunder/executors/passes.py | 8 +- thunder/executors/torch_autograd.py | 8 +- 5 files changed, 542 insertions(+), 266 deletions(-) diff --git a/examples/dev/LLaMAMLP.py b/examples/dev/LLaMAMLP.py index 500bfb0bf4..545aa856ef 100644 --- a/examples/dev/LLaMAMLP.py +++ b/examples/dev/LLaMAMLP.py @@ -1,7 +1,5 @@ -import inspect import torch import thunder -import time class LLaMAMLP(torch.nn.Module): def __init__(self, n_embd, intermediate_size) -> None: @@ -16,53 +14,90 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.proj(x) with torch.device('cuda'): + from thunder.backend_optimizer.optimizer import OptimizerType, benchmark_trace a = 4096 * 1 b = 11008 * 1 x = torch.randn(2, 2048, a, requires_grad=True) - jmodel_def = thunder.jit(LLaMAMLP(a, b), autotune_executors=False) - jmodel_auto = thunder.jit(LLaMAMLP(a, b), autotune_executors=True) + jmodel_def = thunder.jit(LLaMAMLP(a, b)) + jmodel_auto = thunder.jit(LLaMAMLP(a, b), autotune_type='memory') warm_up_iters = 2 iters = 10 + stream = torch.cuda.current_stream() - for i in range(iters): - start_fw = time.perf_counter_ns() + for _ in range(warm_up_iters): y = jmodel_auto(x) - torch.cuda.synchronize() - end_fw = time.perf_counter_ns() - grad_outputs = torch.ones_like(y) - torch.cuda.synchronize() - start_bw = time.perf_counter_ns() - torch.autograd.grad(y, x, grad_outputs=grad_outputs) - torch.cuda.synchronize() - end_bw = time.perf_counter_ns() - torch.cuda.empty_cache() - # source = inspect.getsource(y.grad_fn.compiled_backward) + yy = jmodel_def(x) + torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y)) + torch.autograd.grad(yy, x, grad_outputs=torch.ones_like(y)) + + print('\n\n') - if i >= warm_up_iters: - print(f'tot time auto forward = {(end_fw - start_fw) / 1000000} ms') - print(f'tot time auto backward = {(end_bw - start_bw) / 1000000} ms') + for i in range(1): - torch.cuda.synchronize() - torch.cuda.empty_cache() + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + middle_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + + for i in range(iters): + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + start_events[i].record(stream) + y = jmodel_auto(x) + middle_events[i].record(stream) + torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y)) + end_events[i].record(stream) - for i in range(iters): - start_fw = time.perf_counter_ns() - y = jmodel_def(x) - torch.cuda.synchronize() - end_fw = time.perf_counter_ns() - grad_outputs = torch.ones_like(y) torch.cuda.synchronize() - start_bw = time.perf_counter_ns() - torch.autograd.grad(y, x, grad_outputs=grad_outputs) + fw = [s.elapsed_time(e) for s, e in zip(start_events, middle_events)] + bw = [s.elapsed_time(e) for s, e in zip(middle_events, end_events)] + tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + fw_time = sum(fw) + bw_time = sum(bw) + tot_time = sum(tot) + print(f'Auto fw: {fw_time / iters}') + print(f'Auto bw: {bw_time / iters}') + print(f'Auto tot: {tot_time / iters}') + print('\n') + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + middle_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + + for i in range(iters): + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + start_events[i].record(stream) + y = jmodel_def(x) + middle_events[i].record(stream) + torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y)) + end_events[i].record(stream) + torch.cuda.synchronize() - end_bw = time.perf_counter_ns() - torch.cuda.empty_cache() - # source = inspect.getsource(y.grad_fn.compiled_backward) + fw = [s.elapsed_time(e) for s, e in zip(start_events, middle_events)] + bw = [s.elapsed_time(e) for s, e in zip(middle_events, end_events)] + tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + fw_time = sum(fw) + bw_time = sum(bw) + tot_time = sum(tot) + print(f'Default fw: {fw_time / iters}') + print(f'Default bw: {bw_time / iters}') + print(f'Default tot: {tot_time / iters}') + print('-------------------------------------------------------') + + c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='def_fw') + print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') + del o + c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='auto_fw') + print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') + del o + c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='def_bw') + print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') + del o + c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='auto_bw') + print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') + del o - if i >= warm_up_iters: - print(f'tot time def forward = {(end_fw - start_fw) / 1000000} ms') - print(f'tot time def backward = {(end_bw - start_bw) / 1000000} ms') print('\n\n\n\n\n\n') print(f'{thunder.last_traces(jmodel_def)[-1]}') print('###############################################################################') @@ -72,3 +107,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') print('###############################################################################') print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') + + from torch.profiler import profile, record_function, ProfilerActivity + with profile(activities=[ + ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: + with record_function("def"): + y = jmodel_def(x) + grad_outputs = torch.ones_like(y) + torch.autograd.grad(y, x, grad_outputs=grad_outputs) + + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + with profile(activities=[ + ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: + with record_function("auto"): + y = jmodel_auto(x) + grad_outputs = torch.ones_like(y) + torch.autograd.grad(y, x, grad_outputs=grad_outputs) + + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) diff --git a/thunder/__init__.py b/thunder/__init__.py index c38482ddbc..5840aa1f23 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -285,7 +285,7 @@ def jit( additional_transforms: list[AdditionalTransform] | None = None, post_optimization_transforms: list[PostOptimizationTransform] | None = None, record_history: bool = False, - autotune_executors: bool = True, + autotune_type: Any | None = None, **compile_options, # TODO RC1 Make this explicit -- dict of options ) -> Callable: """Just-in-time compile a callable (function or model). @@ -308,6 +308,7 @@ def jit( transforms: List of transforms to be applied to the computation trace. It should be an instance :class:`thunder.core.transforms.AdditionalTransform`. Default: ``None`` post_optimization_transforms: List of transforms to be applied to the optimized computation traces i.e. forward and backward traces. It should be an instance :class:`thunder.core.transforms.PostOptimizationTransform`. Default: ``None`` """ + from thunder.backend_optimizer.optimizer import OptimizerType if "executors_list" in compile_options: warnings.warn("outdated argument executors_list= in call, please use executors=") @@ -333,6 +334,14 @@ def jit( if post_optimization_transforms is None: post_optimization_transforms = [] + if autotune_type is not None: + if autotune_type == 'runtime': + autotune_type = OptimizerType.RUNTIME + elif autotune_type == 'memory': + autotune_type = OptimizerType.MEMORY + else: + raise AssertionError(f'Not supported optimization: {autotune_type}') + # Resolve names of executors executors = resolve_executors(executors) @@ -583,7 +592,7 @@ def get_computation_and_inputs(*args, **kwargs): # transform_for_execution and various sorting of symbols, # applying transform_for_execution after this would be # breaking the order of operations - computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, autotune_executors, *inps) + computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, autotune_type, *inps) # Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces # by split_forward_backward extraces = cs.last_traces diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index d8aaaf1be8..a45dad7b3f 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -1,21 +1,25 @@ from collections.abc import Callable, Sequence from enum import Enum from itertools import chain +from thunder.core.prims import PrimIDs +from thunder.core.utils import check, safe_map_flat from thunder.core.baseutils import BoundSymbolInterface from thunder.core.proxies import CollectionProxy, Proxy, TensorProxy, variableify, Variable from thunder.core.symbol import BoundSymbol, Symbol from thunder.core.trace import from_trace, set_tracectx, reset_tracectx, get_tracectx, TraceCtx -from thunder.core.utils import check, safe_map_flat from thunder.executors.data_dependent_partition import Graph, Node from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_all_executors, get_always_executors, resolve_executors from thunder.visualizer.visualizer_helper import Visualizer from typing import Any, Hashable import thunder import thunder.core.transforms as transforms -import time import torch # import concurrent.futures +class OptimizerType(Enum): + MEMORY = 1 + RUNTIME = 2 + class OptimizerNode(): def __init__(self, node: Node): self.node: Node = node @@ -28,33 +32,43 @@ class BackendOptimizer(): def log(self, what: str): print(f'================================================================================ Autotune: {what}') - def __init__(self, trace: TraceCtx, priority_executors: Sequence[Executor], produce_log=True, log_file_name='autotune_debug.log', visualizer: Visualizer | None = None) -> None: + def __init__(self, trace: TraceCtx, priority_executors: Sequence[Executor], produce_log=True, log_file_name='autotune_debug.log', visualizer: Visualizer | None = None, optimizer_type: OptimizerType = OptimizerType.RUNTIME) -> None: + from thunder.core.transform_common import dce # Add more supported ones - self.executors: Sequence[Executor] = resolve_executors(['torch', 'python', 'nvfuser', 'torchcompile', 'sdpa', 'cudnn']) - self.priority_executors: Sequence[Executor] = priority_executors - self.trace: TraceCtx = trace - self.incremental_search_out_trace: TraceCtx - self.optimal_trace: TraceCtx = trace + self.trace: TraceCtx = dce(trace) + self.always_executors: tuple[Executor, ...] = get_always_executors() self.computation_graph: Graph = Graph(trace) - self.fusion_executors: Sequence[FusionExecutor] = [ex for ex in self.executors if isinstance(ex, FusionExecutor)] + self.debug_msg: str = "" self.empty_executor_hashable_placeholder: str = 'empty' - self.placement_options: list[list[Executor]] = [] - self.optimized_traces: list[dict[str, TraceCtx]] = [] - self.always_executors: tuple[Executor, ...] = get_always_executors() - self.produce_log: bool = produce_log + self.executors: Sequence[Executor] = resolve_executors(['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python']) + self.fusion_executors: Sequence[FusionExecutor] = [ex for ex in self.executors if isinstance(ex, FusionExecutor)] + self.incremental_search_out_trace: TraceCtx self.log_file_name: str = log_file_name - self.debug_msg: str = "" - self.visualizer: Visualizer | None = visualizer + self.optimal_trace_mem: TraceCtx = trace + self.optimal_trace_time: TraceCtx = trace + self.optimized_traces_mem: list[dict[str | Hashable, TraceCtx]] = [] + self.optimized_traces_mem_benchmark_only: list[dict[str | Hashable, TraceCtx]] = [] + self.optimized_traces_time: list[dict[str | Hashable, TraceCtx]] = [] + self.optimized_traces_time_benchmark_only: list[dict[str | Hashable, TraceCtx]] = [] self.partial_costs: dict[TraceCtx, float] = {} + self.placement_options_mem: list[list[Executor]] = [] + self.placement_options_time: list[list[Executor]] = [] + self.priority_executors: Sequence[Executor] = priority_executors + self.produce_log: bool = produce_log + self.strat = None + self.supported_fusion_executors_by_fusion_strat: set = set(['nvfuser', 'torchcompile']) + self.visualizer: Visualizer | None = visualizer + self.optimizer_type: OptimizerType = optimizer_type - self.log(f'New trace to optimize:\n{self.trace}') + self.log(f'New trace to optimize (strat = {self.optimizer_type}):\n{self.trace}') self.log('Executors:') for o in self.executors: - print(f'{o.name} -> {type(o)}, is operator = {isinstance(o, OperatorExecutor)}, is fusion = {isinstance(o, FusionExecutor)}') + self.log(f'{o.name} -> is operator = {isinstance(o, OperatorExecutor)}, is fusion = {isinstance(o, FusionExecutor)}') class OptimizationStrat(Enum): EXAUSTIVE = 1 GREEDY = 2 + BEST_FUSER = 3 # TODO (matteochen): fix this def __repr__(self) -> str: @@ -102,9 +116,6 @@ def safe_update_dict(d: dict, key_one, key_two, value): s = trace_in.bound_symbols[i] # For each bsym region we expect to output a Tensor tensors.append(s.output) - # print('Tensors inside partial trace') - # for t in tensors: - # print(t) forced_return_bsym = trace_in.bound_symbols[-1].from_bsym(args=tensors) # Should not be an Interface type at this point @@ -118,12 +129,12 @@ def safe_update_dict(d: dict, key_one, key_two, value): # Place the assigned symbols placed_t = self.place_optimizers(t, configuration) - cost, answer = benchmark_trace(placed_t, iters=5) + cost, mem, answer = benchmark_trace(placed_t, iters=5) del answer self.log(f'Executing partial trace for incremental benchmark:\n{placed_t}') self.log(f'Symbol under test = {t.bound_symbols[-2].sym.name}') self.log(f'Assigned executor = {configuration[-2].name}') - self.log(f'Time = {cost/1000000} ms') + self.log(f'Time = {cost} ms') # TODO (matteochen): log this to file safe_update_dict(self.partial_costs, whoami, t, cost) return cost, placed_t @@ -195,7 +206,7 @@ def continue_search(): def try_to_fuse_after_executors_placement(self, trace_in: TraceCtx) -> TraceCtx: best_trace: TraceCtx = trace_in - best_time, answer = benchmark_trace(best_trace, iters=10) + best_time, best_mem, answer = benchmark_trace(best_trace, iters=10) del answer trace_in_time = best_time @@ -203,16 +214,16 @@ def try_to_fuse_after_executors_placement(self, trace_in: TraceCtx) -> TraceCtx: self.log(f'Try to fuse executor {ex.name} with trace:\n{trace_in}') extrace = ex.fusion_pass(trace_in) self.log(f'Fused trace:\n{extrace}') - extrace_time, answer = benchmark_trace(extrace, iters=10) + extrace_time, extrace_mem, answer = benchmark_trace(extrace, iters=10) del answer - self.log(f'Fused trace time:{extrace_time/1000000} ms') + self.log(f'Fused trace time:{extrace_time} ms') if extrace_time < best_time: best_time = extrace_time best_trace = extrace - self.log(f'Trace in (time = {trace_in_time / 1000000} ms):\n{trace_in}') - self.log(f'Best fused trace (time = {best_time / 1000000} ms):\n{best_trace}') + self.log(f'Trace in (time = {trace_in_time } ms):\n{trace_in}') + self.log(f'Best fused trace (time = {best_time } ms):\n{best_trace}') return best_trace @@ -362,7 +373,7 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: cached_subsymbols: dict[str, Sequence[BoundSymbol]] = {} if len(executor_list) != len(extrace.bound_symbols): - raise AssertionError("Invalid executor - bound_symbols lenght") + raise AssertionError("len(executor_list) != len(extrace.bound_symbols)") for ex, bsym in zip(executor_list, extrace.bound_symbols): if isinstance(ex, FusionExecutor): @@ -399,14 +410,25 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: return extrace - # TODO (matteochen): add config for exaustive search or incremental one - def optimize(self, strat: OptimizationStrat = OptimizationStrat.GREEDY): + def optimize(self, strat: OptimizationStrat = OptimizationStrat.BEST_FUSER): import thunder.core.codeutils as cutils + + self.strat = strat + from thunder.executors.passes import transform_for_execution + def best_fuser(): + self.build_placement_options_fusion_regions() + + if len(self.placement_options_time) != len(self.fusion_executors): + raise AssertionError("Unexpected time placement options size") + if len(self.placement_options_mem) != len(self.fusion_executors): + raise AssertionError("Unexpected mem placement options size") - self.split_fusion_executors_placement() - return + for placement, ex in zip(self.placement_options_time, self.fusion_executors): + self.optimized_traces_time.append({ex.name: self.place_optimizers(self.trace, placement)}) + for placement, ex in zip(self.placement_options_mem, self.fusion_executors): + self.optimized_traces_mem.append({ex.name: self.place_optimizers(self.trace, placement)}) def greedy(): # 1. This builds one option by default @@ -441,7 +463,6 @@ def exaustive(): for option in self.placement_options: option_str = [str(ex.name) for ex in option] option_str = '-'.join(option_str) - # print(f'============================================ optimizers len {len(option)}: {option_str}') trace = self.place_optimizers(self.trace, option) if self.visualizer is not None: @@ -455,22 +476,14 @@ def exaustive(): greedy() elif strat == self.OptimizationStrat.EXAUSTIVE: exaustive() + elif strat == self.OptimizationStrat.BEST_FUSER: + best_fuser() else: raise AssertionError('Optimization strat not implemented') - # For each node retrive subgraph - # 1 For each fusion executor compute the min common subgraph - # 2 Select the best runtime for the current subgraph - # 3 Mark nodes as visited - # 4 Nodes which can not be fused do select the best backend - - def split_fusion_executors_placement(self): - import pprint + def build_placement_options_fusion_regions(self, increment_factor:int = 1): from thunder.executors.data_dependent_partition import Node, fuse_bound_symbols - # TODO: parametrize - increment = 1 - def sequence_hash(s: Sequence) -> str: name = "" for e in s: @@ -482,7 +495,7 @@ def sequence_hash(s: Sequence) -> str: raise AssertionError(f'What? type = {type(e)}') return name - # Benchmark the optimal executor and call this optimal + # TODO (matteochen): Benchmark the optimal executor and call this optimal def get_default_executor(bsym: BoundSymbol): for ex in self.executors: if isinstance(ex, FusionExecutor): @@ -491,73 +504,100 @@ def get_default_executor(bsym: BoundSymbol): return ex return Executor(name=self.empty_executor_hashable_placeholder) - def get_placed_trace(mapping: dict[str, Executor], bound_symbols: Sequence[BoundSymbol]): - import pprint + def get_placed_trace(mapping: dict[str, Executor], bound_symbols_in: Sequence[BoundSymbol]): self.log(f'Input mapping len = {len(mapping)}:') - pprint.pprint(mapping) + self.log(f'Input bound_symbols len = {len(bound_symbols_in)}:') trc = from_trace(self.trace) - trc.bound_symbols = list(bound_symbols) + trc.bound_symbols = list(bound_symbols_in) # for b in trc.bound_symbols: # print(b.sym.name) # print(f'trc:\n{trc}') - # Check if this naming is always valid - def is_possible_out(name: str): - num = name[1:] - return num.isdigit() - - def find_original_return_tensors(trace_in: TraceCtx) -> list[str]: - return_bsym = trace_in.bound_symbols[-1] - if return_bsym.sym.name != 'return': - raise AssertionError(f'Expected return symbol got {return_bsym.sym.name}') - - ans = [] - # forward trace - if isinstance(return_bsym.args, tuple): - if isinstance(return_bsym.args[0], dict): - ans.append(return_bsym.args[0]['output']) - else: - ans.extend([s for s in return_bsym.args if s is not None]) - else: - raise AssertionError('Not supported') - - return ans - - def find_last_out_tensor(trace_in: TraceCtx): - m = 0 - t = None - for b in trace_in.bound_symbols: - if b.sym.name == 'return': - continue - if isinstance(b.output, TensorProxy): - if is_possible_out(b.output.name) and int(b.output.name[1:]) > m: - m = int(b.output.name[1:]) - t = b.output - # else: - # raise AssertionError(f'Not implemented, type = {type(b.output)}') - if t is None: - raise AssertionError('Max tensor output not found') - print(f'max tensor out name: {t}') - return t - - return_tensor = [find_last_out_tensor(trc)] - original_returns = find_original_return_tensors(self.trace) - for t in original_returns: - if t in trc.bound_symbols: - return_tensor.append(t) - - print(f'Return tensors: {return_tensor}') - - forced_return_bsym = self.trace.bound_symbols[-1].from_bsym(args=return_tensor) # Should not be an Interface type at this point + # def find_original_return_tensors(trace_in: TraceCtx) -> list[Any]: + # return_bsym = trace_in.bound_symbols[-1] + # if return_bsym.sym.name != 'return': + # raise AssertionError(f'Expected return symbol got {return_bsym.sym.name}') + + # ans = [] + # if isinstance(return_bsym.args, tuple): + # # forward trace + # if isinstance(return_bsym.args[0], dict): + # ans.append(return_bsym.args[0]['output']) + # # backward trace + # else: + # ans.extend([s for s in return_bsym.args if s is not None]) + # else: + # raise AssertionError('Not supported') + + # return ans + + # def find_last_out_tensor(trace_in: TraceCtx): + # m = 0 + # t = None + # for b in trace_in.bound_symbols: + # if b.sym.name == 'return': + # continue + # if isinstance(b.output, TensorProxy): + # if is_possible_out(b.output.name) and int(b.output.name[1:]) > m: + # m = int(b.output.name[1:]) + # t = b.output + # # else: + # # raise AssertionError(f'Not implemented, type = {type(b.output)}') + # if t is None: + # raise AssertionError('Max tensor output not found') + # print(f'max tensor out name: {t}') + # return t + + # def is_tensor_in_bsyms(t: TensorProxy | tuple): + # def handle_tuple(tup: tuple): + # for e in tup: + # if isinstance(e, TensorProxy): + # for b in trc.bound_symbols: + # if b is not None: + # if isinstance(b.output, TensorProxy): + # if b.output.name == e.name: + # return b.output + # else: + # raise AssertionError('Not supported') + + # if isinstance(t, TensorProxy): + # for b in trc.bound_symbols: + # if b is not None: + # if isinstance(b.output, TensorProxy): + # if b.output.name == t.name: + # return b.output + # return None + # else: + # handle_tuple(t) + + + # tensors = [] + # for b in bound_symbols_in: + # if isinstance(b.output, TensorProxy): + # tensors.append(b.output) + # We include always the last tensor as output of the partial trace + all the already + # available out tensor present in the original trace in order to not be discarded from the dce + # tensors = [find_last_out_tensor(trc)] + # original_returns = find_original_return_tensors(self.trace) + # for t in original_returns: + # # TODO (matteochen): improve this + # res = is_tensor_in_bsyms(t) + # if res is not None: + # tensors.append(res) + + # For this partial trace we have to return all not used tensors otherwise the dce will cut them out + tensors = return_not_used(trc) + + forced_return_bsym = self.trace.bound_symbols[-1].from_bsym(args=tensors) executor_configuration = [] empty_executor = Executor(name=self.empty_executor_hashable_placeholder) keys = [] for bsym in trc.bound_symbols: - print(f'current bsym {bsym.sym.name} -> type out {type(bsym.output)}') if bsym.sym.name == 'return': + raise AssertionError('return statement should not be here') executor_configuration.append(empty_executor) keys.append('return') elif isinstance(bsym.output, Sequence): @@ -580,15 +620,14 @@ def find_last_out_tensor(trace_in: TraceCtx): if len(trc.bound_symbols) != len(executor_configuration) or len(keys) != len(executor_configuration): raise AssertionError(f'len trc.bound_symbols ({len(trc.bound_symbols)}) != len executor_configuration ({len(executor_configuration)}) != len keys ({len(keys)})') + # self.log(f'Before placement trc:\n{trc}') placed_trace = self.place_optimizers(trc, executor_configuration) return placed_trace, keys, executor_configuration - best_ex_time = float('inf') - best_ex_trace = None ex: FusionExecutor for ex in self.fusion_executors: - # TODO (matteochen): fix - if not ex.name == 'nvfuser' and not ex.name == 'torchcompile': + + if ex.name not in self.supported_fusion_executors_by_fusion_strat: raise AssertionError(f'Fusion operator not supported: {ex.name}') self.log(f'Searching best placement for fusion executor = {ex.name}') @@ -617,132 +656,211 @@ def _can_fuse_node(n: Node): for group in bound_symbol_groups: for sub in group: - print(f'{sub.sym.name} -> out: {sub.output}') + self.log(f'{sub.sym.name} -> out: {sub.output}') if len(group) > 0: print('\n') - t_map: dict[str, Executor] = {} + map_time: dict[str, Executor] = {} + map_mem: dict[str, Executor] = {} increasing_symbols = [] - for i, group in enumerate(bound_symbol_groups): - print(f'group start = {group[0].sym.name}') - print(f'group end = {group[-1].sym.name}') + for group_id, group in enumerate(bound_symbol_groups): + self.log(f'group start = {group[0].sym.name}') + self.log(f'group end = {group[-1].sym.name}') + + if group[0].sym.name != 'return': + increasing_symbols += group # Is not a fusion region, get the default executor - increasing_symbols += group if len(group) < 2: symbol = group[0] - print(f'Single: {symbol.sym.name}') + self.log(f'--> Single group: {symbol.sym.name}') name = symbol.sym.name ex_for_this = get_default_executor(symbol) if name == 'return': - t_map['return'] = ex_for_this + map_time['return'] = ex_for_this + map_mem['return'] = ex_for_this + # Add the modified return statement at the end of the for loop + break elif isinstance(symbol.output, Sequence): - t_map[sequence_hash(symbol.output)] = ex_for_this + map_time[sequence_hash(symbol.output)] = ex_for_this + map_mem[sequence_hash(symbol.output)] = ex_for_this elif isinstance(symbol.output, CollectionProxy) or isinstance(symbol.output, TensorProxy): - t_map[symbol.output.name] = ex_for_this + map_time[symbol.output.name] = ex_for_this + map_mem[symbol.output.name] = ex_for_this continue # Inside groups we should have alwasy tensors as out - # -> First iteration is the one with no fusion regions - # -> Last iteration gives the complete fusion region - best_trc = None - best_time = float('inf') - best_placement = None - best_keys = None + best_res_time = self.Result() + best_res_mem = self.Result() + worst_res_time = self.Result() + worst_res_mem = self.Result() + worst_res_mem.measure = 0 + worst_res_time.measure = 0 + + best_placement_time = None + best_keys_time = None + best_placement_mem = None + best_keys_mem = None + # Each iteration of this loop will have map_time = map_mem, hence we use and fill only map_time + # Best time and best mem will be recorded separatedly though for i in range(len(group)): - # From top to bottom - for j in range(0, i+1, increment): - t_map[group[j].output.name] = ex - for k in range(i+1, len(group), increment): - t_map[group[k].output.name] = get_default_executor(group[k]) + # From top to bottom (this will include the whole region) + # -> First iteration is the one with fusion region with single element + # -> Last iteration gives the complete fusion region + for j in range(0, i+1, increment_factor): + map_time[group[j].output.name] = ex + map_mem[group[j].output.name] = ex + for k in range(i+1, len(group), increment_factor): + map_time[group[k].output.name] = get_default_executor(group[k]) + map_mem[group[k].output.name] = get_default_executor(group[k]) # Benchmark this placement - trc, keys, placements = get_placed_trace(t_map, increasing_symbols) - cost, out = benchmark_trace(trc, iters=2) + trc, keys, placements = get_placed_trace(map_time, increasing_symbols) + cost, mem, out = benchmark_trace(trc, iters=1) del out - self.log(f'Placed trace (cost = {cost / 1000000} ms)\n{trc}') - if cost < best_time: - best_time = cost - best_trc = trc - best_placement = placements - best_keys = keys - - # From bottom to up - for j in range(0, i+1, increment): - t_map[group[j].output.name] = get_default_executor(group[j]) - for k in range(i+1, len(group), increment): - t_map[group[k].output.name] = ex + self.log(f'Placed trace (cost = {cost} ms, mem = {mem/(2**30)} GB)\n{trc}') + if cost < best_res_time.measure: + best_res_time.measure = cost + best_res_time.trace = trc + best_placement_time = placements + best_keys_time = keys + if cost > worst_res_time.measure: + worst_res_time.measure = cost + + if mem < best_res_mem.measure: + best_res_mem.measure = mem + best_res_mem.trace = trc + best_placement_mem = placements + best_keys_mem = keys + if mem > worst_res_mem.measure: + worst_res_mem.measure = mem + + # From bottom to up (this will exclude the full region as being handled in the for cycle above) + # -> First iteration is the one with len(fusion_region) - 1 + # -> Last iteration gives no fusion regions + # for j in range(0, i+1, increment_factor): + # map_time[group[j].output.name] = get_default_executor(group[j]) + # for k in range(i+1, len(group), increment_factor): + # map_time[group[k].output.name] = ex # Benchmark this placement - trc, keys, placements = get_placed_trace(t_map, increasing_symbols) - cost, out = benchmark_trace(trc, iters=2) - del out - self.log(f'Placed trace (cost = {cost / 1000000} ms)\n{trc}') - if cost < best_time: - best_time = cost - best_trc = trc - best_placement = placements - best_keys = keys - if best_placement is None or best_keys is None: + # trc, keys, placements = get_placed_trace(map_time, increasing_symbols) + # cost, out = benchmark_trace(trc, iters=2) + # del out + # self.log(f'Placed trace (cost = {cost } ms)\n{trc}') + # if cost < best_time: + # best_time = cost + # best_trc = trc + # best_placement = placements + # best_keys = keys + if best_placement_time is None or best_keys_time is None: + raise AssertionError('Failed to get best placement') + if best_placement_mem is None or best_keys_mem is None: raise AssertionError('Failed to get best placement') - self.log(f'For group {i} best placement with cost = {best_time / 1000000} ms is:\n{best_trc}') + self.log(f'For group {group_id} best placement with time cost = {best_res_time.measure} ms (worst time = {worst_res_time.measure} ms):\n{best_res_time.trace}') + self.log(f'For group {group_id} best placement with mem cost = {best_res_mem.measure / (2**30)} GB (worst mem = {worst_res_mem.measure/(2**30)} GB) is:\n{best_res_mem.trace}') # for n, p in zip(best_keys, best_placement): # print(f'{n} -> {p.name}') # Update our dict - for n, p in zip(best_keys, best_placement): - t_map |= {n: p} - - self.log('End of group search') - pprint.pprint(t_map) - - # Generate final trace - trc = from_trace(self.trace) - trc.bound_symbols = list(self.trace.bound_symbols) - executors = [] - for bsym in trc.bound_symbols: + for n, p in zip(best_keys_time, best_placement_time): + map_time |= {n: p} + # Update our dict + for n, p in zip(best_keys_mem, best_placement_mem): + map_mem |= {n: p} + + # self.log('End of group search') + # pprint.pprint(map_time) + + # print('map cmp') + # for k in map_time.keys(): + # if k not in map_mem: + # pprint.pprint(map_time) + # pprint.pprint(map_mem) + # raise AssertionError(f"cannot find {k}") + # pprint.pprint(map_time) + # pprint.pprint(map_mem) + + # Generate the placement + executors_time = [] + executors_mem = [] + for bsym in self.trace.bound_symbols: if bsym.sym.name == 'return': - if 'return' not in t_map: - raise AssertionError(f'Expected key return in mapping {t_map}') - executors.append(t_map['return']) + if 'return' not in map_time or 'return' not in map_mem: + raise AssertionError(f'Expected key return in mapping {map_time} and {map_mem}') + executors_time.append(map_time['return']) + executors_mem.append(map_mem['return']) elif isinstance(bsym.output, Sequence): seq_hash = sequence_hash(bsym.output) - if seq_hash not in t_map: - raise AssertionError(f'Expected key {seq_hash} in mapping {t_map}') - executors.append(t_map[seq_hash]) + if seq_hash not in map_time or seq_hash not in map_mem: + raise AssertionError(f'Expected key {seq_hash} in mapping {map_time} and {map_mem}') + executors_time.append(map_time[seq_hash]) + executors_mem.append(map_mem[seq_hash]) elif isinstance(bsym.output, CollectionProxy) or isinstance(bsym.output, TensorProxy): - if bsym.output.name not in t_map: - raise AssertionError(f'Expected key {bsym.output.name} in mapping {t_map}') - executors.append(t_map[bsym.output.name]) + if bsym.output.name not in map_time or bsym.output.name not in map_mem: + raise AssertionError(f'Expected key {bsym.output.name} in mapping {map_time} and {map_mem}') + executors_time.append(map_time[bsym.output.name]) + executors_mem.append(map_mem[bsym.output.name]) else: raise AssertionError(f"Type not handled: {type(bsym.output)}") - trc = self.place_optimizers(trc, executors) - # Update res - cost, out = benchmark_trace(trc) - self.log(f'Final trace for ex {ex.name}, cost = {cost / 1000000} ms:\n{trc}') - del out - if cost < best_ex_time: - best_ex_time = cost - best_ex_trace = trc - self.log(f'Selected trace from fusion split optimizer:\n{best_ex_trace}') + # Swap return bsym otherwise with no call to remat, we will trace the wrong memory occupation + test_trc = from_trace(self.trace) + test_trc.bound_symbols = list(self.trace.bound_symbols) + test_trc.bound_symbols.pop() + test_trc.bound_symbols.append(self.trace.bound_symbols[-1].from_bsym(args=return_not_used(test_trc))) + trc = self.place_optimizers(test_trc, executors_mem) + c, m, o = benchmark_trace(trc) + del o + self.log(f'Debug MEM, mem = {m/(2**30)} GB:\n{trc}') + self.optimized_traces_mem_benchmark_only.append({ex.name: trc}) + trc = self.place_optimizers(test_trc, executors_time) + c, m, o = benchmark_trace(trc) + del o + self.log(f'Debug TIME, time = {c} ms:\n{trc}') + self.optimized_traces_time_benchmark_only.append({ex.name: trc}) + + # Save executors in order to generate real fw and bw trace with correct output + self.placement_options_time.append(executors_time) + self.placement_options_mem.append(executors_mem) def get_optimal_trace(self) -> TraceCtx: - return self.optimal_trace + if self.optimizer_type == OptimizerType.RUNTIME: + return self.optimal_trace_time + else: + return self.optimal_trace_mem def bsym_assigned(self, bsym: BoundSymbol) -> bool: return isinstance(bsym.sym.executor, OperatorExecutor) or isinstance(bsym.sym.executor, FusionExecutor) + class Result: + def __init__(self) -> None: + self.measure: float = float('inf') + self.trace: TraceCtx | None = None + self.label: str | Hashable = "" + self.index = -1 + def benchmark_traces(self): - min_run_time = float('inf') - optimal_trace: TraceCtx = self.trace # Assign initial value for unbound errors - best_label = "" + + tm = self.Result() + mem = self.Result() self.debug_msg += 'Traces benchmarks:\n\n' - for trace_info in self.optimized_traces: + source_mem = None + source_time = None + if self.strat == self.OptimizationStrat.BEST_FUSER: + source_mem = self.optimized_traces_mem_benchmark_only + source_time = self.optimized_traces_time_benchmark_only + elif self.strat == self.OptimizationStrat.GREEDY: + source_mem = self.optimized_traces_mem + source_time = self.optimized_traces_time + else: + raise AssertionError('Not supported') + + for i, trace_info in enumerate(source_time): label = None trace = None @@ -750,45 +868,143 @@ def benchmark_traces(self): label = k trace = v - trace_time, res = benchmark_trace(trace, iters=10) + trace_time, _, res = benchmark_trace(trace, iters=10) del res - self.debug_msg += f'Trace name = [{label}] - Time = {trace_time / 1000000} ms\n{trace}\n\n' - self.log(f'Benchmark trace "{label}" (time = {trace_time / 1000000} ms):\n{trace}') - if trace_time < min_run_time: - min_run_time = trace_time - optimal_trace = trace - best_label = label + self.debug_msg += f'Trace name = [{label}] - Time = {trace_time} ms\n{trace}\n\n' + self.log(f'Benchmark trace "{label}" (time = {trace_time} ms:\n{trace}') + if trace_time < tm.measure: + tm.measure = trace_time + tm.trace = trace + tm.label = label + tm.index = i - self.log(f'Benchmark end: Best trace "{best_label} (time = {min_run_time / 1000000} ms)":\n{optimal_trace}') + for i, trace_info in enumerate(source_mem): - self.optimal_trace = optimal_trace + label = None + trace = None + for k, v in trace_info.items(): + label = k + trace = v + + _, trace_mem, res = benchmark_trace(trace, iters=10) + del res + self.debug_msg += f'Trace name = [{label}] - Mem = {trace_mem / (2**30)} GB\n{trace}\n\n' + self.log(f'Benchmark trace "{label}" (mem = {trace_mem / (2 ** 30)} GB):\n{trace}') + if trace_mem < mem.measure: + mem.measure = trace_mem + mem.trace = trace + mem.label = label + mem.index = i + + self.log(f'Benchmark end: Best trace time "{tm.label} (time = {tm.measure} ms)":\n{tm.trace}') + self.log(f'Benchmark end: Best trace mem "{mem.label} (mem = {mem.measure / (2 ** 30)} GB)":\n{mem.trace}') + + self.log('Strat comparison') + c, m, o = benchmark_trace(tm.trace) + del o + self.log(f'best time: {c} ms, {m/(2**30)} GB') + c, m, o = benchmark_trace(mem.trace) + del o + self.log(f'best mem: {c} ms, {m/(2**30)} GB') + + # TODO (matteochen): use time or mem strat + if self.strat == self.OptimizationStrat.GREEDY: + self.optimal_trace_time = tm.trace + self.optimal_trace_mem = mem.trace + elif self.strat == self.OptimizationStrat.BEST_FUSER: + d = self.optimized_traces_time[tm.index] + t = None + for _, v in d.items(): + t = v + self.optimal_trace_time = t + d = self.optimized_traces_mem[mem.index] + t = None + for _, v in d.items(): + t = v + self.optimal_trace_mem = t + + self.log(f'Saved best trace time:\n{self.optimal_trace_time}') + self.log(f'Saved best trace mem:\n{self.optimal_trace_mem}') if self.produce_log: with open(self.log_file_name, 'w') as file: file.write(self.debug_msg) file.close() +def return_not_used(trace_in: TraceCtx) -> list[TensorProxy]: + def is_in_sequence(seq: Sequence[Any], t:TensorProxy): + for e in seq: + if isinstance(e, TensorProxy) and e.name == t.name: + return True + return False + + # Check if this naming is always valid + def is_possible_out(name: str): + if not name.startswith('t'): + return False + num = name[1:] + return num.isdigit() + + ans: list[TensorProxy] = [] + for b in trace_in.bound_symbols: + f = False + # Not a tensor + if not isinstance(b.output, TensorProxy): + continue + # Not a produced tensor + if not is_possible_out(b.output.name): + continue + for test in trace_in.bound_symbols: + if test.args is not None and (isinstance(test.args, tuple) or isinstance(test.args, list)) and is_in_sequence(test.args, b.output): + f = True + break + if not f: + ans.append(b.output) + return ans + # This will benchmark the input trace with the del_last_used call -def benchmark_trace(trace: TraceCtx, iters: int = 1) -> tuple[float, Any]: +def benchmark_trace(trace: TraceCtx, iters: int = 1, show_func = False, apply_del_last_used = True, snapshot = False, snapshot_name = "") -> tuple[float, float, Any]: from thunder.executors.passes import del_last_used + import inspect input_args = [] - def compute_time_cost(fn: Callable, iters: int, *args) -> tuple[float, Any]: + if trace.bound_symbols[-1].sym.id != PrimIDs.RETURN: + raise AssertionError('Missing return statement') + + def compute_time_cost_ms(fn: Callable, iters: int, *args) -> tuple[float, float, Any]: warm_up_iters = 3 - total_time = 0 out = None torch.cuda.empty_cache() - for i in range(iters + warm_up_iters): - time_s = time.perf_counter_ns() - out = fn(*args) - torch.cuda.synchronize() - time_e = time.perf_counter_ns() + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + + max_allocated_bytes = 0 + # Warm up cycles + for _ in range(warm_up_iters): + fn(*args) + # Snapshot request + if snapshot: + torch.cuda.memory._record_memory_history() + fn(*args) + torch.cuda.memory._dump_snapshot(snapshot_name + "_benchmark.pickle") + torch.cuda.memory._record_memory_history(enabled=None) + # Benchmark + stream = torch.cuda.current_stream() + for i in range(iters): + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) torch.cuda.empty_cache() - if i >= warm_up_iters: - total_time += (time_e - time_s) + torch.cuda._sleep(1_000_000) + start_events[i].record(stream) + fn(*args) + end_events[i].record(stream) + max_allocated_bytes = max(max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device())) - return total_time / iters, out + torch.cuda.synchronize() + times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + tot_time = sum(times) / iters + return tot_time, max_allocated_bytes, out def print_input_args(args, level=0, show_content = False): for e in args: @@ -797,12 +1013,12 @@ def print_input_args(args, level=0, show_content = False): else: print(f'level {level}', type(e)) - def print_trace_execution_output(out: Any, show_content=False): - if isinstance(out, tuple): - for e in out: - print(f'{type(e)}') - else: - print(f'{type(out)}') + # def print_trace_execution_output(out: Any, show_content=False): + # if isinstance(out, tuple): + # for e in out: + # print(f'{type(e)}') + # else: + # print(f'{type(out)}') def thunder_to_torch_float_dtype(byte: int) -> torch.dtype: if (byte == 2): @@ -818,7 +1034,6 @@ def transform_input_tuple(t: tuple, level=0) -> tuple: if type(e) is tuple: res.append(transform_input_tuple(e, level+1)) else: - # print(f'level {level}', type(e)) if isinstance(e, TensorProxy): res.append(transform_tensor(e)) else: @@ -830,7 +1045,6 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: dtype = arg.dtype if dtype is not None and type(dtype) is thunder.dtypes.floating: torch_dtype = thunder_to_torch_float_dtype(dtype.bytes) - # print(f'thunder type: {dtype} torch_dtype: {torch_dtype}') else: # TODO (matteochen): support other types raise AssertionError(f"dtype {dtype} not supported yet") @@ -841,18 +1055,14 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: # TODO (matteochen): Missing parallel and fsdp handling... # TODO (matteochen): Missing support for meta types ... tensor: torch.Tensor = torch.randn(*shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad) - # print(f'Adding tensor shape: {tensor.shape} dtype: {tensor.dtype} device: {tensor.device} requires_grad: {tensor.requires_grad}') return tensor # Can we remove this check? - if isinstance(trace.args, list): + if isinstance(trace.args, Sequence): for arg in trace.args: - # print(f'current arg {arg}\ntype {type(arg)}') if isinstance(arg, tuple): - # print('Processig tuple') input_args.append(transform_input_tuple(arg)) elif isinstance(arg, TensorProxy): - # print('Processig TensorProxy') e = transform_tensor(arg) input_args.append(e) else: @@ -860,17 +1070,18 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: else: raise AssertionError('Unexpexcted args type') - # Always benchmark trace after a deletion last used pass - trace = del_last_used(trace) + # Always benchmark trace after a deletion last used pass as the final trace out will passed under this stage + if apply_del_last_used: + trace = del_last_used(trace) - # TODO (matteochen): measure time trace_tok = set_tracectx(trace) # Obtain the python executable string - executable_str = trace.python_callable() - # TODO (matteochen): make the iters configurable - t, answer = compute_time_cost(executable_str, iters, *input_args) + executable = trace.python_callable() + if show_func: + print(inspect.getsource(executable)) + t, m, answer = compute_time_cost_ms(executable, iters, *input_args) reset_tracectx(trace_tok) - return t, answer + return t, m, answer diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index b37298dd8f..65d454c2e6 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -23,7 +23,7 @@ from thunder.executors.pythonex import clear_mutable_collection from thunder.extend import Executor, get_all_executors, get_always_executors, OperatorExecutor, FusionExecutor -from thunder.backend_optimizer.optimizer import BackendOptimizer +from thunder.backend_optimizer.optimizer import BackendOptimizer, OptimizerType from thunder.visualizer.graphviz import create_graphviz_pdf from thunder.visualizer.visualizer_helper import Visualizer @@ -136,7 +136,7 @@ def visit_(bsym: BoundSymbol) -> transforms.VISIT_TYPE: return extrace # Autotuned transform_for_execution version -def autotune_transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor], visualizer: Visualizer | None = None) -> TraceCtx: +def autotune_transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor], autotune_type: OptimizerType, visualizer: Visualizer | None = None) -> TraceCtx: import torch # Recover the function name @@ -152,7 +152,7 @@ def autotune_transform_for_execution(trace: TraceCtx, executors_list: Sequence[E trace = dce(trace) - backend_optimizer = BackendOptimizer(trace, executors_list, produce_log=True, log_file_name=f'autotune_transform_for_execution_{sig_name}.log', visualizer=visualizer) + backend_optimizer = BackendOptimizer(trace, executors_list, produce_log=True, log_file_name=f'autotune_transform_for_execution_{sig_name}.log', visualizer=visualizer, optimizer_type=autotune_type) backend_optimizer.optimize() backend_optimizer.benchmark_traces() extrace = backend_optimizer.get_optimal_trace() @@ -161,7 +161,7 @@ def autotune_transform_for_execution(trace: TraceCtx, executors_list: Sequence[E elapsed_time_ns = end_time_ns - start_time_ns elapsed_time_millis = elapsed_time_ns // 1000000 - extrace.set_provenance(TraceProvenance(f"Autotuned transform for execution (took {elapsed_time_millis} milliseconds)")) + extrace.set_provenance(TraceProvenance(f"Autotuned transform for execution (strat: {autotune_type}) (took {elapsed_time_millis} milliseconds)")) return extrace diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 0b2c393982..13850e56de 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -106,7 +106,7 @@ def backward(ctx, *args): return (None, None, None, None, None, *([None] * n_grads)) # TODO (matteochen): add control for using autotuner or not -def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stats, autotune_executors, /, *flat_args): +def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stats, autotune_type, /, *flat_args): from thunder.core.rematerialization import rematerialize_all_gather, rematerialize_forward_and_backward from thunder.core.transforms import forward_and_backward_from_trace from thunder.distributed.transforms import FSDPCommBucketing @@ -170,10 +170,11 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat # TODO Restore request for no rematerialization visualizer.set_fw_initial_trace(fw_trace) - if autotune_executors: + if autotune_type is not None: fw_extrace = autotune_transform_for_execution( fw_trace, executors_list=compile_data.executors_list, + autotune_type=autotune_type, visualizer=visualizer ) else: @@ -216,10 +217,11 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat # Now we can run the optimization passes on the backward trace # TODO Restore request for no rematerialization visualizer.set_bw_initial_trace(bw_trace) - if autotune_executors: + if autotune_type is not None: bw_extrace = autotune_transform_for_execution( bw_trace, executors_list=compile_data.executors_list, + autotune_type=autotune_type, visualizer=visualizer ) else: From d94e841d91fc566608e5482d91fefa5328db42e5 Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 19 Jul 2024 12:41:56 +0200 Subject: [PATCH 017/171] Updated test model --- examples/dev/LLaMAMLP.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/dev/LLaMAMLP.py b/examples/dev/LLaMAMLP.py index 545aa856ef..58382b8d0a 100644 --- a/examples/dev/LLaMAMLP.py +++ b/examples/dev/LLaMAMLP.py @@ -14,13 +14,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.proj(x) with torch.device('cuda'): - from thunder.backend_optimizer.optimizer import OptimizerType, benchmark_trace + from thunder.backend_optimizer.optimizer import benchmark_trace a = 4096 * 1 b = 11008 * 1 x = torch.randn(2, 2048, a, requires_grad=True) + model = LLaMAMLP(a, b) + jmodel_def = thunder.jit(model) + jmodel_auto = thunder.jit(model, autotune_type='memory') - jmodel_def = thunder.jit(LLaMAMLP(a, b)) - jmodel_auto = thunder.jit(LLaMAMLP(a, b), autotune_type='memory') warm_up_iters = 2 iters = 10 stream = torch.cuda.current_stream() @@ -28,9 +29,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: for _ in range(warm_up_iters): y = jmodel_auto(x) yy = jmodel_def(x) + yyy = model(x) torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y)) torch.autograd.grad(yy, x, grad_outputs=torch.ones_like(y)) + print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) + print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) + print('\n\n') for i in range(1): From 15f51bfa8a503e69e00cdb8f0f767d8775fa19fb Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 26 Jul 2024 15:50:13 +0300 Subject: [PATCH 018/171] Enhanced bw trace placement seach / support for more models (#3) --- examples/dev/LLaMAMLP.py | 120 +- examples/dev/MLP.py | 91 ++ examples/dev/csa.py | 153 +++ examples/dev/litGPT.py | 45 +- examples/dev/simple.py | 137 ++- thunder/backend_optimizer/optimizer.py | 1420 ++++++++++++++++-------- thunder/executors/passes.py | 47 +- thunder/executors/torch_autograd.py | 111 +- 8 files changed, 1445 insertions(+), 679 deletions(-) create mode 100644 examples/dev/MLP.py create mode 100644 examples/dev/csa.py diff --git a/examples/dev/LLaMAMLP.py b/examples/dev/LLaMAMLP.py index 58382b8d0a..98fadb8dbd 100644 --- a/examples/dev/LLaMAMLP.py +++ b/examples/dev/LLaMAMLP.py @@ -15,93 +15,32 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: with torch.device('cuda'): from thunder.backend_optimizer.optimizer import benchmark_trace - a = 4096 * 1 - b = 11008 * 1 + # See changes from mult = 1 to mult = 4 + mult = 1 + a = 4096 * mult + b = 11008 * mult x = torch.randn(2, 2048, a, requires_grad=True) model = LLaMAMLP(a, b) - jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type='memory') - - warm_up_iters = 2 - iters = 10 - stream = torch.cuda.current_stream() - - for _ in range(warm_up_iters): - y = jmodel_auto(x) - yy = jmodel_def(x) - yyy = model(x) - torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y)) - torch.autograd.grad(yy, x, grad_outputs=torch.ones_like(y)) + jmodel_def = thunder.jit(model, executors=['torchcompile', 'nvfuser']) + jmodel_auto = thunder.jit(model, autotune_type='runtime') + y = model(x) print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) - print('\n\n') - - for i in range(1): - - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - middle_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - - for i in range(iters): - torch.cuda.empty_cache() - torch.cuda._sleep(1_000_000) - start_events[i].record(stream) - y = jmodel_auto(x) - middle_events[i].record(stream) - torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y)) - end_events[i].record(stream) - - torch.cuda.synchronize() - fw = [s.elapsed_time(e) for s, e in zip(start_events, middle_events)] - bw = [s.elapsed_time(e) for s, e in zip(middle_events, end_events)] - tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] - fw_time = sum(fw) - bw_time = sum(bw) - tot_time = sum(tot) - print(f'Auto fw: {fw_time / iters}') - print(f'Auto bw: {bw_time / iters}') - print(f'Auto tot: {tot_time / iters}') - print('\n') - - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - middle_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - - for i in range(iters): - torch.cuda.empty_cache() - torch.cuda._sleep(1_000_000) - start_events[i].record(stream) - y = jmodel_def(x) - middle_events[i].record(stream) - torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y)) - end_events[i].record(stream) - - torch.cuda.synchronize() - fw = [s.elapsed_time(e) for s, e in zip(start_events, middle_events)] - bw = [s.elapsed_time(e) for s, e in zip(middle_events, end_events)] - tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] - fw_time = sum(fw) - bw_time = sum(bw) - tot_time = sum(tot) - print(f'Default fw: {fw_time / iters}') - print(f'Default bw: {bw_time / iters}') - print(f'Default tot: {tot_time / iters}') - print('-------------------------------------------------------') - - c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='def_fw') - print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') - del o - c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='auto_fw') - print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') - del o - c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='def_bw') - print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') - del o - c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='auto_bw') - print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') - del o + print('########################################') + c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], iters=2, apply_del_last_used=False, snapshot=True, snapshot_name='LLaMAMLP_def_fw') + print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') + del o + c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], iters=2, apply_del_last_used=False, snapshot=True, snapshot_name='LLaMAMLP_auto_fw') + print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') + del o + c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], iters=2, apply_del_last_used=False, snapshot=True, snapshot_name='LLaMAMLP_def_bw') + print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') + del o + c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], iters=2, apply_del_last_used=False, snapshot=True, snapshot_name='LLaMAMLP_auto_bw') + print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') + del o print('\n\n\n\n\n\n') print(f'{thunder.last_traces(jmodel_def)[-1]}') @@ -112,22 +51,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') print('###############################################################################') print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') - - from torch.profiler import profile, record_function, ProfilerActivity - with profile(activities=[ - ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: - with record_function("def"): - y = jmodel_def(x) - grad_outputs = torch.ones_like(y) - torch.autograd.grad(y, x, grad_outputs=grad_outputs) - - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - - with profile(activities=[ - ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: - with record_function("auto"): - y = jmodel_auto(x) - grad_outputs = torch.ones_like(y) - torch.autograd.grad(y, x, grad_outputs=grad_outputs) - - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) diff --git a/examples/dev/MLP.py b/examples/dev/MLP.py new file mode 100644 index 0000000000..cd803d76c9 --- /dev/null +++ b/examples/dev/MLP.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn +import thunder +from thunder.backend_optimizer.optimizer import benchmark_trace +# import logging + +# torch._logging.set_logs(dynamo = logging.DEBUG) +# torch._dynamo.config.verbose = True + +class ModelConfig: + def __init__(self, n_embd=256, n_head=8, dropout=0.1, block_size=64, bias=True): + self.n_embd = n_embd + self.n_head = n_head + self.dropout = dropout + self.bias = bias + self.block_size = block_size + +class MLP(nn.Module): + def __init__(self, config): + super().__init__() + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + self.gelu = nn.GELU() + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.dropout = nn.Dropout(config.dropout) + + def forward(self, x): + x = self.c_fc(x) + x = self.gelu(x) + x = self.c_proj(x) + x = self.dropout(x) + return x + +with torch.device('cuda'): + embeddings = 3072 + config = ModelConfig(n_embd=embeddings) + dtype = torch.float32 + x = torch.randn(16, 1024, embeddings, requires_grad=True) + + model = MLP(config) + + jmodel_def = thunder.jit(model) + # This model fails under some circumstances after passed the placed traced under the rematelizer + jmodel_auto = thunder.jit(model, autotune_type='memory') + + y = model(x) + print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) + print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) + + print('########################################') + c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_def_fw') + print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') + del o + c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_auto_fw') + print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') + del o + c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_def_bw') + print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') + del o + c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_auto_bw') + print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') + del o + + print('\n\n\n\n\n\n') + print(f'{thunder.last_traces(jmodel_def)[-1]}') + print('###############################################################################') + print(f'{thunder.last_traces(jmodel_auto)[-1]}') + + print('\n\n') + print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') + print('###############################################################################') + print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') + + # from torch.profiler import profile, record_function, ProfilerActivity + # with profile(activities=[ + # ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: + # with record_function("def"): + # y = jmodel_def(x) + # grad_outputs = torch.ones_like(y) + # torch.autograd.grad(y, x, grad_outputs=grad_outputs) + + # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + # with profile(activities=[ + # ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: + # with record_function("auto"): + # y = jmodel_auto(x) + # grad_outputs = torch.ones_like(y) + # torch.autograd.grad(y, x, grad_outputs=grad_outputs) + + # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + diff --git a/examples/dev/csa.py b/examples/dev/csa.py new file mode 100644 index 0000000000..c011e5d48b --- /dev/null +++ b/examples/dev/csa.py @@ -0,0 +1,153 @@ +import torch +import torch.nn as nn +import thunder +from thunder.backend_optimizer.optimizer import benchmark_trace + +# import torch._dynamo +# torch._dynamo.config.suppress_errors = True + +class CausalSelfAttention(nn.Module): + + def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0): + super().__init__() + assert embed_dimension % num_heads == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias) + # output projection + self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias) + # regularization + self.dropout = dropout + self.resid_dropout = nn.Dropout(dropout) + self.num_heads = num_heads + self.embed_dimension = embed_dimension + # Perform causal masking + self.is_causal = is_causal + + def forward(self, x): + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + query_projected = self.c_attn(x) + + batch_size = query_projected.size(0) + embed_dim = query_projected.size(2) + head_dim = embed_dim // (self.num_heads * 3) + + query, key, value = query_projected.chunk(3, -1) + query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) + + if self.training: + dropout = self.dropout + is_causal = self.is_causal + else: + dropout = 0.0 + is_causal = False + + y = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal) + y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim) + + y = self.resid_dropout(self.c_proj(y)) + return y + +device = torch.device('cuda') +num_heads = 8 +heads_per_dim = 64 * 4 +embed_dimension = num_heads * heads_per_dim +dtype = torch.float32 +model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to(device).to(dtype) +print(model) +batch_size = 16 +max_sequence_len = 1024 +x = torch.randn(batch_size, max_sequence_len, embed_dimension, dtype=dtype, requires_grad=True, device=device) + +jmodel_def = thunder.jit(model) +jmodel_auto = thunder.jit(model, autotune_type='runtime') + +warm_up_iters = 2 +iters = 10 +stream = torch.cuda.current_stream() + +y = model(x) +for _ in range(warm_up_iters): + yy = jmodel_def(x) + yyy = jmodel_auto(x) + torch.autograd.grad(yy, x, grad_outputs=torch.ones_like(y)) + torch.autograd.grad(yyy, x, grad_outputs=torch.ones_like(y)) + +print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) +print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) + +# print('\n\n') + +# start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] +# middle_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] +# end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + +# for i in range(iters): +# torch.cuda.empty_cache() +# torch.cuda._sleep(1_000_000) +# start_events[i].record(stream) +# y = jmodel_auto(x) +# middle_events[i].record(stream) +# torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y)) +# end_events[i].record(stream) + +# torch.cuda.synchronize() +# fw = [s.elapsed_time(e) for s, e in zip(start_events, middle_events)] +# bw = [s.elapsed_time(e) for s, e in zip(middle_events, end_events)] +# tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] +# fw_time = sum(fw) +# bw_time = sum(bw) +# tot_time = sum(tot) +# print(f'Auto fw: {fw_time / iters}') +# print(f'Auto bw: {bw_time / iters}') +# print(f'Auto tot: {tot_time / iters}') +# print('\n') + +# start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] +# middle_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] +# end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + +# for i in range(iters): +# torch.cuda.empty_cache() +# torch.cuda._sleep(1_000_000) +# start_events[i].record(stream) +# y = jmodel_def(x) +# middle_events[i].record(stream) +# torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y)) +# end_events[i].record(stream) + +# torch.cuda.synchronize() +# fw = [s.elapsed_time(e) for s, e in zip(start_events, middle_events)] +# bw = [s.elapsed_time(e) for s, e in zip(middle_events, end_events)] +# tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] +# fw_time = sum(fw) +# bw_time = sum(bw) +# tot_time = sum(tot) +# print(f'Default fw: {fw_time / iters}') +# print(f'Default bw: {bw_time / iters}') +# print(f'Default tot: {tot_time / iters}') +# print('-------------------------------------------------------') + +c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], iters = 10, apply_del_last_used=False, snapshot=True, snapshot_name='def_fw') +print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') +del o +c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], iters=10, apply_del_last_used=False, snapshot=True, snapshot_name='auto_fw') +print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') +del o +c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], iters=10, apply_del_last_used=False, snapshot=True, snapshot_name='def_bw') +print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') +del o +c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], iters=10, apply_del_last_used=False, snapshot=True, snapshot_name='auto_bw') +print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') +del o + +print('\n\n\n\n\n\n') +print(f'{thunder.last_traces(jmodel_def)[-1]}') +print('###############################################################################') +print(f'{thunder.last_traces(jmodel_auto)[-1]}') + +print('\n\n') +print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') +print('###############################################################################') +print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 7139a415c0..ec8c78044d 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -2,15 +2,48 @@ from thunder.tests.litgpt_model import Config import thunder import torch +from thunder.backend_optimizer.optimizer import benchmark_trace cfg = Config.from_name('Llama-2-7b-hf') -cfg.n_layer = 16 # fewer layers +cfg.n_layer = 1 # fewer layers torch.set_default_dtype(torch.bfloat16) + with torch.device('cuda'): - m = GPT(cfg) - thunder_model = thunder.jit(m) + model = GPT(cfg) + x = torch.randint(1, model.config.vocab_size, (1, 512)) + jmodel_def = thunder.jit(model) + jmodel_auto = thunder.jit(model, autotune_type='runtime') + y = jmodel_def(x) + yy = jmodel_auto(x) + + jmodel_def = thunder.jit(model) + # This model fails under some circumstances after passed the placed traced under the rematelizer + jmodel_auto = thunder.jit(model, autotune_type='memory') + + y = model(x) + print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) + print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) + + print('########################################') + c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_def_fw') + print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') + del o + c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_auto_fw') + print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') + del o + c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_def_bw') + print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') + del o + c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_auto_bw') + print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') + del o - inp = torch.randint(1, m.config.vocab_size, (1, 512)) + print('\n\n\n\n\n\n') + print(f'{thunder.last_traces(jmodel_def)[-1]}') + print('###############################################################################') + print(f'{thunder.last_traces(jmodel_auto)[-1]}') - actual = thunder_model(inp) - expected = m(inp) + print('\n\n') + print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') + print('###############################################################################') + print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') diff --git a/examples/dev/simple.py b/examples/dev/simple.py index e0c7d65e11..a3128ffda2 100644 --- a/examples/dev/simple.py +++ b/examples/dev/simple.py @@ -1,7 +1,6 @@ import torch import thunder -import time -import inspect +from thunder.backend_optimizer.optimizer import benchmark_trace class Module(torch.nn.Module): def __init__(self, in_features, out_features) -> None: @@ -20,45 +19,113 @@ def forward(self, x: torch.Tensor): multiplier = 1000 in_features = 20 * multiplier out_features = 30 * multiplier - - jmodel_default = thunder.jit(Module(in_features, out_features), autotune_executors=False) - jmodel_autotune = thunder.jit(Module(in_features, out_features), autotune_executors=True) + model = Module(in_features, out_features) x = torch.randn(128, in_features, requires_grad=True) - warm_up_iters = 3 - for i in range(10): - start_fw = time.perf_counter_ns() - y = jmodel_default(x) - torch.cuda.synchronize() - end_fw = time.perf_counter_ns() + + jmodel_def = thunder.jit(model, autotune_executors=False) + jmodel_auto = thunder.jit(model, autotune_executors=True) + stream = torch.cuda.current_stream() + + warm_up_iters = 2 + iters = 10 + for _ in range(warm_up_iters): + y = jmodel_auto(x) + yy = jmodel_def(x) grad_outputs = torch.ones_like(y) - torch.cuda.synchronize() - start_bw = time.perf_counter_ns() torch.autograd.grad(y, x, grad_outputs=grad_outputs) - torch.cuda.synchronize() - end_bw = time.perf_counter_ns() - torch.cuda.empty_cache() - # source = inspect.getsource(y.grad_fn.compiled_backward) + torch.autograd.grad(yy, x, grad_outputs=grad_outputs) - if i >= warm_up_iters: - print(f'tot time default forward = {(end_fw - start_fw) / 1000000} ms') - print(f'tot time default backward = {(end_bw - start_bw) / 1000000} ms') + print('\n\n') + + for i in range(1): + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + middle_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + + for i in range(iters): + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + start_events[i].record(stream) + y = jmodel_auto(x) + middle_events[i].record(stream) + torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y)) + end_events[i].record(stream) - for i in range(10): - start_fw = time.perf_counter_ns() - y = jmodel_autotune(x) - torch.cuda.synchronize() - end_fw = time.perf_counter_ns() - grad_outputs = torch.ones_like(y) torch.cuda.synchronize() - start_bw = time.perf_counter_ns() - torch.autograd.grad(y, x, grad_outputs=grad_outputs) + fw = [s.elapsed_time(e) for s, e in zip(start_events, middle_events)] + bw = [s.elapsed_time(e) for s, e in zip(middle_events, end_events)] + tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + fw_time = sum(fw) + bw_time = sum(bw) + tot_time = sum(tot) + print(f'Auto fw: {fw_time / iters}') + print(f'Auto bw: {bw_time / iters}') + print(f'Auto tot: {tot_time / iters}') + print('\n') + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + middle_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + + for i in range(iters): + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + start_events[i].record(stream) + y = jmodel_def(x) + middle_events[i].record(stream) + torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y)) + end_events[i].record(stream) + torch.cuda.synchronize() - end_bw = time.perf_counter_ns() - torch.cuda.empty_cache() - # source = inspect.getsource(y.grad_fn.compiled_backward) + fw = [s.elapsed_time(e) for s, e in zip(start_events, middle_events)] + bw = [s.elapsed_time(e) for s, e in zip(middle_events, end_events)] + tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + fw_time = sum(fw) + bw_time = sum(bw) + tot_time = sum(tot) + print(f'Default fw: {fw_time / iters}') + print(f'Default bw: {bw_time / iters}') + print(f'Default tot: {tot_time / iters}') + print('-------------------------------------------------------') + + c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], apply_del_last_used=False) + print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') + del o + c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], apply_del_last_used=False) + print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') + del o + c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], apply_del_last_used=False) + print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') + del o + c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], apply_del_last_used=False) + print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') + del o + # print('\n\n\n\n\n\n') + # print(f'{thunder.last_traces(jmodel_def)[-1]}') + # print('###############################################################################') + # print(f'{thunder.last_traces(jmodel_auto)[-1]}') + + # print('\n\n') + # print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') + # print('###############################################################################') + # print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') + + from torch.profiler import profile, record_function, ProfilerActivity + with profile(activities=[ + ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: + with record_function("def"): + y = jmodel_def(x) + grad_outputs = torch.ones_like(y) + torch.autograd.grad(y, x, grad_outputs=grad_outputs) + + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - if i >= warm_up_iters: - print(f'tot time autotune forward = {(end_fw - start_fw) / 1000000} ms') - print(f'tot time autotune backward = {(end_bw - start_bw) / 1000000} ms') - # print('\n\n', thunder.last_backward_traces(jmodel)[-1]) + with profile(activities=[ + ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: + with record_function("auto"): + y = jmodel_auto(x) + grad_outputs = torch.ones_like(y) + torch.autograd.grad(y, x, grad_outputs=grad_outputs) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index a45dad7b3f..ff6d6d210a 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -1,26 +1,40 @@ from collections.abc import Callable, Sequence from enum import Enum from itertools import chain +from thunder.core.dtypes import dtype, is_boolean_dtype from thunder.core.prims import PrimIDs from thunder.core.utils import check, safe_map_flat from thunder.core.baseutils import BoundSymbolInterface -from thunder.core.proxies import CollectionProxy, Proxy, TensorProxy, variableify, Variable +from thunder.core.proxies import CollectionProxy, FloatProxy, IntegerProxy, Proxy, TensorProxy, variableify, Variable from thunder.core.symbol import BoundSymbol, Symbol from thunder.core.trace import from_trace, set_tracectx, reset_tracectx, get_tracectx, TraceCtx from thunder.executors.data_dependent_partition import Graph, Node -from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_all_executors, get_always_executors, resolve_executors +from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_always_executors, resolve_executors from thunder.visualizer.visualizer_helper import Visualizer from typing import Any, Hashable import thunder import thunder.core.transforms as transforms import torch -# import concurrent.futures +import time + class OptimizerType(Enum): - MEMORY = 1 - RUNTIME = 2 + MEMORY = 0 + RUNTIME = 1 + + +class TraceType(Enum): + FW = 0 + BW = 1 + + +class OptimizationAlgorithm(Enum): + EXAUSTIVE = 0 + GREEDY = 1 + BEST_FUSER = 2 -class OptimizerNode(): + +class OptimizerNode: def __init__(self, node: Node): self.node: Node = node self.candidate_executors: dict[Hashable, float] = {} @@ -28,122 +42,214 @@ def __init__(self, node: Node): def add_candidate(self, ex: Executor, benchmark: float): self.candidate_executors[ex] = benchmark -class BackendOptimizer(): - def log(self, what: str): - print(f'================================================================================ Autotune: {what}') - def __init__(self, trace: TraceCtx, priority_executors: Sequence[Executor], produce_log=True, log_file_name='autotune_debug.log', visualizer: Visualizer | None = None, optimizer_type: OptimizerType = OptimizerType.RUNTIME) -> None: - from thunder.core.transform_common import dce - # Add more supported ones - self.trace: TraceCtx = dce(trace) - self.always_executors: tuple[Executor, ...] = get_always_executors() - self.computation_graph: Graph = Graph(trace) - self.debug_msg: str = "" - self.empty_executor_hashable_placeholder: str = 'empty' - self.executors: Sequence[Executor] = resolve_executors(['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python']) - self.fusion_executors: Sequence[FusionExecutor] = [ex for ex in self.executors if isinstance(ex, FusionExecutor)] - self.incremental_search_out_trace: TraceCtx - self.log_file_name: str = log_file_name - self.optimal_trace_mem: TraceCtx = trace - self.optimal_trace_time: TraceCtx = trace +class TraceCandidates: + def __init__(self, best_time: TraceCtx | None = None, best_mem: TraceCtx | None = None) -> None: + self.best_time: TraceCtx | None = best_time + self.best_mem: TraceCtx | None = best_mem + self.time_took: float = 0 + + def __repr__(self) -> str: + return f"\nBest runtime candidate:\n{self.best_time}\nBest memory candidate:\n{self.best_mem}" + + def is_set(self) -> bool: + return False if self.best_time is None or self.best_mem is None else True + + def attach_best_time_candidate(self, trace: TraceCtx): + self.best_time = trace + + def attach_best_mem_candidate(self, trace: TraceCtx): + self.best_mem = trace + + def assign_time_took(self, time: float): + self.time_took = time + + def to_list(self) -> list[TraceCtx | None]: + return [self.best_time, self.best_mem] + + +class FinalOutputCandidates: + def __init__(self, *, fw: TraceCtx, bw: TraceCtx, cost: float) -> None: + self.fw: TraceCtx = fw + self.bw: TraceCtx = bw + self.tot_cost: float = cost + + def __repr__(self) -> str: + return f"Forward trace:\n{self.fw.__repr__()}\nBackward trace:{self.bw.__repr__()}" + + +# Benchmark only traces will contain traces after the rematerialization call with fw and bw calls, reproducing what will be the real traces after the autotune pass +# Non benchmark traces will contain traces after the placement (default) with no call to remat +# We have duplciated those in order to maintain thunder compilation flow as the output from the autotuner will be the traces with no pass through rematerialization +# TODO (matteochen): currently the GREEDY strat is using this data structure, fix this +class FusionStratHelper: + def __init__(self) -> None: + self.supported_executors: set = set(["nvfuser", "torchcompile"]) self.optimized_traces_mem: list[dict[str | Hashable, TraceCtx]] = [] - self.optimized_traces_mem_benchmark_only: list[dict[str | Hashable, TraceCtx]] = [] + self.optimized_traces_mem_benchmark_only: list[dict[str | Hashable, TraceCtx]] = [ + ] self.optimized_traces_time: list[dict[str | Hashable, TraceCtx]] = [] - self.optimized_traces_time_benchmark_only: list[dict[str | Hashable, TraceCtx]] = [] - self.partial_costs: dict[TraceCtx, float] = {} + self.optimized_traces_time_benchmark_only: list[dict[str | Hashable, TraceCtx]] = [ + ] + + +class ExecutorPlacementOptions: + def __init__(self) -> None: self.placement_options_mem: list[list[Executor]] = [] self.placement_options_time: list[list[Executor]] = [] + + +class BackendOptimizer: + def log(self, what: str): + print( + f"================================================================================ Autotune: {what}") + + def __init__( + self, + *, + priority_executors: Sequence[Executor], + produce_log=True, + apply_bucketing_bw_trace: bool, + log_file_name="autotune_debug.log", + visualizer: Visualizer | None = None, + optimizer_type: OptimizerType = OptimizerType.RUNTIME, + ) -> None: + print(f"OO {optimizer_type}") + self.always_executors: tuple[Executor, ...] = get_always_executors() + self.empty_executor_hashable_placeholder: str = "empty" + self.executors: Sequence[Executor] = resolve_executors( + ["nvfuser", "torchcompile", "sdpa", "cudnn", "torch", "python"] + ) self.priority_executors: Sequence[Executor] = priority_executors - self.produce_log: bool = produce_log - self.strat = None - self.supported_fusion_executors_by_fusion_strat: set = set(['nvfuser', 'torchcompile']) + self.fusion_executors: Sequence[FusionExecutor] = [ + ex for ex in self.executors if isinstance(ex, FusionExecutor) + ] + + self.debug_msg: str = "" + self.partial_costs: dict[TraceCtx, float] = {} self.visualizer: Visualizer | None = visualizer + self.log_file_name: str = log_file_name + self.produce_log: bool = produce_log + self.optimizer_type: OptimizerType = optimizer_type + self.optimization_algorithm: OptimizationAlgorithm | None = None - self.log(f'New trace to optimize (strat = {self.optimizer_type}):\n{self.trace}') - self.log('Executors:') - for o in self.executors: - self.log(f'{o.name} -> is operator = {isinstance(o, OperatorExecutor)}, is fusion = {isinstance(o, FusionExecutor)}') + self.cached_fw_trace: TraceCtx | None = None + self.fw_trace_candidates: TraceCandidates = TraceCandidates() + self.bw_trace_candidates: TraceCandidates = TraceCandidates() + self.out: list[FinalOutputCandidates] = [] - class OptimizationStrat(Enum): - EXAUSTIVE = 1 - GREEDY = 2 - BEST_FUSER = 3 + # Strat greedy + self.computation_graph: Graph + self.incremental_search_out_trace: TraceCtx - # TODO (matteochen): fix this - def __repr__(self) -> str: - return '' + # Strat fusion + self.fusion_strat_helper: FusionStratHelper = FusionStratHelper() + self.executor_placement_options: ExecutorPlacementOptions = ExecutorPlacementOptions() + + self.apply_bucketing_bw_trace: bool = apply_bucketing_bw_trace - def write(self, file_name): - with open(file_name, 'w') as file: - s = self.__repr__() - file.write(s) - file.close() + self.log("Executors:") + for e in self.executors: + self.log( + f"{e.name} -> is operator = {isinstance(e, OperatorExecutor)}, is fusion = {isinstance(e, FusionExecutor)}" + ) class SearchNode: - def __init__(self, symbol: BoundSymbolInterface, idx: int)-> None: + def __init__(self, symbol: BoundSymbolInterface, idx: int) -> None: self.symbol = symbol self.idx = idx - class TraceType(Enum): - COMPUTATIONAL = 0 - FW = 1 - BW = 2 + # Currently this manages both time and memory + class Result: + def __init__(self) -> None: + self.tm: float = float("inf") + self.mem: float = float("inf") + self.trace: TraceCtx | None = None + self.label: str | Hashable = "" + self.index = -1 + + def attach_cached_fw_traces(self, cached_fw_traces: TraceCandidates) -> None: + self.cached_fw_traces = cached_fw_traces + + def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): + from thunder.core.transform_common import dce + + self.trace_type = trace_type + # dce for the backward trace will be passed afterwards + self.trace: TraceCtx = dce( + trace) if trace_type == TraceType.FW else trace + + match self.trace_type: + case TraceType.FW: + self.log( + f"New forward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}") + # TODO (matteochen): support bw trace optimization even though with no fw traces cached + case TraceType.BW: + self.log( + f"New backward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}") + if not self.fw_trace_candidates.is_set(): + raise AssertionError( + "Can not optimize backward traces before forward traces") # TODO (matteochen): this has a lot in common with the exaustive search, compact them - def build_placement_options_incremental(self, whoami: TraceType = TraceType.COMPUTATIONAL): + def build_placement_options_incremental(self, whoami: TraceType = TraceType.FW): import sys old_max_recursion = sys.getrecursionlimit() sys.setrecursionlimit(2000) # Last index inclusive - def benchmark_partial_trace(trace_in: TraceCtx, last_idx: int, configuration: list[Executor]) -> tuple[float, TraceCtx]: - + def benchmark_partial_trace( + trace_in: TraceCtx, last_idx: int, configuration: list[Executor] + ) -> tuple[float, TraceCtx]: def safe_update_dict(d: dict, key_one, key_two, value): if key_one not in d: - d[key_one] = { - key_two: value - } + d[key_one] = {key_two: value} else: d[key_one][key_two] = value # Retrive all output tensors from each subregion tensors = [] - for i in range(last_idx+1): + for i in range(last_idx + 1): if not isinstance(trace_in.bound_symbols[i], BoundSymbol): - raise AssertionError('Expected BoundSymbol but received BoundSymbolInterface') + raise AssertionError( + "Expected BoundSymbol but received BoundSymbolInterface") s = trace_in.bound_symbols[i] # For each bsym region we expect to output a Tensor tensors.append(s.output) - forced_return_bsym = trace_in.bound_symbols[-1].from_bsym(args=tensors) # Should not be an Interface type at this point + forced_return_bsym = trace_in.bound_symbols[-1].from_bsym( + args=tensors + ) # Should not be an Interface type at this point t = from_trace(trace_in) # Cut the trace to the required depth - t.bound_symbols = list(trace_in.bound_symbols)[:last_idx+1] + t.bound_symbols = list(trace_in.bound_symbols)[: last_idx + 1] t.bound_symbols.append(forced_return_bsym) - configuration.append(Executor(name=self.empty_executor_hashable_placeholder)) # Empty executor for the forced_return + configuration.append( + Executor(name=self.empty_executor_hashable_placeholder) + ) # Empty executor for the forced_return # Place the assigned symbols placed_t = self.place_optimizers(t, configuration) cost, mem, answer = benchmark_trace(placed_t, iters=5) del answer - self.log(f'Executing partial trace for incremental benchmark:\n{placed_t}') - self.log(f'Symbol under test = {t.bound_symbols[-2].sym.name}') - self.log(f'Assigned executor = {configuration[-2].name}') - self.log(f'Time = {cost} ms') + self.log( + f"Executing partial trace for incremental benchmark:\n{placed_t}") + self.log(f"Symbol under test = {t.bound_symbols[-2].sym.name}") + self.log(f"Assigned executor = {configuration[-2].name}") + self.log(f"Time = {cost} ms") # TODO (matteochen): log this to file safe_update_dict(self.partial_costs, whoami, t, cost) return cost, placed_t # We assign an internal id to each symbol based on its idx inside the bound_symbols list def search(node: self.SearchNode, configuration: list[Executor]): - def continue_search(): - if node.idx+1 < max_len: + if node.idx + 1 < max_len: new_idx: int = node.idx + 1 new_symbol: BoundSymbolInterface = bound_symbols[new_idx] search(self.SearchNode(new_symbol, new_idx), configuration) @@ -151,26 +257,29 @@ def continue_search(): all_configurations.append(configuration) has_backend = False - min_cost = float('inf') + min_cost = float("inf") min_cost_ex = None ex: Executor # TODO (matteochen): do parallel for for ex in self.executors: - cost = float('inf') + cost = float("inf") if not isinstance(node.symbol, BoundSymbol): - raise AssertionError("Receive a symbol which is not a BoundSymbol") - if (isinstance(ex, OperatorExecutor) and ex.can_execute(node.symbol)): + raise AssertionError( + "Receive a symbol which is not a BoundSymbol") + if isinstance(ex, OperatorExecutor) and ex.can_execute(node.symbol): has_backend = True configuration.append(ex) - cost, _ = benchmark_partial_trace(self.trace, node.idx, list(configuration)) + cost, _ = benchmark_partial_trace( + self.trace, node.idx, list(configuration)) configuration.pop() - if (isinstance(ex, FusionExecutor) and ex.can_fuse(node.symbol)): + if isinstance(ex, FusionExecutor) and ex.can_fuse(node.symbol): has_backend = True configuration.append(ex) - cost, _ = benchmark_partial_trace(self.trace, node.idx, list(configuration)) + cost, _ = benchmark_partial_trace( + self.trace, node.idx, list(configuration)) configuration.pop() if cost < min_cost: @@ -182,8 +291,10 @@ def continue_search(): continue_search() else: if min_cost_ex is None: - raise AssertionError("Unexpected min cost executor or trace: None") - self.log(f'\nFor id: {node.idx} - {node.symbol.sym.name} -> best backend {min_cost_ex.name}\n') + raise AssertionError( + "Unexpected min cost executor or trace: None") + self.log( + f"\nFor id: {node.idx} - {node.symbol.sym.name} -> best backend {min_cost_ex.name}\n") configuration.append(min_cost_ex) continue_search() @@ -192,7 +303,8 @@ def continue_search(): all_configurations: list[list[Executor]] = [] # Is the name reserved? - empty_executor = Executor(name=self.empty_executor_hashable_placeholder) + empty_executor = Executor( + name=self.empty_executor_hashable_placeholder) if len(bound_symbols) > 0: search(self.SearchNode(bound_symbols[0], 0), []) @@ -204,35 +316,34 @@ def continue_search(): # Fusion operators as nvFuser can be slower on the single trace region but can be faster by combining more of them, # try to fuse then and compare def try_to_fuse_after_executors_placement(self, trace_in: TraceCtx) -> TraceCtx: - best_trace: TraceCtx = trace_in best_time, best_mem, answer = benchmark_trace(best_trace, iters=10) del answer trace_in_time = best_time for ex in self.fusion_executors: - self.log(f'Try to fuse executor {ex.name} with trace:\n{trace_in}') + self.log(f"Try to fuse executor {ex.name} with trace:\n{trace_in}") extrace = ex.fusion_pass(trace_in) - self.log(f'Fused trace:\n{extrace}') - extrace_time, extrace_mem, answer = benchmark_trace(extrace, iters=10) + self.log(f"Fused trace:\n{extrace}") + extrace_time, extrace_mem, answer = benchmark_trace( + extrace, iters=10) del answer - self.log(f'Fused trace time:{extrace_time} ms') + self.log(f"Fused trace time:{extrace_time} ms") if extrace_time < best_time: best_time = extrace_time best_trace = extrace - self.log(f'Trace in (time = {trace_in_time } ms):\n{trace_in}') - self.log(f'Best fused trace (time = {best_time } ms):\n{best_trace}') + self.log(f"Trace in (time = {trace_in_time } ms):\n{trace_in}") + self.log(f"Best fused trace (time = {best_time } ms):\n{best_trace}") return best_trace def build_placement_options_exaustive(self): - # We assign an internal id to each symbol based on its idx inside the bound_symbols list def search(node: self.SearchNode, configuration): def continue_search(): - if node.idx+1 < max_len: + if node.idx + 1 < max_len: new_idx: int = node.idx + 1 new_symbol: BoundSymbolInterface = bound_symbols[new_idx] search(self.SearchNode(new_symbol, new_idx), configuration) @@ -243,13 +354,14 @@ def continue_search(): has_backend = False for ex in self.executors: if not isinstance(node.symbol, BoundSymbol): - raise AssertionError("Receive a symbol which is not a BoundSymbol") - if (isinstance(ex, OperatorExecutor) and ex.can_execute(node.symbol)): + raise AssertionError( + "Receive a symbol which is not a BoundSymbol") + if isinstance(ex, OperatorExecutor) and ex.can_execute(node.symbol): has_backend = True configuration.append(ex) continue_search() configuration.pop(-1) - if (isinstance(ex, FusionExecutor) and ex.can_fuse(node.symbol)): + if isinstance(ex, FusionExecutor) and ex.can_fuse(node.symbol): has_backend = True configuration.append(ex) continue_search() @@ -265,14 +377,14 @@ def continue_search(): all_configurations: list[list[Executor]] = [] # Is the name reserved? - empty_executor = Executor(name=self.empty_executor_hashable_placeholder) + empty_executor = Executor( + name=self.empty_executor_hashable_placeholder) if len(bound_symbols) > 0: search(self.SearchNode(bound_symbols[0], 0), []) self.placement_options = all_configurations def place_optimizers(self, in_trace, executor_list: list[Executor]) -> TraceCtx: - from thunder.executors.passes import _transform_for_operator_executor_execution swapmap: dict[Variable, Proxy] = {} @@ -280,7 +392,6 @@ def place_optimizers(self, in_trace, executor_list: list[Executor]) -> TraceCtx: # During the fusion pass and CSE optimizatons some args in trace regions could be different from the cached args. Restore the correct arguments # https://pytorch-lightning.slack.com/archives/C06QA9M8L3C/p1720732254341999 def restore_correct_args(trace_in: TraceCtx): - def args_eq(a, b) -> bool: if len(a) != len(b): return False @@ -290,7 +401,8 @@ def args_eq(a, b) -> bool: return False elif type(obj_a) == type(obj_b) and not isinstance(obj_a, TensorProxy): if obj_a != obj_b: - raise AssertionError(f'What do you want to do here:\nobj_a:\n{obj_a}\nobj_b:{obj_b}') + raise AssertionError( + f"What do you want to do here:\nobj_a:\n{obj_a}\nobj_b:{obj_b}") return True def clear(bsym: BoundSymbol, input): @@ -321,7 +433,7 @@ def update_swapmap(o: Any, no: Any) -> None: def preserve_bsym(bsym: BoundSymbol) -> Any: trace: TraceCtx | None = get_tracectx() if trace is None: - raise AssertionError('None trace context') + raise AssertionError("None trace context") trace.scopes[-1].append(bsym) for p in chain(bsym.flat_proxy_outs, bsym.flat_proxy_args): trace.names.add(p.name) @@ -335,16 +447,16 @@ def visit_helper(bsym: BoundSymbol, ex: Executor) -> None | bool: if ex.name == self.empty_executor_hashable_placeholder: return None - execution_transform: None | Callable = ex.get_execution_transform(bsym.sym) + execution_transform: None | Callable = ex.get_execution_transform( + bsym.sym) out: Any - # TODO: What is this? if execution_transform is not None: out = execution_transform(*bsym.args, **bsym.kwargs) elif isinstance(ex, OperatorExecutor): # Calls the operator executor's operation op: Symbol | None = ex.implmap[bsym.sym.id].symbol if op is None: - raise AssertionError('op is None') + raise AssertionError("op is None") out = op(*bsym.args, **bsym.kwargs) elif isinstance(ex, FusionExecutor): # Preserves the symbol as is (it will be handled in the fusion pass) @@ -359,34 +471,52 @@ def visit_helper(bsym: BoundSymbol, ex: Executor) -> None | bool: def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: return transforms.VISIT_TYPE.NO_OP if visit_helper(bsym, ex) is None else transforms.VISIT_TYPE.REPLACE - extrace = transforms.visitor_transform_paired(in_trace, visit, zip(in_trace.bound_symbols, executor_list)) + if len(executor_list) != len(in_trace.bound_symbols): + raise AssertionError( + "len(executor_list) != len(in_trace.bound_symbols)") + + # self.log(f'Visit transf') + # for n, e in zip(in_trace.bound_symbols, executor_list): + # print(f'{n.sym.name} -> {e.name}') + cached_subsymbols: dict[str, Sequence[BoundSymbol]] = {} + executor_mapping: dict[str, Executor] = {} + unique_fusion_executors = set() + + # Input should have equal length + if len(executor_list) != len(in_trace.bound_symbols): + raise AssertionError( + "len(executor_list) != len(extrace.bound_symbols)") + + for b, e in zip(in_trace.bound_symbols, executor_list): + if isinstance(e, FusionExecutor): + unique_fusion_executors.add(e) + if isinstance(b.output, TensorProxy): + executor_mapping[b.output.name] = e + + extrace = transforms.visitor_transform_paired( + in_trace, visit, zip(in_trace.bound_symbols, executor_list)) # Restores original variables bound_symbols: list[BoundSymbol] = [] for bsym in extrace.bound_symbols: nbsym: BoundSymbol = bsym.from_bsym_swap_proxies(swapmap) bound_symbols.append(nbsym) - extrace.bound_symbols = bound_symbols - unique_fusion_executors = set() - cached_subsymbols: dict[str, Sequence[BoundSymbol]] = {} - - if len(executor_list) != len(extrace.bound_symbols): - raise AssertionError("len(executor_list) != len(extrace.bound_symbols)") - - for ex, bsym in zip(executor_list, extrace.bound_symbols): - if isinstance(ex, FusionExecutor): - unique_fusion_executors.add(ex) - elif isinstance(ex, OperatorExecutor): - if isinstance(bsym.output, TensorProxy): - t_proxy_name: str = bsym.output.name - cached_subsymbols[t_proxy_name] = list(bsym.subsymbols) + for bsym in extrace.bound_symbols: + if isinstance(bsym.output, TensorProxy): + t_name = bsym.output.name + if t_name not in executor_mapping: + # Symbol added by the visitor + continue + # raise AssertionError('Failed to retrive key in mapping') + saved_ex = executor_mapping[t_name] + if isinstance(saved_ex, OperatorExecutor): + cached_subsymbols[t_name] = list(bsym.subsymbols) # This will leave out these symbols from the fusion pass bsym.subsymbols = [] # Perform fusion pass - # TODO (matteochen): filter for the current fusion operator as we wanna find the most efficient one for ex in unique_fusion_executors: extrace = ex.fusion_pass(extrace) @@ -406,31 +536,57 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: restore_correct_args(extrace) # Apply always executors - extrace = _transform_for_operator_executor_execution(extrace, self.always_executors) + extrace = _transform_for_operator_executor_execution( + extrace, self.always_executors) return extrace # TODO (matteochen): add config for exaustive search or incremental one - def optimize(self, strat: OptimizationStrat = OptimizationStrat.BEST_FUSER): + def optimize(self, strat: OptimizationAlgorithm = OptimizationAlgorithm.BEST_FUSER): import thunder.core.codeutils as cutils + from thunder.executors.passes import transform_for_execution + from thunder.core.transform_common import replace_redundant_inputs + from thunder.core.transform_common import dce - self.strat = strat + self.optimization_algorithm = strat - from thunder.executors.passes import transform_for_execution def best_fuser(): + # Reset fusion helpers + self.fusion_strat_helper = FusionStratHelper() + # Reset helpers data structures + self.executor_placement_options = ExecutorPlacementOptions() + self.build_placement_options_fusion_regions() - if len(self.placement_options_time) != len(self.fusion_executors): - raise AssertionError("Unexpected time placement options size") - if len(self.placement_options_mem) != len(self.fusion_executors): - raise AssertionError("Unexpected mem placement options size") + if len(self.executor_placement_options.placement_options_time) != len(self.fusion_executors): + raise AssertionError( + f"Unexpected time placement options size: {len(self.executor_placement_options.placement_options_time)}. Expected: {len(self.fusion_executors)}" + ) + if len(self.executor_placement_options.placement_options_mem) != len(self.fusion_executors): + raise AssertionError( + f"Unexpected mem placement options size: {len(self.executor_placement_options.placement_options_mem)}. Expected: {len(self.fusion_executors)}" + ) + + for placement, ex in zip(self.executor_placement_options.placement_options_time, self.fusion_executors): + self.fusion_strat_helper.optimized_traces_time.append( + {ex.name: self.place_optimizers(self.trace, placement)} + ) + for placement, ex in zip(self.executor_placement_options.placement_options_mem, self.fusion_executors): + self.fusion_strat_helper.optimized_traces_mem.append( + {ex.name: self.place_optimizers(self.trace, placement)} + ) + + self.benchmark_traces() - for placement, ex in zip(self.placement_options_time, self.fusion_executors): - self.optimized_traces_time.append({ex.name: self.place_optimizers(self.trace, placement)}) - for placement, ex in zip(self.placement_options_mem, self.fusion_executors): - self.optimized_traces_mem.append({ex.name: self.place_optimizers(self.trace, placement)}) + # Cached computed fw traced placements + if self.trace_type == TraceType.FW: + self.cached_fw_traces = self.fw_trace_candidates + self.log(f"Caching fw traces\n{self.cached_fw_traces}") def greedy(): + # Reset helpers data structures + self.executor_placement_options = ExecutorPlacementOptions() + # 1. This builds one option by default self.build_placement_options_incremental() @@ -440,17 +596,19 @@ def greedy(): option = self.placement_options[0] trace_greedy = self.place_optimizers(self.trace, option) # Append the unique trace - self.optimized_traces.append({'greedy': trace_greedy}) + self.optimized_traces.append({"greedy": trace_greedy}) # 2. Try to fuse additional regions from the greedy result # Attention, if all the fused traces perform worse that the greedy one, the greedy one is returned # TODO (matteochen): ignore a duplicated trace - trace_greedy_fused = self.try_to_fuse_after_executors_placement(trace_greedy) - self.optimized_traces.append({'fused_greedy': trace_greedy_fused}) + trace_greedy_fused = self.try_to_fuse_after_executors_placement( + trace_greedy) + self.optimized_traces.append({"fused_greedy": trace_greedy_fused}) # 3. Try the priority list approach - trace_priority = transform_for_execution(self.trace, self.priority_executors) - self.optimized_traces.append({'priority_list': trace_priority}) + trace_priority = transform_for_execution( + self.trace, self.priority_executors) + self.optimized_traces.append({"priority_list": trace_priority}) # There are no hidden placements hence do not call the visualizer @@ -458,45 +616,122 @@ def exaustive(): # This builds one option by default self.build_placement_options_exaustive() - self.log(f'Placement options size: {len(self.placement_options)}') + self.log(f"Placement options size: {len(self.placement_options)}") for option in self.placement_options: option_str = [str(ex.name) for ex in option] - option_str = '-'.join(option_str) + option_str = "-".join(option_str) trace = self.place_optimizers(self.trace, option) if self.visualizer is not None: sig_name = cutils.get_siginfo_name(trace) # TODO (matteochen): consider adding more infos for naming - self.visualizer.set_hidden_trace(f'hidden-{sig_name}-{option_str}', trace) + self.visualizer.set_hidden_trace( + f"hidden-{sig_name}-{option_str}", trace) self.optimized_traces.append({option_str: trace}) - if strat == self.OptimizationStrat.GREEDY: - greedy() - elif strat == self.OptimizationStrat.EXAUSTIVE: - exaustive() - elif strat == self.OptimizationStrat.BEST_FUSER: - best_fuser() + def match_optimizer_algorithm(): + match self.optimization_algorithm: + case OptimizationAlgorithm.GREEDY: + greedy() + case OptimizationAlgorithm.EXAUSTIVE: + exaustive() + case OptimizationAlgorithm.BEST_FUSER: + best_fuser() + + start_time = time.perf_counter_ns() + + match self.trace_type: + case TraceType.FW: + match_optimizer_algorithm() + # We have multiple cached optimized fw traces, find the best backward + case TraceType.BW: + fw_traces = self.fw_trace_candidates.to_list() + # Cached the bw trace as we need to modify the input trace during the loop + cached_self_trace = from_trace(self.trace) + cached_self_trace.bound_symbols = list( + self.trace.bound_symbols) + for t in fw_traces: + # Restore the original bw trace + self.trace = from_trace(cached_self_trace) + self.trace.bound_symbols = list( + cached_self_trace.bound_symbols) + # Set the current active cached forward trace + self.cached_fw_trace = t + + self.log(f"Cached fw trace:\n{self.cached_fw_trace}") + self.log(f"Input bw trace:\n{self.trace}") + + # Some of the optimization passes change proxies in the trace and + # any change in the forward trace must be reflected in the backward + # trace. + original_bw_saved_tensors_for_backward = self.trace.args[0][0] + new_fw_saved_tensors_for_backward = t.output[1][0] + swap_map = { + variableify(x): y + for x, y in zip(original_bw_saved_tensors_for_backward, new_fw_saved_tensors_for_backward) + if variableify(x) != variableify(y) + } + new_bsyms = replace_redundant_inputs( + swap_map, self.trace.bound_symbols) + # replace_redundant_inputs doesn't replace the output of + # UNPACK_SEQUENCE so we do it manually. Here we have certain + # assumptions about the structure of the backward trace. + assert self.trace.bound_symbols[0].sym.id == PrimIDs.UNPACK_TRIVIAL + assert self.trace.bound_symbols[0].kwargs["name"] == "saved_for_backward" + assert self.trace.bound_symbols[4].sym.id == PrimIDs.UNPACK_SEQUENCE + assert self.trace.bound_symbols[4].args[0].name == "C0" + new_bsyms[4] = new_bsyms[4].from_bsym_swap_proxies( + swap_map, + skip_inputs=False, + skip_output=False, + skip_subsymbols=False, + ) + self.trace.bound_symbols = new_bsyms + + if self.apply_bucketing_bw_trace: + from thunder.distributed.transforms import FSDPCommBucketing + + self.trace = FSDPCommBucketing.apply_bucketing_to_backward_trace( + self.trace) + + # Not called in the constructor for bw traces + dce(self.trace) + + match_optimizer_algorithm() + + end_time = time.perf_counter_ns() + if self.trace_type == TraceType.FW: + self.fw_trace_candidates.assign_time_took( + (end_time - start_time) // 1000000) else: - raise AssertionError('Optimization strat not implemented') + self.bw_trace_candidates.assign_time_took( + (end_time - start_time) // 1000000) - def build_placement_options_fusion_regions(self, increment_factor:int = 1): + def build_placement_options_fusion_regions(self, increment_factor: int = 1): from thunder.executors.data_dependent_partition import Node, fuse_bound_symbols + from thunder.core.rematerialization import rematerialize_forward_and_backward def sequence_hash(s: Sequence) -> str: name = "" for e in s: - if isinstance(e, CollectionProxy) or isinstance(e, TensorProxy): + if ( + isinstance(e, CollectionProxy) + or isinstance(e, TensorProxy) + or isinstance(e, IntegerProxy) + or isinstance(e, FloatProxy) + ): name += e.name elif e is None: name += "None" else: - raise AssertionError(f'What? type = {type(e)}') + raise AssertionError( + f"What? Maybe nested Sequence. type = {type(e)}") return name # TODO (matteochen): Benchmark the optimal executor and call this optimal - def get_default_executor(bsym: BoundSymbol): + def get_optimal_executor(bsym: BoundSymbol): for ex in self.executors: if isinstance(ex, FusionExecutor): continue @@ -505,133 +740,74 @@ def get_default_executor(bsym: BoundSymbol): return Executor(name=self.empty_executor_hashable_placeholder) def get_placed_trace(mapping: dict[str, Executor], bound_symbols_in: Sequence[BoundSymbol]): - self.log(f'Input mapping len = {len(mapping)}:') - self.log(f'Input bound_symbols len = {len(bound_symbols_in)}:') + # self.log(f'Input mapping len = {len(mapping)}:') + # self.log(f'Input bound_symbols len = {len(bound_symbols_in)}:') trc = from_trace(self.trace) trc.bound_symbols = list(bound_symbols_in) - # for b in trc.bound_symbols: - # print(b.sym.name) - - # print(f'trc:\n{trc}') - - # def find_original_return_tensors(trace_in: TraceCtx) -> list[Any]: - # return_bsym = trace_in.bound_symbols[-1] - # if return_bsym.sym.name != 'return': - # raise AssertionError(f'Expected return symbol got {return_bsym.sym.name}') - - # ans = [] - # if isinstance(return_bsym.args, tuple): - # # forward trace - # if isinstance(return_bsym.args[0], dict): - # ans.append(return_bsym.args[0]['output']) - # # backward trace - # else: - # ans.extend([s for s in return_bsym.args if s is not None]) - # else: - # raise AssertionError('Not supported') - - # return ans - - # def find_last_out_tensor(trace_in: TraceCtx): - # m = 0 - # t = None - # for b in trace_in.bound_symbols: - # if b.sym.name == 'return': - # continue - # if isinstance(b.output, TensorProxy): - # if is_possible_out(b.output.name) and int(b.output.name[1:]) > m: - # m = int(b.output.name[1:]) - # t = b.output - # # else: - # # raise AssertionError(f'Not implemented, type = {type(b.output)}') - # if t is None: - # raise AssertionError('Max tensor output not found') - # print(f'max tensor out name: {t}') - # return t - - # def is_tensor_in_bsyms(t: TensorProxy | tuple): - # def handle_tuple(tup: tuple): - # for e in tup: - # if isinstance(e, TensorProxy): - # for b in trc.bound_symbols: - # if b is not None: - # if isinstance(b.output, TensorProxy): - # if b.output.name == e.name: - # return b.output - # else: - # raise AssertionError('Not supported') - - # if isinstance(t, TensorProxy): - # for b in trc.bound_symbols: - # if b is not None: - # if isinstance(b.output, TensorProxy): - # if b.output.name == t.name: - # return b.output - # return None - # else: - # handle_tuple(t) - - - # tensors = [] - # for b in bound_symbols_in: - # if isinstance(b.output, TensorProxy): - # tensors.append(b.output) - # We include always the last tensor as output of the partial trace + all the already - # available out tensor present in the original trace in order to not be discarded from the dce - # tensors = [find_last_out_tensor(trc)] - # original_returns = find_original_return_tensors(self.trace) - # for t in original_returns: - # # TODO (matteochen): improve this - # res = is_tensor_in_bsyms(t) - # if res is not None: - # tensors.append(res) - # For this partial trace we have to return all not used tensors otherwise the dce will cut them out - tensors = return_not_used(trc) + tensors = return_not_used_vars(trc) - forced_return_bsym = self.trace.bound_symbols[-1].from_bsym(args=tensors) + forced_return_bsym = self.trace.bound_symbols[-1].from_bsym( + args=tensors) executor_configuration = [] - empty_executor = Executor(name=self.empty_executor_hashable_placeholder) + empty_executor = Executor( + name=self.empty_executor_hashable_placeholder) keys = [] - for bsym in trc.bound_symbols: - if bsym.sym.name == 'return': - raise AssertionError('return statement should not be here') - executor_configuration.append(empty_executor) - keys.append('return') + for bsym in trc.bound_symbols: + if bsym.sym.name == "return": + raise AssertionError("Return statement should not be here") + # executor_configuration.append(empty_executor) + # keys.append('return') elif isinstance(bsym.output, Sequence): seq_hash = sequence_hash(bsym.output) - executor_configuration.append(mapping.get(seq_hash, empty_executor)) + executor_configuration.append( + mapping.get(seq_hash, empty_executor)) keys.append(seq_hash) - elif isinstance(bsym.output, CollectionProxy) or isinstance(bsym.output, TensorProxy): + elif ( + isinstance(bsym.output, CollectionProxy) + or isinstance(bsym.output, TensorProxy) + or isinstance(bsym.output, IntegerProxy) + or isinstance(bsym.output, FloatProxy) + ): if bsym.output.name not in mapping: - raise AssertionError(f'Expected key {bsym.output.name} in mapping {mapping}') + raise AssertionError( + f"Expected key {bsym.output.name} in mapping {mapping}") executor_configuration.append(mapping[bsym.output.name]) keys.append(bsym.output.name) else: - raise AssertionError(f"Type not handled: {type(bsym.output)}") + raise AssertionError( + f"Type not handled: {type(bsym.output)}") - if trc.bound_symbols[-1].sym.name != 'return': + if trc.bound_symbols[-1].sym.name != "return": trc.bound_symbols.append(forced_return_bsym) - executor_configuration.append(Executor(name=self.empty_executor_hashable_placeholder)) - keys.append('return') + executor_configuration.append( + Executor(name=self.empty_executor_hashable_placeholder)) + keys.append("return") if len(trc.bound_symbols) != len(executor_configuration) or len(keys) != len(executor_configuration): - raise AssertionError(f'len trc.bound_symbols ({len(trc.bound_symbols)}) != len executor_configuration ({len(executor_configuration)}) != len keys ({len(keys)})') + raise AssertionError( + f"len trc.bound_symbols ({len(trc.bound_symbols)}) != len executor_configuration ({len(executor_configuration)}) != len keys ({len(keys)})" + ) + + # for b, e in zip(trc.bound_symbols, executor_configuration): + # if isinstance(b.output, TensorProxy): + # print(f'{b.sym.name}: {b.output.name} -> {e.name}') - # self.log(f'Before placement trc:\n{trc}') placed_trace = self.place_optimizers(trc, executor_configuration) return placed_trace, keys, executor_configuration ex: FusionExecutor for ex in self.fusion_executors: + if ex.name not in self.fusion_strat_helper.supported_executors: + raise AssertionError( + f"Fusion operator not supported: {ex.name}") - if ex.name not in self.supported_fusion_executors_by_fusion_strat: - raise AssertionError(f'Fusion operator not supported: {ex.name}') + self.log( + f"Searching best placement for fusion executor = {ex.name}") - self.log(f'Searching best placement for fusion executor = {ex.name}') - # TODO (matteochen): each executor has a custo def + # TODO (matteochen): each executor has a custom should fuse function, can we make this prettier? def _should_fuse_nvfuser(a: Node, b: Node): def _can_fuse_node(n: Node): # if already merged, then node can be fused @@ -641,6 +817,7 @@ def _can_fuse_node(n: Node): can_fuse: bool = ex.can_fuse(bsym) cuda_in_or_out: bool = ex.has_cuda_input_or_output(bsym) return can_fuse and cuda_in_or_out + return _can_fuse_node(a) and _can_fuse_node(b) def _should_fuse_torchcompile(a: Node, b: Node): @@ -649,44 +826,62 @@ def _can_fuse_node(n: Node): return True bsym: BoundSymbol = n.group_bsyms[0] return ex.can_fuse(bsym) + return _can_fuse_node(a) and _can_fuse_node(b) - bound_symbol_groups =fuse_bound_symbols(self.trace, _should_fuse_nvfuser if ex.name == 'nvfuser' else _should_fuse_torchcompile) - self.log(f'Num of groups = {len(bound_symbol_groups)}') + def match_bsym_output(bsym_in: BoundSymbol, time_dict: dict, mem_dict: dict, ex_in: Executor): + if isinstance(bsym_in.output, Sequence): + time_dict[sequence_hash(bsym_in.output)] = ex_in + mem_dict[sequence_hash(bsym_in.output)] = ex_in + elif ( + isinstance(bsym_in.output, CollectionProxy) + or isinstance(bsym_in.output, TensorProxy) + or isinstance(bsym_in.output, IntegerProxy) + or isinstance(bsym_in.output, FloatProxy) + ): + time_dict[bsym_in.output.name] = ex_in + mem_dict[bsym_in.output.name] = ex_in + else: + raise AssertionError( + f"Type not handled: {type(bsym_in.output)}") + + bound_symbol_groups = fuse_bound_symbols( + self.trace, _should_fuse_nvfuser if ex.name == "nvfuser" else _should_fuse_torchcompile + ) + self.log(f"Num of groups = {len(bound_symbol_groups)}") - for group in bound_symbol_groups: + for id, group in enumerate(bound_symbol_groups): + self.log(f"Group id: {id}") for sub in group: - self.log(f'{sub.sym.name} -> out: {sub.output}') + self.log(f"{sub.sym.name} -> out: {sub.output}") if len(group) > 0: - print('\n') + print("\n") - map_time: dict[str, Executor] = {} - map_mem: dict[str, Executor] = {} + dict_time_strat: dict[str, Executor] = {} + dict_mem_strat: dict[str, Executor] = {} increasing_symbols = [] for group_id, group in enumerate(bound_symbol_groups): - self.log(f'group start = {group[0].sym.name}') - self.log(f'group end = {group[-1].sym.name}') + self.log(f"Group id: {group_id}") + self.log(f"group start = {group[0].sym.name}") + self.log(f"group end = {group[-1].sym.name}") - if group[0].sym.name != 'return': + if group[0].sym.name != "return": increasing_symbols += group - # Is not a fusion region, get the default executor + # Is not a fusion region, get the optimal executor if len(group) < 2: - symbol = group[0] - self.log(f'--> Single group: {symbol.sym.name}') - name = symbol.sym.name - ex_for_this = get_default_executor(symbol) - if name == 'return': - map_time['return'] = ex_for_this - map_mem['return'] = ex_for_this + current_bsym = group[0] + self.log(f"--> Single group: {current_bsym.sym.name}") + name = current_bsym.sym.name + optimal_ex = get_optimal_executor(current_bsym) + if name == "return": + dict_time_strat["return"] = optimal_ex + dict_mem_strat["return"] = optimal_ex # Add the modified return statement at the end of the for loop break - elif isinstance(symbol.output, Sequence): - map_time[sequence_hash(symbol.output)] = ex_for_this - map_mem[sequence_hash(symbol.output)] = ex_for_this - elif isinstance(symbol.output, CollectionProxy) or isinstance(symbol.output, TensorProxy): - map_time[symbol.output.name] = ex_for_this - map_mem[symbol.output.name] = ex_for_this + # before was ex??? + match_bsym_output( + current_bsym, dict_time_strat, dict_mem_strat, optimal_ex) continue # Inside groups we should have alwasy tensors as out @@ -694,245 +889,405 @@ def _can_fuse_node(n: Node): best_res_mem = self.Result() worst_res_time = self.Result() worst_res_mem = self.Result() + # Only for visual worst_res_mem.measure = 0 worst_res_time.measure = 0 + # TODO (matteochen): Aggregate them best_placement_time = None best_keys_time = None best_placement_mem = None best_keys_mem = None - # Each iteration of this loop will have map_time = map_mem, hence we use and fill only map_time - # Best time and best mem will be recorded separatedly though - for i in range(len(group)): - # From top to bottom (this will include the whole region) - # -> First iteration is the one with fusion region with single element - # -> Last iteration gives the complete fusion region - for j in range(0, i+1, increment_factor): - map_time[group[j].output.name] = ex - map_mem[group[j].output.name] = ex - for k in range(i+1, len(group), increment_factor): - map_time[group[k].output.name] = get_default_executor(group[k]) - map_mem[group[k].output.name] = get_default_executor(group[k]) - # Benchmark this placement - trc, keys, placements = get_placed_trace(map_time, increasing_symbols) - cost, mem, out = benchmark_trace(trc, iters=1) + def measure_and_update_result(): + nonlocal best_res_time + nonlocal best_placement_time + nonlocal best_keys_time + nonlocal worst_res_time + nonlocal best_res_mem + nonlocal best_placement_mem + nonlocal best_keys_mem + nonlocal worst_res_mem + trc, keys, placements = get_placed_trace( + dict_time_strat, increasing_symbols) + if self.trace_type == TraceType.BW and self.cached_fw_trace is not None: + _, trc = rematerialize_forward_and_backward( + self.cached_fw_trace, trc) + cost, mem, out = benchmark_trace(trc, iters=3) del out - self.log(f'Placed trace (cost = {cost} ms, mem = {mem/(2**30)} GB)\n{trc}') - if cost < best_res_time.measure: - best_res_time.measure = cost + self.log( + f"Placed trace (cost = {cost} ms, mem = {mem/(2**30)} GB)\n{trc}") + if cost < best_res_time.tm or (cost == best_res_time.tm and mem < best_res_time.mem): + best_res_time.tm = cost + best_res_time.mem = mem best_res_time.trace = trc best_placement_time = placements best_keys_time = keys - if cost > worst_res_time.measure: - worst_res_time.measure = cost + if cost > worst_res_time.tm: + worst_res_time.tm = cost - if mem < best_res_mem.measure: - best_res_mem.measure = mem + if mem < best_res_mem.mem or (mem == best_res_mem.mem and cost < best_res_mem.tm): + best_res_mem.tm = cost + best_res_mem.mem = mem best_res_mem.trace = trc best_placement_mem = placements best_keys_mem = keys - if mem > worst_res_mem.measure: - worst_res_mem.measure = mem + if mem > worst_res_mem.mem: + worst_res_mem.mem = mem + + start_idx = 0 + # This is to accomodate the following TODO + # TODO: investigate why is failing with torchcompile if left alone + if ex.name == "torchcompile": + last_embedding_idx = -1 + for idx in range(0, len(group)): + if group[idx].sym.name == "embedding_backward": + last_embedding_idx = idx + self.log(f"last embedding {last_embedding_idx}") + if last_embedding_idx != -1: + # Until last_embedding_idx (included) assigned to current fusion ex + for i in range(0, last_embedding_idx + 1, 1): + match_bsym_output( + group[i], dict_time_strat, dict_mem_strat, ex) + + if last_embedding_idx == len(group) - 1: + # Benchmark + measure_and_update_result() + + start_idx = last_embedding_idx + 1 + + n_missing_bsyms = len(group) - start_idx + for i in range(0, n_missing_bsyms): + # From top to bottom (this will include the whole region) + # -> First iteration is the one with fusion region with single element + # -> Last iteration gives the complete fusion region + for j in range(start_idx, start_idx + i + 1, increment_factor): + match_bsym_output( + group[j], dict_time_strat, dict_mem_strat, ex) + for k in range(start_idx + i + 1, len(group), increment_factor): + match_bsym_output( + group[k], dict_time_strat, dict_mem_strat, get_optimal_executor( + group[k + start_idx]) + ) + # Benchmark + measure_and_update_result() + + # TODO (matteochen): consider if this can increase placement # From bottom to up (this will exclude the full region as being handled in the for cycle above) # -> First iteration is the one with len(fusion_region) - 1 # -> Last iteration gives no fusion regions # for j in range(0, i+1, increment_factor): - # map_time[group[j].output.name] = get_default_executor(group[j]) + # dict_time_strat[group[j].output.name] = get_default_executor(group[j]) # for k in range(i+1, len(group), increment_factor): - # map_time[group[k].output.name] = ex + # dict_time_strat[group[k].output.name] = ex # Benchmark this placement - # trc, keys, placements = get_placed_trace(map_time, increasing_symbols) - # cost, out = benchmark_trace(trc, iters=2) - # del out - # self.log(f'Placed trace (cost = {cost } ms)\n{trc}') - # if cost < best_time: - # best_time = cost - # best_trc = trc - # best_placement = placements - # best_keys = keys + # measure_and_update_result() + if best_placement_time is None or best_keys_time is None: - raise AssertionError('Failed to get best placement') + raise AssertionError("Failed to get best time placement") if best_placement_mem is None or best_keys_mem is None: - raise AssertionError('Failed to get best placement') + raise AssertionError("Failed to get best placement") - self.log(f'For group {group_id} best placement with time cost = {best_res_time.measure} ms (worst time = {worst_res_time.measure} ms):\n{best_res_time.trace}') - self.log(f'For group {group_id} best placement with mem cost = {best_res_mem.measure / (2**30)} GB (worst mem = {worst_res_mem.measure/(2**30)} GB) is:\n{best_res_mem.trace}') + self.log( + f"For group {group_id} best placement with time cost = {best_res_time.tm} ms (worst time = {worst_res_time.tm} ms):\n{best_res_time.trace}" + ) + self.log( + f"For group {group_id} best placement with mem cost = {best_res_mem.mem / (2**30)} GB (worst mem = {worst_res_mem.mem/(2**30)} GB) is:\n{best_res_mem.trace}" + ) # for n, p in zip(best_keys, best_placement): # print(f'{n} -> {p.name}') # Update our dict for n, p in zip(best_keys_time, best_placement_time): - map_time |= {n: p} + dict_time_strat |= {n: p} # Update our dict for n, p in zip(best_keys_mem, best_placement_mem): - map_mem |= {n: p} - - # self.log('End of group search') - # pprint.pprint(map_time) - - # print('map cmp') - # for k in map_time.keys(): - # if k not in map_mem: - # pprint.pprint(map_time) - # pprint.pprint(map_mem) - # raise AssertionError(f"cannot find {k}") - # pprint.pprint(map_time) - # pprint.pprint(map_mem) + dict_mem_strat |= {n: p} # Generate the placement executors_time = [] executors_mem = [] - for bsym in self.trace.bound_symbols: - if bsym.sym.name == 'return': - if 'return' not in map_time or 'return' not in map_mem: - raise AssertionError(f'Expected key return in mapping {map_time} and {map_mem}') - executors_time.append(map_time['return']) - executors_mem.append(map_mem['return']) + for bsym in self.trace.bound_symbols: + if bsym.sym.name == "return": + if "return" not in dict_time_strat or "return" not in dict_mem_strat: + raise AssertionError( + f"Expected key return in mapping {dict_time_strat} and {dict_mem_strat}") + executors_time.append(dict_time_strat["return"]) + executors_mem.append(dict_mem_strat["return"]) elif isinstance(bsym.output, Sequence): seq_hash = sequence_hash(bsym.output) - if seq_hash not in map_time or seq_hash not in map_mem: - raise AssertionError(f'Expected key {seq_hash} in mapping {map_time} and {map_mem}') - executors_time.append(map_time[seq_hash]) - executors_mem.append(map_mem[seq_hash]) - elif isinstance(bsym.output, CollectionProxy) or isinstance(bsym.output, TensorProxy): - if bsym.output.name not in map_time or bsym.output.name not in map_mem: - raise AssertionError(f'Expected key {bsym.output.name} in mapping {map_time} and {map_mem}') - executors_time.append(map_time[bsym.output.name]) - executors_mem.append(map_mem[bsym.output.name]) + if seq_hash not in dict_time_strat or seq_hash not in dict_mem_strat: + raise AssertionError( + f"Expected key {seq_hash} in mapping {dict_time_strat} and {dict_mem_strat}" + ) + executors_time.append(dict_time_strat[seq_hash]) + executors_mem.append(dict_mem_strat[seq_hash]) + elif ( + isinstance(bsym.output, CollectionProxy) + or isinstance(bsym.output, TensorProxy) + or isinstance(bsym.output, IntegerProxy) + or isinstance(bsym.output, FloatProxy) + ): + if bsym.output.name not in dict_time_strat or bsym.output.name not in dict_mem_strat: + raise AssertionError( + f"Expected key {bsym.output.name} in mapping {dict_time_strat} and {dict_mem_strat}" + ) + executors_time.append(dict_time_strat[bsym.output.name]) + executors_mem.append(dict_mem_strat[bsym.output.name]) else: - raise AssertionError(f"Type not handled: {type(bsym.output)}") - - # Swap return bsym otherwise with no call to remat, we will trace the wrong memory occupation - test_trc = from_trace(self.trace) - test_trc.bound_symbols = list(self.trace.bound_symbols) - test_trc.bound_symbols.pop() - test_trc.bound_symbols.append(self.trace.bound_symbols[-1].from_bsym(args=return_not_used(test_trc))) - trc = self.place_optimizers(test_trc, executors_mem) - c, m, o = benchmark_trace(trc) - del o - self.log(f'Debug MEM, mem = {m/(2**30)} GB:\n{trc}') - self.optimized_traces_mem_benchmark_only.append({ex.name: trc}) - trc = self.place_optimizers(test_trc, executors_time) - c, m, o = benchmark_trace(trc) - del o - self.log(f'Debug TIME, time = {c} ms:\n{trc}') - self.optimized_traces_time_benchmark_only.append({ex.name: trc}) + raise AssertionError( + f"Type not handled: {type(bsym.output)}") + + # For the forward trace we benchmark (memory) the mocked return statement as we don't know which + # Tensor will be returned after the rematerialize_forward_and_backward() call in order to do not overestimate the memory consumption + if self.trace_type == TraceType.FW: + trc = from_trace(self.trace) + trc.bound_symbols = list(self.trace.bound_symbols) + trc.bound_symbols.pop() + trc.bound_symbols.append( + self.trace.bound_symbols[-1].from_bsym(args=return_not_used_vars(trc))) + # NOTE: Here the active trace to place will be 'trc' and not 'self.trace' + trc_time = self.place_optimizers(trc, executors_mem) + c, m, o = benchmark_trace(trc_time) + del o + self.log(f"Debug MEM, mem = {m/(2**30)} GB:\n{trc_time}") + self.fusion_strat_helper.optimized_traces_mem_benchmark_only.append({ + ex.name: trc_time}) + trc_mem = self.place_optimizers(trc, executors_time) + c, m, o = benchmark_trace(trc_mem) + del o + self.log(f"Debug TIME, time = {c} ms:\n{trc_mem}") + self.fusion_strat_helper.optimized_traces_time_benchmark_only.append({ + ex.name: trc_mem}) + else: + trc = self.place_optimizers(self.trace, executors_mem) + _, trc = rematerialize_forward_and_backward( + self.cached_fw_trace, trc) + c, m, o = benchmark_trace(trc) + del o + self.log(f"Debug MEM, mem = {m/(2**30)} GB:\n{trc}") + self.fusion_strat_helper.optimized_traces_mem_benchmark_only.append({ + ex.name: trc}) + trc = self.place_optimizers(self.trace, executors_time) + _, trc = rematerialize_forward_and_backward( + self.cached_fw_trace, trc) + c, m, o = benchmark_trace(trc) + del o + self.log(f"Debug TIME, time = {c} ms:\n{trc}") + self.fusion_strat_helper.optimized_traces_time_benchmark_only.append({ + ex.name: trc}) # Save executors in order to generate real fw and bw trace with correct output - self.placement_options_time.append(executors_time) - self.placement_options_mem.append(executors_mem) - - def get_optimal_trace(self) -> TraceCtx: - if self.optimizer_type == OptimizerType.RUNTIME: - return self.optimal_trace_time - else: - return self.optimal_trace_mem + self.executor_placement_options.placement_options_time.append( + executors_time) + self.executor_placement_options.placement_options_mem.append( + executors_mem) + + def get_optimal_fw_traces_time_and_mem(self) -> tuple[TraceCtx, TraceCtx]: + if self.fw_trace_candidates.best_time is None or self.fw_trace_candidates.best_mem is None: + raise AssertionError("Failed to obtain optimal fw traces") + return self.fw_trace_candidates.best_time, self.fw_trace_candidates.best_mem + + def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: + # This is agnostic from the optimization strat as results are both floats + min_value: float = float("inf") + ans: FinalOutputCandidates | None = None + for pair in self.out: + if pair.tot_cost < min_value: + self.log(f"New best pair:\n{pair}") + min_value = pair.tot_cost + ans = pair + return ans.fw, ans.bw def bsym_assigned(self, bsym: BoundSymbol) -> bool: return isinstance(bsym.sym.executor, OperatorExecutor) or isinstance(bsym.sym.executor, FusionExecutor) - class Result: - def __init__(self) -> None: - self.measure: float = float('inf') - self.trace: TraceCtx | None = None - self.label: str | Hashable = "" - self.index = -1 - def benchmark_traces(self): - tm = self.Result() mem = self.Result() - self.debug_msg += 'Traces benchmarks:\n\n' + self.debug_msg += "Traces benchmarks:\n\n" source_mem = None source_time = None - if self.strat == self.OptimizationStrat.BEST_FUSER: - source_mem = self.optimized_traces_mem_benchmark_only - source_time = self.optimized_traces_time_benchmark_only - elif self.strat == self.OptimizationStrat.GREEDY: - source_mem = self.optimized_traces_mem - source_time = self.optimized_traces_time - else: - raise AssertionError('Not supported') - + match self.optimization_algorithm: + case OptimizationAlgorithm.BEST_FUSER: + # TODO: handle requests for no remat when introduced in thunder (split_fw_bw) + source_mem = self.fusion_strat_helper.optimized_traces_mem_benchmark_only + source_time = self.fusion_strat_helper.optimized_traces_time_benchmark_only + case OptimizationAlgorithm.GREEDY: + # TODO (matteochen) + raise AssertionError("Not supported") + case OptimizationAlgorithm.EXAUSTIVE: + raise AssertionError("Not supported") + + # Find best trace for runtime for i, trace_info in enumerate(source_time): - + # Unpack the dict label = None trace = None for k, v in trace_info.items(): label = k trace = v - trace_time, _, res = benchmark_trace(trace, iters=10) + trace_time, trace_mem, res = benchmark_trace(trace, iters=10) del res - self.debug_msg += f'Trace name = [{label}] - Time = {trace_time} ms\n{trace}\n\n' - self.log(f'Benchmark trace "{label}" (time = {trace_time} ms:\n{trace}') - if trace_time < tm.measure: - tm.measure = trace_time + self.debug_msg += ( + f"Trace name = [{label}] - Time = {trace_time} ms - Mem = {trace_mem / (2**30)} GB\n{trace}\n\n" + ) + self.log( + f'Benchmark trace (target TIME) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}' + ) + if trace_time < tm.tm: + tm.tm = trace_time + tm.mem = trace_mem tm.trace = trace tm.label = label tm.index = i + # Find best trace for memory for i, trace_info in enumerate(source_mem): - + # Unpack the dict label = None trace = None for k, v in trace_info.items(): label = k trace = v - _, trace_mem, res = benchmark_trace(trace, iters=10) + trace_time, trace_mem, res = benchmark_trace(trace, iters=10) del res - self.debug_msg += f'Trace name = [{label}] - Mem = {trace_mem / (2**30)} GB\n{trace}\n\n' - self.log(f'Benchmark trace "{label}" (mem = {trace_mem / (2 ** 30)} GB):\n{trace}') - if trace_mem < mem.measure: - mem.measure = trace_mem + self.debug_msg += ( + f"Trace name = [{label}] - Mem = {trace_mem / (2**30)} GB - Time = {trace_time} ms\n{trace}\n\n" + ) + self.log( + f'Benchmark trace (target MEM) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}' + ) + if trace_mem < mem.mem: + mem.tm = trace_time + mem.mem = trace_mem mem.trace = trace mem.label = label mem.index = i - self.log(f'Benchmark end: Best trace time "{tm.label} (time = {tm.measure} ms)":\n{tm.trace}') - self.log(f'Benchmark end: Best trace mem "{mem.label} (mem = {mem.measure / (2 ** 30)} GB)":\n{mem.trace}') + self.log( + f'Benchmark end: Best trace time "{tm.label} (time = {tm.tm} ms)":\n{tm.trace}') + self.log( + f'Benchmark end: Best trace mem "{mem.label} (mem = {mem.mem / (2 ** 30)} GB)":\n{mem.trace}') - self.log('Strat comparison') + # TODO (matteochen): remove this + self.log("Strat comparison") c, m, o = benchmark_trace(tm.trace) del o - self.log(f'best time: {c} ms, {m/(2**30)} GB') + self.log(f"best time: {c} ms, {m/(2**30)} GB") c, m, o = benchmark_trace(mem.trace) del o - self.log(f'best mem: {c} ms, {m/(2**30)} GB') - - # TODO (matteochen): use time or mem strat - if self.strat == self.OptimizationStrat.GREEDY: - self.optimal_trace_time = tm.trace - self.optimal_trace_mem = mem.trace - elif self.strat == self.OptimizationStrat.BEST_FUSER: - d = self.optimized_traces_time[tm.index] - t = None - for _, v in d.items(): - t = v - self.optimal_trace_time = t - d = self.optimized_traces_mem[mem.index] - t = None - for _, v in d.items(): - t = v - self.optimal_trace_mem = t - - self.log(f'Saved best trace time:\n{self.optimal_trace_time}') - self.log(f'Saved best trace mem:\n{self.optimal_trace_mem}') + self.log(f"best mem: {c} ms, {m/(2**30)} GB") + + match self.optimization_algorithm: + case OptimizationAlgorithm.GREEDY: + raise AssertionError("Not implemented") + case OptimizationAlgorithm.BEST_FUSER: + # Here we have to recover the traces without the pass through remat in order to be compliant + # with thunder flow as we might have request for no remat + + d = self.fusion_strat_helper.optimized_traces_time[tm.index] + t = None + # Unpack dict + for _, v in d.items(): + t = v + if t is None: + raise AssertionError("None trace") + + match self.trace_type: + case TraceType.FW: + self.fw_trace_candidates.attach_best_time_candidate(t) + case TraceType.BW: + self.bw_trace_candidates.attach_best_time_candidate(t) + + d = self.fusion_strat_helper.optimized_traces_mem[mem.index] + t = None + # Unpack dict + for _, v in d.items(): + t = v + if t is None: + raise AssertionError("None trace") + + match self.trace_type: + case TraceType.FW: + self.fw_trace_candidates.attach_best_mem_candidate(t) + case TraceType.BW: + self.bw_trace_candidates.attach_best_mem_candidate(t) + + match self.trace_type: + case TraceType.FW: + self.log(self.fw_trace_candidates.__repr__()) + case TraceType.BW: + self.log(self.bw_trace_candidates.__repr__()) + + # Now, finally build the pair fw and bw traces for the requested strat + if self.trace_type == TraceType.BW: + forward_time, forward_memory, _ = benchmark_trace( + self.cached_fw_trace, iters=10) + match self.optimizer_type: + case OptimizerType.RUNTIME: + # Used the computed benchmark from above + if tm.tm < mem.tm: + self.log( + f"out candidate times: (fw){forward_time} ms, (bw){tm.tm} ms") + self.out.append( + FinalOutputCandidates( + fw=self.cached_fw_trace, + bw=self.bw_trace_candidates.best_time, + cost=forward_time + tm.tm, + ) + ) + else: + self.log( + f"out candidate times: (fw){forward_time} ms, (bw){mem.tm} ms") + self.out.append( + FinalOutputCandidates( + fw=self.cached_fw_trace, + bw=self.bw_trace_candidates.best_mem, + cost=forward_time + mem.tm, + ) + ) + case OptimizerType.MEMORY: + # Used the computed benchmark from above + if tm.mem < mem.mem: + self.log( + f"out candidate mem: (fw){forward_memory / (2**30)} GB, (bw){tm.mem} GB") + self.out.append( + FinalOutputCandidates( + fw=self.cached_fw_trace, + bw=self.bw_trace_candidates.best_time, + cost=forward_memory + tm.mem, + ) + ) + else: + self.log( + f"out candidate mem: (fw){forward_memory} GB, (bw){mem.mem} GB") + self.out.append( + FinalOutputCandidates( + fw=self.cached_fw_trace, + bw=self.bw_trace_candidates.best_mem, + cost=forward_memory + mem.mem, + ) + ) if self.produce_log: - with open(self.log_file_name, 'w') as file: + import time + + timestamp: str = str(time.time()) + with open(f"{timestamp}-{self.log_file_name}", "w") as file: file.write(self.debug_msg) file.close() -def return_not_used(trace_in: TraceCtx) -> list[TensorProxy]: - def is_in_sequence(seq: Sequence[Any], t:TensorProxy): + +def return_not_used_vars(trace_in: TraceCtx) -> list[TensorProxy]: + def is_in_sequence(seq: Sequence[Any], t: TensorProxy): for e in seq: if isinstance(e, TensorProxy) and e.name == t.name: return True @@ -940,7 +1295,7 @@ def is_in_sequence(seq: Sequence[Any], t:TensorProxy): # Check if this naming is always valid def is_possible_out(name: str): - if not name.startswith('t'): + if not name.startswith("t"): return False num = name[1:] return num.isdigit() @@ -955,63 +1310,85 @@ def is_possible_out(name: str): if not is_possible_out(b.output.name): continue for test in trace_in.bound_symbols: - if test.args is not None and (isinstance(test.args, tuple) or isinstance(test.args, list)) and is_in_sequence(test.args, b.output): + if ( + test.args is not None + and (isinstance(test.args, tuple) or isinstance(test.args, list)) + and is_in_sequence(test.args, b.output) + ): f = True break if not f: ans.append(b.output) return ans + # This will benchmark the input trace with the del_last_used call -def benchmark_trace(trace: TraceCtx, iters: int = 1, show_func = False, apply_del_last_used = True, snapshot = False, snapshot_name = "") -> tuple[float, float, Any]: +# TODO (matteochen): move into utils module +def benchmark_trace( + trace: TraceCtx, iters: int = 1, show_func=False, apply_del_last_used=True, snapshot=False, snapshot_name="" +) -> tuple[float, float, Any]: from thunder.executors.passes import del_last_used import inspect input_args = [] if trace.bound_symbols[-1].sym.id != PrimIDs.RETURN: - raise AssertionError('Missing return statement') + raise AssertionError("Missing return statement") def compute_time_cost_ms(fn: Callable, iters: int, *args) -> tuple[float, float, Any]: - warm_up_iters = 3 - out = None - torch.cuda.empty_cache() - - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - - max_allocated_bytes = 0 - # Warm up cycles - for _ in range(warm_up_iters): - fn(*args) - # Snapshot request - if snapshot: - torch.cuda.memory._record_memory_history() - fn(*args) - torch.cuda.memory._dump_snapshot(snapshot_name + "_benchmark.pickle") - torch.cuda.memory._record_memory_history(enabled=None) - # Benchmark - stream = torch.cuda.current_stream() - for i in range(iters): - torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) + try: + warm_up_iters = 3 + out = None torch.cuda.empty_cache() - torch.cuda._sleep(1_000_000) - start_events[i].record(stream) - fn(*args) - end_events[i].record(stream) - max_allocated_bytes = max(max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device())) - - torch.cuda.synchronize() - times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] - tot_time = sum(times) / iters - return tot_time, max_allocated_bytes, out - - def print_input_args(args, level=0, show_content = False): + + start_events = [torch.cuda.Event( + enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) + for _ in range(iters)] + + max_allocated_bytes = 0 + # Warm up cycles + for _ in range(warm_up_iters): + fn(*args) + # Snapshot request + if snapshot: + torch.cuda.memory._record_memory_history() + fn(*args) + torch.cuda.memory._dump_snapshot( + snapshot_name + "_benchmark.pickle") + torch.cuda.memory._record_memory_history(enabled=None) + # Benchmark + stream = torch.cuda.current_stream() + for i in range(iters): + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + start_events[i].record(stream) + fn(*args) + end_events[i].record(stream) + max_allocated_bytes = max( + max_allocated_bytes, torch.cuda.max_memory_allocated( + torch.cuda.current_device()) + ) + + torch.cuda.synchronize() + times = [s.elapsed_time(e) + for s, e in zip(start_events, end_events)] + tot_time = sum(times) / iters + return tot_time, max_allocated_bytes, out + except Exception as e: + import inspect + + trc = inspect.getsource(fn) + print(f"#FN EXECUTION FAILED:\n{trc}") + raise e + + def print_input_args(args, level=0, show_content=False): for e in args: if isinstance(e, tuple) or isinstance(e, list): - print_input_args(e, level=level+1) + print_input_args(e, level=level + 1) else: - print(f'level {level}', type(e)) + print(f"level {level}", type(e)) # def print_trace_execution_output(out: Any, show_content=False): # if isinstance(out, tuple): @@ -1020,44 +1397,88 @@ def print_input_args(args, level=0, show_content = False): # else: # print(f'{type(out)}') - def thunder_to_torch_float_dtype(byte: int) -> torch.dtype: - if (byte == 2): - return torch.float16 - elif (byte == 4): + def thunder_to_torch_float_dtype(tp: dtype, byte: int) -> torch.dtype: + if byte == 1: + raise AssertionError("Not implmented: 8 bit float") + # Dispatch flaot 16 type 1 from type 2 + elif byte == 2: + if tp._name == thunder.bfloat16._name: + return torch.bfloat16 + else: + return torch.float16 + elif byte == 4: return torch.float32 - else: + elif byte == 8: return torch.float64 + else: + raise AssertionError(f"Not supported byte = {byte}") + + def thunder_to_torch_int_dtype(byte: int) -> torch.dtype: + if byte == 1: + return torch.int8 + elif byte == 2: + return torch.int16 + elif byte == 4: + return torch.int32 + elif byte == 8: + return torch.int64 + else: + raise AssertionError(f"Not supported byte = {byte}") + # TODO (matteochen): use more appropriate mock int and float def transform_input_tuple(t: tuple, level=0) -> tuple: res = [] for e in t: if type(e) is tuple: - res.append(transform_input_tuple(e, level+1)) + res.append(transform_input_tuple(e, level + 1)) else: if isinstance(e, TensorProxy): res.append(transform_tensor(e)) + elif isinstance(e, IntegerProxy): + if e.python_type is bool: + res.append(False if e.value is None else e.value) + else: + res.append(0 if e.value is None else e.value) + elif isinstance(e, FloatProxy): + res.append(0.0 if e.value is None else e.value) else: # TODO (matteochen): support more data types - raise AssertionError(f'Input arg type not recognized: {type(e)}') + raise AssertionError( + f"Input arg type not recognized: {type(e)}") return tuple(res) def transform_tensor(arg: TensorProxy) -> torch.Tensor: + from thunder.core.dtypes import is_float_dtype, is_signedinteger_dtype + + # TODO (matteochen): Missing parallel and fsdp handling... + # TODO (matteochen): Missing support for meta types ... dtype = arg.dtype - if dtype is not None and type(dtype) is thunder.dtypes.floating: - torch_dtype = thunder_to_torch_float_dtype(dtype.bytes) + shape = arg.shape + device = arg.device + requires_grad = arg.requires_grad + if dtype is not None and is_float_dtype(dtype): + torch_dtype = thunder_to_torch_float_dtype(dtype, dtype.bytes) + tensor: torch.Tensor = torch.randn( + *shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad + ) + elif dtype is not None and is_signedinteger_dtype(dtype): + torch_dtype = thunder_to_torch_int_dtype(dtype.bytes) + tensor: torch.Tensor = torch.randint( + 0, 10, shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad + ) + elif dtype is not None and is_boolean_dtype(dtype): + # TODO (matteochen): maybe random? + tensor: torch.Tensor = torch.zeros( + *shape, dtype=torch.bool, device=device.device_str(), requires_grad=requires_grad + ) else: # TODO (matteochen): support other types raise AssertionError(f"dtype {dtype} not supported yet") - shape = arg.shape - device = arg.device - requires_grad = arg.requires_grad - # TODO (matteochen): Missing parallel and fsdp handling... - # TODO (matteochen): Missing support for meta types ... - tensor: torch.Tensor = torch.randn(*shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad) return tensor # Can we remove this check? + # TODO (matteochen): use more appropriate mock int and float if isinstance(trace.args, Sequence): for arg in trace.args: if isinstance(arg, tuple): @@ -1065,23 +1486,70 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: elif isinstance(arg, TensorProxy): e = transform_tensor(arg) input_args.append(e) + elif isinstance(arg, IntegerProxy): + if arg.python_type is bool: + input_args.append( + False if arg.value is None else arg.value) + else: + input_args.append(0 if arg.value is None else arg.value) + elif isinstance(arg, FloatProxy): + input_args.append(0.0 if arg.value is None else arg.value) else: - raise AssertionError(f'Input arg type not recognized: {type(arg)}') + raise AssertionError( + f"Input arg type not recognized: {type(arg)}") else: - raise AssertionError('Unexpexcted args type') + raise AssertionError("Unexpexcted args type") # Always benchmark trace after a deletion last used pass as the final trace out will passed under this stage if apply_del_last_used: trace = del_last_used(trace) + # print(f'BENCHMARKING:\n{trace}') + # def p(args): + # for e in args: + # if not isinstance(e, Sequence): + # if isinstance(e, torch.Tensor): + # print(f'{e.size()}') + # else: + # try: + # print(f'{e.name} -> {e}') + # except: + # print(f'{e}') + # else: + # print('rec') + # p(e) + # p(trace.args) + # print('##################') + # p(input_args) + trace_tok = set_tracectx(trace) # Obtain the python executable string executable = trace.python_callable() if show_func: print(inspect.getsource(executable)) - t, m, answer = compute_time_cost_ms(executable, iters, *input_args) - reset_tracectx(trace_tok) + t = float("inf") + m = float("inf") + answer = None + try: + t, m, answer = compute_time_cost_ms(executable, iters, *input_args) + except Exception as e: + # https://github.com/Lightning-AI/lightning-thunder/issues/664 + print(f"Exception:\n{e}") + if "call_method UserDefinedObjectVariable(set) __contains__ [UserDefinedObjectVariable()] {}" in str(e): + print( + "Executing with torch compile no full graph (this might still fail), see: https://github.com/Lightning-AI/lightning-thunder/issues/664" + ) + torch_compiled = torch.compile(executable, fullgraph=False) + try: + t, m, answer = compute_time_cost_ms( + torch_compiled, iters, *input_args) + except Exception as e: + print(f"Compiled trace execution still failed:\n{e}") + else: + print(f"Unknown exception occured:\n{e}") + finally: + reset_tracectx(trace_tok) return t, m, answer diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index 65d454c2e6..3427d9724f 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -23,7 +23,7 @@ from thunder.executors.pythonex import clear_mutable_collection from thunder.extend import Executor, get_all_executors, get_always_executors, OperatorExecutor, FusionExecutor -from thunder.backend_optimizer.optimizer import BackendOptimizer, OptimizerType +from thunder.backend_optimizer.optimizer import BackendOptimizer, OptimizerType, TraceCandidates, TraceType from thunder.visualizer.graphviz import create_graphviz_pdf from thunder.visualizer.visualizer_helper import Visualizer @@ -136,34 +136,55 @@ def visit_(bsym: BoundSymbol) -> transforms.VISIT_TYPE: return extrace # Autotuned transform_for_execution version -def autotune_transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor], autotune_type: OptimizerType, visualizer: Visualizer | None = None) -> TraceCtx: +def autotune_transform_for_execution( + *, optimizer_context: BackendOptimizer, trace: TraceCtx, trace_type: TraceType +) -> tuple[TraceCtx, TraceCtx] | None: import torch + start_time_ns = time.perf_counter_ns() + # Recover the function name sig_name = cutils.get_siginfo_name(trace) - start_time_ns = time.perf_counter_ns() - if torch.distributed.is_available(): # Apply AllReduce bucketing if possible & needed from thunder.distributed.transforms.ddp import apply_bucketing_to_grad_allreduce trace = apply_bucketing_to_grad_allreduce(trace) - trace = dce(trace) - - backend_optimizer = BackendOptimizer(trace, executors_list, produce_log=True, log_file_name=f'autotune_transform_for_execution_{sig_name}.log', visualizer=visualizer, optimizer_type=autotune_type) - backend_optimizer.optimize() - backend_optimizer.benchmark_traces() - extrace = backend_optimizer.get_optimal_trace() + # Attach new trace and set the debug file name + optimizer_context.attach_trace(trace=trace, trace_type=trace_type) + optimizer_context.log_file_name = f'autotune_transform_for_execution_{sig_name}.log' + # Forward traces are cached inside the context + optimizer_context.optimize() + match trace_type: + case TraceType.FW: + # Nothing more left + pass + # When optimizing the backward pass, the optimizer will return the best fw and bw traces based on the requested autotune_type, no need to choose the fw pass manually + case TraceType.BW: + fw_extrace, bw_extrace = optimizer_context.get_optimal_fw_bw_traces() end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns elapsed_time_millis = elapsed_time_ns // 1000000 - extrace.set_provenance(TraceProvenance(f"Autotuned transform for execution (strat: {autotune_type}) (took {elapsed_time_millis} milliseconds)")) - return extrace - + # Assign the trace provenance + match trace_type: + case TraceType.FW: + fw_extrace_time, fw_extrace_mem = optimizer_context.get_optimal_fw_traces_time_and_mem() + fw_extrace_time.set_provenance( + TraceProvenance(f"Autotuned transform for execution (took {elapsed_time_millis} milliseconds)") + ) + fw_extrace_mem.set_provenance( + TraceProvenance(f"Autotuned transform for execution (took {elapsed_time_millis} milliseconds)") + ) + return None + case TraceType.BW: + bw_extrace.set_provenance( + TraceProvenance(f"Autotuned transform for execution (took {elapsed_time_millis} milliseconds)") + ) + return fw_extrace, bw_extrace def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor]) -> TraceCtx: import torch diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index d51ffc16a3..931cb173e9 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -3,6 +3,7 @@ import torch +from thunder.backend_optimizer.optimizer import OptimizerType import thunder.core.utils as utils from thunder.core.prims import PrimIDs from thunder.core.proxies import TensorProxy, variableify @@ -105,7 +106,6 @@ def backward(ctx, *args): del grads return (None, None, None, None, None, *([None] * n_grads)) -# TODO (matteochen): add control for using autotuner or not def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stats, autotune_type, /, *flat_args): from thunder.core.rematerialization import rematerialize_all_gather, rematerialize_forward_and_backward from thunder.core.transforms import forward_and_backward_from_trace @@ -113,8 +113,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat from thunder.distributed.utils import sort_data_parallel_syncs, sort_waits, sort_communication_ops from thunder.executors.passes import del_last_used, transform_for_execution, autotune_transform_for_execution from thunder.visualizer.visualizer_helper import Visualizer - - visualizer = Visualizer(produce_hidden=False) + from thunder.backend_optimizer.optimizer import TraceType, BackendOptimizer utils.check(compile_data is not None, lambda: "`compile_data` is required") # NOTE: This function is rather slow, so it's intended to be used @@ -166,64 +165,77 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat _fsdp_comm_bucketing = FSDPCommBucketing(compile_data, computation_trc) fw_trace = _fsdp_comm_bucketing.apply_bucketing_to_forward_trace(fw_trace) + do_apply_bucketing_bw_trace: bool = getattr(compile_data.fn, "use_fsdp", False) + # Now we can run the optimization passes on the forward trace - # TODO Restore request for no rematerialization + visualizer = Visualizer(produce_hidden=False) + backend_optimizer_ctx: BackendOptimizer | None = ( + None + if autotune_type is None + else BackendOptimizer( + priority_executors=compile_data.executors_list, + apply_bucketing_bw_trace=do_apply_bucketing_bw_trace, + produce_log=True, + visualizer=visualizer, + optimizer_type=autotune_type, + ) + ) visualizer.set_fw_initial_trace(fw_trace) - if autotune_type is not None: - fw_extrace = autotune_transform_for_execution( - fw_trace, - executors_list=compile_data.executors_list, - autotune_type=autotune_type, - visualizer=visualizer + # Get optimzied fw trace + fw_extrace = ( + transform_for_execution(fw_trace, executors_list=compile_data.executors_list) + if autotune_type is None + else autotune_transform_for_execution( + optimizer_context=backend_optimizer_ctx, trace=fw_trace, trace_type=TraceType.FW ) - else: - fw_extrace = transform_for_execution( - fw_trace, - executors_list=compile_data.executors_list - ) - fw_traces.append(fw_extrace) - visualizer.set_fw_optimized_trace(fw_extrace) - - # Some of the optimization passes change proxies in the trace and - # any change in the forward trace must be reflected in the backward - # trace. - original_bw_saved_tensors_for_backward = bw_trace.args[0][0] - new_fw_saved_tensors_for_backward = fw_extrace.output[1][0] - swap_map = { - variableify(x): y - for x, y in zip(original_bw_saved_tensors_for_backward, new_fw_saved_tensors_for_backward) - if variableify(x) != variableify(y) - } - new_bsyms = replace_redundant_inputs(swap_map, bw_trace.bound_symbols) - # replace_redundant_inputs doesn't replace the output of - # UNPACK_SEQUENCE so we do it manually. Here we have certain - # assumptions about the structure of the backward trace. - assert bw_trace.bound_symbols[0].sym.id == PrimIDs.UNPACK_TRIVIAL - assert bw_trace.bound_symbols[0].kwargs["name"] == "saved_for_backward" - assert bw_trace.bound_symbols[4].sym.id == PrimIDs.UNPACK_SEQUENCE - assert bw_trace.bound_symbols[4].args[0].name == "C0" - new_bsyms[4] = new_bsyms[4].from_bsym_swap_proxies( - swap_map, - skip_inputs=False, - skip_output=False, - skip_subsymbols=False, ) - bw_trace.bound_symbols = new_bsyms - if getattr(compile_data.fn, "use_fsdp", False): - bw_trace = _fsdp_comm_bucketing.apply_bucketing_to_backward_trace(bw_trace) + # If in default mode, otherwise the best fw will be returned only at the end + if autotune_type is None: + fw_traces.append(fw_extrace) + visualizer.set_fw_optimized_trace(fw_extrace) + + # NOTE: autotuner will take care of this + # Some of the optimization passes change proxies in the trace and + # any change in the forward trace must be reflected in the backward + # trace. + original_bw_saved_tensors_for_backward = bw_trace.args[0][0] + new_fw_saved_tensors_for_backward = fw_extrace.output[1][0] + swap_map = { + variableify(x): y + for x, y in zip(original_bw_saved_tensors_for_backward, new_fw_saved_tensors_for_backward) + if variableify(x) != variableify(y) + } + new_bsyms = replace_redundant_inputs(swap_map, bw_trace.bound_symbols) + # replace_redundant_inputs doesn't replace the output of + # UNPACK_SEQUENCE so we do it manually. Here we have certain + # assumptions about the structure of the backward trace. + assert bw_trace.bound_symbols[0].sym.id == PrimIDs.UNPACK_TRIVIAL + assert bw_trace.bound_symbols[0].kwargs["name"] == "saved_for_backward" + assert bw_trace.bound_symbols[4].sym.id == PrimIDs.UNPACK_SEQUENCE + assert bw_trace.bound_symbols[4].args[0].name == "C0" + new_bsyms[4] = new_bsyms[4].from_bsym_swap_proxies( + swap_map, + skip_inputs=False, + skip_output=False, + skip_subsymbols=False, + ) + bw_trace.bound_symbols = new_bsyms + + if do_apply_bucketing_bw_trace: + bw_trace = _fsdp_comm_bucketing.apply_bucketing_to_backward_trace(bw_trace) # Now we can run the optimization passes on the backward trace # TODO Restore request for no rematerialization + visualizer.set_bw_initial_trace(bw_trace) if autotune_type is not None: - bw_extrace = autotune_transform_for_execution( - bw_trace, - executors_list=compile_data.executors_list, - autotune_type=autotune_type, - visualizer=visualizer + fw_extrace, bw_extrace = autotune_transform_for_execution( + optimizer_context=backend_optimizer_ctx, trace=bw_trace, trace_type=TraceType.BW ) + fw_traces.append(fw_extrace) + visualizer.set_bw_optimized_trace(fw_extrace) else: bw_extrace = transform_for_execution( bw_trace, @@ -232,6 +244,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat bw_traces.append(bw_extrace) visualizer.set_bw_optimized_trace(bw_extrace) + # TODO Restore request for no rematerialization fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) fw_traces.append(fw_extrace) bw_traces.append(bw_extrace) From 0f7ced3415849fd284db0fe1625b75e7ecc97523 Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Fri, 26 Jul 2024 16:08:00 +0300 Subject: [PATCH 019/171] Updated test --- examples/dev/litGPT.py | 76 ++++++++++++++++++++++-------------------- 1 file changed, 39 insertions(+), 37 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index ec8c78044d..587e27cf28 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -4,46 +4,48 @@ import torch from thunder.backend_optimizer.optimizer import benchmark_trace -cfg = Config.from_name('Llama-2-7b-hf') -cfg.n_layer = 1 # fewer layers -torch.set_default_dtype(torch.bfloat16) +layers = [4, 8, 16, 32] -with torch.device('cuda'): - model = GPT(cfg) - x = torch.randint(1, model.config.vocab_size, (1, 512)) - jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type='runtime') - y = jmodel_def(x) - yy = jmodel_auto(x) +for l in layers: + print('Layers:', l) + cfg = Config.from_name('Llama-2-7b-hf') + cfg.n_layer = l + torch.set_default_dtype(torch.bfloat16) + with torch.device('cuda'): + model = GPT(cfg) + x = torch.randint(1, model.config.vocab_size, (1, 512)) + jmodel_def = thunder.jit(model) + jmodel_auto = thunder.jit(model, autotune_type='runtime') + y = jmodel_def(x) + yy = jmodel_auto(x) - jmodel_def = thunder.jit(model) - # This model fails under some circumstances after passed the placed traced under the rematelizer - jmodel_auto = thunder.jit(model, autotune_type='memory') + jmodel_def = thunder.jit(model) + jmodel_auto = thunder.jit(model, autotune_type='runtime') - y = model(x) - print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) - print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) + y = model(x) + print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) + print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) - print('########################################') - c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_def_fw') - print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') - del o - c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_auto_fw') - print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') - del o - c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_def_bw') - print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') - del o - c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_auto_bw') - print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') - del o + print('Results ########################################') + c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='llama-2-7b-hf_def_fw') + print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') + del o + c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='llama-2-7b-hf_auto_fw') + print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') + del o + c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='llama-2-7b-hf_def_bw') + print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') + del o + c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='llama-2-7b-hf_auto_bw') + print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') + del o - print('\n\n\n\n\n\n') - print(f'{thunder.last_traces(jmodel_def)[-1]}') - print('###############################################################################') - print(f'{thunder.last_traces(jmodel_auto)[-1]}') + print('\n\n\n\n\n\n') + print(f'{thunder.last_traces(jmodel_def)[-1]}') + print('###############################################################################') + print(f'{thunder.last_traces(jmodel_auto)[-1]}') - print('\n\n') - print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') - print('###############################################################################') - print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') + print('\n\n') + print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') + print('###############################################################################') + print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') From 579888d309e31a230a2a0b6cfb5b01f14ca8131e Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Sun, 28 Jul 2024 17:57:56 +0300 Subject: [PATCH 020/171] Fixed bad list index / removed print --- thunder/backend_optimizer/optimizer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index ff6d6d210a..0b4a6ac8da 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -114,7 +114,6 @@ def __init__( visualizer: Visualizer | None = None, optimizer_type: OptimizerType = OptimizerType.RUNTIME, ) -> None: - print(f"OO {optimizer_type}") self.always_executors: tuple[Executor, ...] = get_always_executors() self.empty_executor_hashable_placeholder: str = "empty" self.executors: Sequence[Executor] = resolve_executors( @@ -968,7 +967,7 @@ def measure_and_update_result(): for k in range(start_idx + i + 1, len(group), increment_factor): match_bsym_output( group[k], dict_time_strat, dict_mem_strat, get_optimal_executor( - group[k + start_idx]) + group[k]) ) # Benchmark measure_and_update_result() From 3528e6c82d52620b81dcec9d0ddd8871e2b732eb Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Sun, 28 Jul 2024 21:18:37 +0300 Subject: [PATCH 021/171] Disabled graphviz / modified test runner --- examples/dev/litGPT.py | 21 ++++++++++----------- thunder/backend_optimizer/optimizer.py | 6 +++--- thunder/executors/torch_autograd.py | 2 +- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 587e27cf28..e3f76cb293 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -4,25 +4,24 @@ import torch from thunder.backend_optimizer.optimizer import benchmark_trace -layers = [4, 8, 16, 32] +class Test: + def __init__(self, layers: int, autotune_type: str) -> None: + self.layers = layers + self.autotune_type = autotune_type -for l in layers: - print('Layers:', l) +layers = [Test(2, 'runtime')] + +for test in layers: + print('Layers:', test.layers) cfg = Config.from_name('Llama-2-7b-hf') - cfg.n_layer = l + cfg.n_layer = test.layers torch.set_default_dtype(torch.bfloat16) with torch.device('cuda'): model = GPT(cfg) x = torch.randint(1, model.config.vocab_size, (1, 512)) jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type='runtime') - y = jmodel_def(x) - yy = jmodel_auto(x) - - jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type='runtime') + jmodel_auto = thunder.jit(model, autotune_type=test.autotune_type) - y = model(x) print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 0b4a6ac8da..77bc2c85b6 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -14,6 +14,7 @@ from typing import Any, Hashable import thunder import thunder.core.transforms as transforms +import concurrent.futures import torch import time @@ -1321,7 +1322,6 @@ def is_possible_out(name: str): return ans -# This will benchmark the input trace with the del_last_used call # TODO (matteochen): move into utils module def benchmark_trace( trace: TraceCtx, iters: int = 1, show_func=False, apply_del_last_used=True, snapshot=False, snapshot_name="" @@ -1396,6 +1396,7 @@ def print_input_args(args, level=0, show_content=False): # else: # print(f'{type(out)}') + # TODO (matteochen): convert this into dict def thunder_to_torch_float_dtype(tp: dtype, byte: int) -> torch.dtype: if byte == 1: raise AssertionError("Not implmented: 8 bit float") @@ -1412,6 +1413,7 @@ def thunder_to_torch_float_dtype(tp: dtype, byte: int) -> torch.dtype: else: raise AssertionError(f"Not supported byte = {byte}") + # TODO (matteochen): convert this into dict def thunder_to_torch_int_dtype(byte: int) -> torch.dtype: if byte == 1: return torch.int8 @@ -1441,7 +1443,6 @@ def transform_input_tuple(t: tuple, level=0) -> tuple: elif isinstance(e, FloatProxy): res.append(0.0 if e.value is None else e.value) else: - # TODO (matteochen): support more data types raise AssertionError( f"Input arg type not recognized: {type(e)}") return tuple(res) @@ -1471,7 +1472,6 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: *shape, dtype=torch.bool, device=device.device_str(), requires_grad=requires_grad ) else: - # TODO (matteochen): support other types raise AssertionError(f"dtype {dtype} not supported yet") return tensor diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 931cb173e9..99ed7492c5 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -324,6 +324,6 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat visualizer.set_fw_final_trace(fw_extrace) visualizer.set_bw_final_trace(bw_extrace) - visualizer.produce() + # visualizer.produce() return fw_extrace, bw_extrace From f67bb1ccf24afd49109c10abe7600dcb76adb95e Mon Sep 17 00:00:00 2001 From: Kaixi Matteo Chen Date: Sun, 28 Jul 2024 21:27:55 +0300 Subject: [PATCH 022/171] Using user defined executor list or default as unique executors ref in autotuner --- examples/dev/litGPT.py | 6 ++++-- thunder/backend_optimizer/optimizer.py | 7 ++----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index e3f76cb293..2a8537be2a 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -19,8 +19,10 @@ def __init__(self, layers: int, autotune_type: str) -> None: with torch.device('cuda'): model = GPT(cfg) x = torch.randint(1, model.config.vocab_size, (1, 512)) - jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type=test.autotune_type) + executors = ['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python'] + + jmodel_def = thunder.jit(model, executors=executors) + jmodel_auto = thunder.jit(model, autotune_type=test.autotune_type, executors=executors) print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 77bc2c85b6..11553177f4 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -117,10 +117,7 @@ def __init__( ) -> None: self.always_executors: tuple[Executor, ...] = get_always_executors() self.empty_executor_hashable_placeholder: str = "empty" - self.executors: Sequence[Executor] = resolve_executors( - ["nvfuser", "torchcompile", "sdpa", "cudnn", "torch", "python"] - ) - self.priority_executors: Sequence[Executor] = priority_executors + self.executors: Sequence[Executor] = priority_executors self.fusion_executors: Sequence[FusionExecutor] = [ ex for ex in self.executors if isinstance(ex, FusionExecutor) ] @@ -607,7 +604,7 @@ def greedy(): # 3. Try the priority list approach trace_priority = transform_for_execution( - self.trace, self.priority_executors) + self.trace, self.executors) self.optimized_traces.append({"priority_list": trace_priority}) # There are no hidden placements hence do not call the visualizer From 5f76bcf17bdb7933d9818b5df438b01e30fdba9d Mon Sep 17 00:00:00 2001 From: matteochen Date: Mon, 29 Jul 2024 14:24:07 +0300 Subject: [PATCH 023/171] Computing bw traces taking in consideration every fw traces options (from different FusionExecutors) (#4) --- examples/dev/LLaMAMLP.py | 14 +- examples/dev/MLP.py | 37 +- examples/dev/litGPT.py | 25 +- thunder/backend_optimizer/optimizer.py | 536 +++++++++++-------------- thunder/executors/passes.py | 12 +- 5 files changed, 253 insertions(+), 371 deletions(-) diff --git a/examples/dev/LLaMAMLP.py b/examples/dev/LLaMAMLP.py index 98fadb8dbd..28d26e867a 100644 --- a/examples/dev/LLaMAMLP.py +++ b/examples/dev/LLaMAMLP.py @@ -20,25 +20,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: a = 4096 * mult b = 11008 * mult x = torch.randn(2, 2048, a, requires_grad=True) + model = LLaMAMLP(a, b) - jmodel_def = thunder.jit(model, executors=['torchcompile', 'nvfuser']) - jmodel_auto = thunder.jit(model, autotune_type='runtime') + + jmodel_def = thunder.jit(model) + jmodel_auto = thunder.jit(model, autotune_type='memory', executors = ['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python']) y = model(x) print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) print('########################################') - c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], iters=2, apply_del_last_used=False, snapshot=True, snapshot_name='LLaMAMLP_def_fw') + c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='LLaMAMLP_def_fw', iters=10) print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') del o - c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], iters=2, apply_del_last_used=False, snapshot=True, snapshot_name='LLaMAMLP_auto_fw') + c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='LLaMAMLP_auto_fw', iters=10) print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') del o - c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], iters=2, apply_del_last_used=False, snapshot=True, snapshot_name='LLaMAMLP_def_bw') + c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='LLaMAMLP_def_bw', iters=10) print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') del o - c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], iters=2, apply_del_last_used=False, snapshot=True, snapshot_name='LLaMAMLP_auto_bw') + c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='LLaMAMLP_auto_bw', iters=10) print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') del o diff --git a/examples/dev/MLP.py b/examples/dev/MLP.py index cd803d76c9..4c1fa488ae 100644 --- a/examples/dev/MLP.py +++ b/examples/dev/MLP.py @@ -2,10 +2,6 @@ import torch.nn as nn import thunder from thunder.backend_optimizer.optimizer import benchmark_trace -# import logging - -# torch._logging.set_logs(dynamo = logging.DEBUG) -# torch._dynamo.config.verbose = True class ModelConfig: def __init__(self, n_embd=256, n_head=8, dropout=0.1, block_size=64, bias=True): @@ -40,25 +36,21 @@ def forward(self, x): jmodel_def = thunder.jit(model) # This model fails under some circumstances after passed the placed traced under the rematelizer - jmodel_auto = thunder.jit(model, autotune_type='memory') + jmodel_auto = thunder.jit(model, autotune_type='runtime', executors = ['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python']) y = model(x) print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) print('########################################') - c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_def_fw') + c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_def_fw', iters=10) print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') - del o - c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_auto_fw') + c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_auto_fw', iters=10) print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') - del o - c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_def_bw') + c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_def_bw', iters=10) print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') - del o - c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_auto_bw') + c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_auto_bw', iters=10) print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') - del o print('\n\n\n\n\n\n') print(f'{thunder.last_traces(jmodel_def)[-1]}') @@ -70,22 +62,3 @@ def forward(self, x): print('###############################################################################') print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') - # from torch.profiler import profile, record_function, ProfilerActivity - # with profile(activities=[ - # ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: - # with record_function("def"): - # y = jmodel_def(x) - # grad_outputs = torch.ones_like(y) - # torch.autograd.grad(y, x, grad_outputs=grad_outputs) - - # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - - # with profile(activities=[ - # ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: - # with record_function("auto"): - # y = jmodel_auto(x) - # grad_outputs = torch.ones_like(y) - # torch.autograd.grad(y, x, grad_outputs=grad_outputs) - - # print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 2a8537be2a..effb664c25 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -9,37 +9,34 @@ def __init__(self, layers: int, autotune_type: str) -> None: self.layers = layers self.autotune_type = autotune_type -layers = [Test(2, 'runtime')] +layers = [Test(1, 'runtime')] + +model_name = 'Llama-2-7b-hf' for test in layers: print('Layers:', test.layers) - cfg = Config.from_name('Llama-2-7b-hf') + cfg = Config.from_name(model_name) cfg.n_layer = test.layers torch.set_default_dtype(torch.bfloat16) with torch.device('cuda'): model = GPT(cfg) x = torch.randint(1, model.config.vocab_size, (1, 512)) - executors = ['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python'] - jmodel_def = thunder.jit(model, executors=executors) - jmodel_auto = thunder.jit(model, autotune_type=test.autotune_type, executors=executors) + jmodel_def = thunder.jit(model) + jmodel_auto = thunder.jit(model, autotune_type=test.autotune_type, executors = ['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python']) - print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) + print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) print('Results ########################################') - c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='llama-2-7b-hf_def_fw') + c, m, _ = benchmark_trace(thunder.last_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name=model_name) print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') - del o - c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='llama-2-7b-hf_auto_fw') + c, m, _ = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name=model_name) print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') - del o - c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='llama-2-7b-hf_def_bw') + c, m, _ = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name=model_name) print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') - del o - c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='llama-2-7b-hf_auto_bw') + c, m, _ = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name=model_name) print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') - del o print('\n\n\n\n\n\n') print(f'{thunder.last_traces(jmodel_def)[-1]}') diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 11553177f4..76afecfbcf 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -9,14 +9,13 @@ from thunder.core.symbol import BoundSymbol, Symbol from thunder.core.trace import from_trace, set_tracectx, reset_tracectx, get_tracectx, TraceCtx from thunder.executors.data_dependent_partition import Graph, Node -from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_always_executors, resolve_executors +from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_always_executors from thunder.visualizer.visualizer_helper import Visualizer from typing import Any, Hashable import thunder import thunder.core.transforms as transforms import concurrent.futures import torch -import time class OptimizerType(Enum): @@ -30,9 +29,8 @@ class TraceType(Enum): class OptimizationAlgorithm(Enum): - EXAUSTIVE = 0 - GREEDY = 1 - BEST_FUSER = 2 + GREEDY = 0 + BEST_FUSER = 1 class OptimizerNode: @@ -48,7 +46,6 @@ class TraceCandidates: def __init__(self, best_time: TraceCtx | None = None, best_mem: TraceCtx | None = None) -> None: self.best_time: TraceCtx | None = best_time self.best_mem: TraceCtx | None = best_mem - self.time_took: float = 0 def __repr__(self) -> str: return f"\nBest runtime candidate:\n{self.best_time}\nBest memory candidate:\n{self.best_mem}" @@ -62,11 +59,8 @@ def attach_best_time_candidate(self, trace: TraceCtx): def attach_best_mem_candidate(self, trace: TraceCtx): self.best_mem = trace - def assign_time_took(self, time: float): - self.time_took = time - - def to_list(self) -> list[TraceCtx | None]: - return [self.best_time, self.best_mem] + def iterable(self) -> tuple[TraceCtx | None, TraceCtx | None]: + return self.best_time, self.best_mem class FinalOutputCandidates: @@ -76,7 +70,7 @@ def __init__(self, *, fw: TraceCtx, bw: TraceCtx, cost: float) -> None: self.tot_cost: float = cost def __repr__(self) -> str: - return f"Forward trace:\n{self.fw.__repr__()}\nBackward trace:{self.bw.__repr__()}" + return f"Final output candidate: forward trace:\n{self.fw.__repr__()}\nFinal output candidate: backward trace:{self.bw.__repr__()}" # Benchmark only traces will contain traces after the rematerialization call with fw and bw calls, reproducing what will be the real traces after the autotune pass @@ -131,14 +125,13 @@ def __init__( self.optimizer_type: OptimizerType = optimizer_type self.optimization_algorithm: OptimizationAlgorithm | None = None - self.cached_fw_trace: TraceCtx | None = None - self.fw_trace_candidates: TraceCandidates = TraceCandidates() + self.active_fw_trace: TraceCtx | None = None + self.cached_fw_traces: dict[str | Hashable, TraceCandidates] = {} self.bw_trace_candidates: TraceCandidates = TraceCandidates() self.out: list[FinalOutputCandidates] = [] # Strat greedy self.computation_graph: Graph - self.incremental_search_out_trace: TraceCtx # Strat fusion self.fusion_strat_helper: FusionStratHelper = FusionStratHelper() @@ -146,6 +139,8 @@ def __init__( self.apply_bucketing_bw_trace: bool = apply_bucketing_bw_trace + self.benchmark_iters = 10 + self.log("Executors:") for e in self.executors: self.log( @@ -158,7 +153,7 @@ def __init__(self, symbol: BoundSymbolInterface, idx: int) -> None: self.idx = idx # Currently this manages both time and memory - class Result: + class BenchmarkResult: def __init__(self) -> None: self.tm: float = float("inf") self.mem: float = float("inf") @@ -166,10 +161,10 @@ def __init__(self) -> None: self.label: str | Hashable = "" self.index = -1 - def attach_cached_fw_traces(self, cached_fw_traces: TraceCandidates) -> None: - self.cached_fw_traces = cached_fw_traces + def attach_cached_fw_traces(self, cached_fw_traces: TraceCandidates, executor_name: str) -> None: + self.cached_fw_traces[executor_name] = cached_fw_traces - def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): + def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): from thunder.core.transform_common import dce self.trace_type = trace_type @@ -183,14 +178,13 @@ def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): f"New forward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}") # TODO (matteochen): support bw trace optimization even though with no fw traces cached case TraceType.BW: - self.log( - f"New backward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}") - if not self.fw_trace_candidates.is_set(): + if not self.cached_fw_traces: raise AssertionError( "Can not optimize backward traces before forward traces") + self.log( + f"New backward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}") - # TODO (matteochen): this has a lot in common with the exaustive search, compact them - def build_placement_options_incremental(self, whoami: TraceType = TraceType.FW): + def build_placement_options_greedy(self): import sys old_max_recursion = sys.getrecursionlimit() @@ -232,7 +226,7 @@ def safe_update_dict(d: dict, key_one, key_two, value): # Place the assigned symbols placed_t = self.place_optimizers(t, configuration) - cost, mem, answer = benchmark_trace(placed_t, iters=5) + cost, mem, answer = benchmark_trace(placed_t, self.benchmark_iters) del answer self.log( f"Executing partial trace for incremental benchmark:\n{placed_t}") @@ -314,7 +308,7 @@ def continue_search(): # try to fuse then and compare def try_to_fuse_after_executors_placement(self, trace_in: TraceCtx) -> TraceCtx: best_trace: TraceCtx = trace_in - best_time, best_mem, answer = benchmark_trace(best_trace, iters=10) + best_time, best_mem, answer = benchmark_trace(best_trace, self.benchmark_iters) del answer trace_in_time = best_time @@ -323,7 +317,7 @@ def try_to_fuse_after_executors_placement(self, trace_in: TraceCtx) -> TraceCtx: extrace = ex.fusion_pass(trace_in) self.log(f"Fused trace:\n{extrace}") extrace_time, extrace_mem, answer = benchmark_trace( - extrace, iters=10) + extrace, self.benchmark_iters) del answer self.log(f"Fused trace time:{extrace_time} ms") @@ -336,51 +330,6 @@ def try_to_fuse_after_executors_placement(self, trace_in: TraceCtx) -> TraceCtx: return best_trace - def build_placement_options_exaustive(self): - # We assign an internal id to each symbol based on its idx inside the bound_symbols list - def search(node: self.SearchNode, configuration): - def continue_search(): - if node.idx + 1 < max_len: - new_idx: int = node.idx + 1 - new_symbol: BoundSymbolInterface = bound_symbols[new_idx] - search(self.SearchNode(new_symbol, new_idx), configuration) - else: - all_configurations.append(list(configuration)) - - ex: Executor - has_backend = False - for ex in self.executors: - if not isinstance(node.symbol, BoundSymbol): - raise AssertionError( - "Receive a symbol which is not a BoundSymbol") - if isinstance(ex, OperatorExecutor) and ex.can_execute(node.symbol): - has_backend = True - configuration.append(ex) - continue_search() - configuration.pop(-1) - if isinstance(ex, FusionExecutor) and ex.can_fuse(node.symbol): - has_backend = True - configuration.append(ex) - continue_search() - configuration.pop(-1) - - if not has_backend: - configuration.append(empty_executor) - continue_search() - configuration.pop(-1) - - bound_symbols: list[BoundSymbolInterface] = self.trace.bound_symbols - max_len = len(bound_symbols) - - all_configurations: list[list[Executor]] = [] - # Is the name reserved? - empty_executor = Executor( - name=self.empty_executor_hashable_placeholder) - - if len(bound_symbols) > 0: - search(self.SearchNode(bound_symbols[0], 0), []) - self.placement_options = all_configurations - def place_optimizers(self, in_trace, executor_list: list[Executor]) -> TraceCtx: from thunder.executors.passes import _transform_for_operator_executor_execution @@ -540,20 +489,19 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: # TODO (matteochen): add config for exaustive search or incremental one def optimize(self, strat: OptimizationAlgorithm = OptimizationAlgorithm.BEST_FUSER): - import thunder.core.codeutils as cutils from thunder.executors.passes import transform_for_execution from thunder.core.transform_common import replace_redundant_inputs from thunder.core.transform_common import dce self.optimization_algorithm = strat - def best_fuser(): + def optmize_best_fuser(): # Reset fusion helpers self.fusion_strat_helper = FusionStratHelper() # Reset helpers data structures self.executor_placement_options = ExecutorPlacementOptions() - self.build_placement_options_fusion_regions() + self.build_placement_options_best_fuser() if len(self.executor_placement_options.placement_options_time) != len(self.fusion_executors): raise AssertionError( @@ -575,17 +523,12 @@ def best_fuser(): self.benchmark_traces() - # Cached computed fw traced placements - if self.trace_type == TraceType.FW: - self.cached_fw_traces = self.fw_trace_candidates - self.log(f"Caching fw traces\n{self.cached_fw_traces}") - - def greedy(): + def optimize_greedy(): # Reset helpers data structures self.executor_placement_options = ExecutorPlacementOptions() # 1. This builds one option by default - self.build_placement_options_incremental() + self.build_placement_options_greedy() if len(self.placement_options) != 1: raise AssertionError("Unexpected placement options size") @@ -609,104 +552,82 @@ def greedy(): # There are no hidden placements hence do not call the visualizer - def exaustive(): - # This builds one option by default - self.build_placement_options_exaustive() - - self.log(f"Placement options size: {len(self.placement_options)}") - - for option in self.placement_options: - option_str = [str(ex.name) for ex in option] - option_str = "-".join(option_str) - trace = self.place_optimizers(self.trace, option) - - if self.visualizer is not None: - sig_name = cutils.get_siginfo_name(trace) - # TODO (matteochen): consider adding more infos for naming - self.visualizer.set_hidden_trace( - f"hidden-{sig_name}-{option_str}", trace) - - self.optimized_traces.append({option_str: trace}) + # Run benchmarks + self.benchmark_traces() def match_optimizer_algorithm(): match self.optimization_algorithm: case OptimizationAlgorithm.GREEDY: - greedy() - case OptimizationAlgorithm.EXAUSTIVE: - exaustive() + optimize_greedy() case OptimizationAlgorithm.BEST_FUSER: - best_fuser() - - start_time = time.perf_counter_ns() + optmize_best_fuser() match self.trace_type: case TraceType.FW: match_optimizer_algorithm() # We have multiple cached optimized fw traces, find the best backward case TraceType.BW: - fw_traces = self.fw_trace_candidates.to_list() # Cached the bw trace as we need to modify the input trace during the loop cached_self_trace = from_trace(self.trace) cached_self_trace.bound_symbols = list( self.trace.bound_symbols) - for t in fw_traces: - # Restore the original bw trace - self.trace = from_trace(cached_self_trace) - self.trace.bound_symbols = list( - cached_self_trace.bound_symbols) - # Set the current active cached forward trace - self.cached_fw_trace = t - - self.log(f"Cached fw trace:\n{self.cached_fw_trace}") - self.log(f"Input bw trace:\n{self.trace}") - - # Some of the optimization passes change proxies in the trace and - # any change in the forward trace must be reflected in the backward - # trace. - original_bw_saved_tensors_for_backward = self.trace.args[0][0] - new_fw_saved_tensors_for_backward = t.output[1][0] - swap_map = { - variableify(x): y - for x, y in zip(original_bw_saved_tensors_for_backward, new_fw_saved_tensors_for_backward) - if variableify(x) != variableify(y) - } - new_bsyms = replace_redundant_inputs( - swap_map, self.trace.bound_symbols) - # replace_redundant_inputs doesn't replace the output of - # UNPACK_SEQUENCE so we do it manually. Here we have certain - # assumptions about the structure of the backward trace. - assert self.trace.bound_symbols[0].sym.id == PrimIDs.UNPACK_TRIVIAL - assert self.trace.bound_symbols[0].kwargs["name"] == "saved_for_backward" - assert self.trace.bound_symbols[4].sym.id == PrimIDs.UNPACK_SEQUENCE - assert self.trace.bound_symbols[4].args[0].name == "C0" - new_bsyms[4] = new_bsyms[4].from_bsym_swap_proxies( - swap_map, - skip_inputs=False, - skip_output=False, - skip_subsymbols=False, - ) - self.trace.bound_symbols = new_bsyms - - if self.apply_bucketing_bw_trace: - from thunder.distributed.transforms import FSDPCommBucketing - - self.trace = FSDPCommBucketing.apply_bucketing_to_backward_trace( - self.trace) - - # Not called in the constructor for bw traces - dce(self.trace) - - match_optimizer_algorithm() - - end_time = time.perf_counter_ns() - if self.trace_type == TraceType.FW: - self.fw_trace_candidates.assign_time_took( - (end_time - start_time) // 1000000) - else: - self.bw_trace_candidates.assign_time_took( - (end_time - start_time) // 1000000) + for label, candidate in self.cached_fw_traces.items(): + self.log(f'Backward optimization with fw from {label}') + fw_traces = candidate.iterable() + for trc in fw_traces: + + # TODO (matteochen): unify below with the original block + + # Restore the original bw trace + self.trace = from_trace(cached_self_trace) + self.trace.bound_symbols = list( + cached_self_trace.bound_symbols) + # Set the current active cached forward trace + self.active_fw_trace = trc + + self.log(f"Cached fw trace:\n{self.active_fw_trace}") + self.log(f"Input bw trace:\n{self.trace}") + + # Some of the optimization passes change proxies in the trace and + # any change in the forward trace must be reflected in the backward + # trace. + original_bw_saved_tensors_for_backward = self.trace.args[0][0] + new_fw_saved_tensors_for_backward = trc.output[1][0] + swap_map = { + variableify(x): y + for x, y in zip(original_bw_saved_tensors_for_backward, new_fw_saved_tensors_for_backward) + if variableify(x) != variableify(y) + } + new_bsyms = replace_redundant_inputs( + swap_map, self.trace.bound_symbols) + # replace_redundant_inputs doesn't replace the output of + # UNPACK_SEQUENCE so we do it manually. Here we have certain + # assumptions about the structure of the backward trace. + assert self.trace.bound_symbols[0].sym.id == PrimIDs.UNPACK_TRIVIAL + assert self.trace.bound_symbols[0].kwargs["name"] == "saved_for_backward" + assert self.trace.bound_symbols[4].sym.id == PrimIDs.UNPACK_SEQUENCE + assert self.trace.bound_symbols[4].args[0].name == "C0" + new_bsyms[4] = new_bsyms[4].from_bsym_swap_proxies( + swap_map, + skip_inputs=False, + skip_output=False, + skip_subsymbols=False, + ) + self.trace.bound_symbols = new_bsyms + + if self.apply_bucketing_bw_trace: + from thunder.distributed.transforms import FSDPCommBucketing + + self.trace = FSDPCommBucketing.apply_bucketing_to_backward_trace( + self.trace) + + # Not called in the constructor for bw traces + dce(self.trace) + + match_optimizer_algorithm() - def build_placement_options_fusion_regions(self, increment_factor: int = 1): + # For each fusion executor in the input list, find the best trace dispatching for each executor + def build_placement_options_best_fuser(self, increment_factor: int = 1): from thunder.executors.data_dependent_partition import Node, fuse_bound_symbols from thunder.core.rematerialization import rematerialize_forward_and_backward @@ -728,7 +649,7 @@ def sequence_hash(s: Sequence) -> str: return name # TODO (matteochen): Benchmark the optimal executor and call this optimal - def get_optimal_executor(bsym: BoundSymbol): + def get_first_available_executor(bsym: BoundSymbol): for ex in self.executors: if isinstance(ex, FusionExecutor): continue @@ -870,7 +791,7 @@ def match_bsym_output(bsym_in: BoundSymbol, time_dict: dict, mem_dict: dict, ex_ current_bsym = group[0] self.log(f"--> Single group: {current_bsym.sym.name}") name = current_bsym.sym.name - optimal_ex = get_optimal_executor(current_bsym) + optimal_ex = get_first_available_executor(current_bsym) if name == "return": dict_time_strat["return"] = optimal_ex dict_mem_strat["return"] = optimal_ex @@ -882,10 +803,10 @@ def match_bsym_output(bsym_in: BoundSymbol, time_dict: dict, mem_dict: dict, ex_ continue # Inside groups we should have alwasy tensors as out - best_res_time = self.Result() - best_res_mem = self.Result() - worst_res_time = self.Result() - worst_res_mem = self.Result() + best_res_time = self.BenchmarkResult() + best_res_mem = self.BenchmarkResult() + worst_res_time = self.BenchmarkResult() + worst_res_mem = self.BenchmarkResult() # Only for visual worst_res_mem.measure = 0 worst_res_time.measure = 0 @@ -907,10 +828,10 @@ def measure_and_update_result(): nonlocal worst_res_mem trc, keys, placements = get_placed_trace( dict_time_strat, increasing_symbols) - if self.trace_type == TraceType.BW and self.cached_fw_trace is not None: + if self.trace_type == TraceType.BW and self.active_fw_trace is not None: _, trc = rematerialize_forward_and_backward( - self.cached_fw_trace, trc) - cost, mem, out = benchmark_trace(trc, iters=3) + self.active_fw_trace, trc) + cost, mem, out = benchmark_trace(trc, self.benchmark_iters) del out self.log( f"Placed trace (cost = {cost} ms, mem = {mem/(2**30)} GB)\n{trc}") @@ -964,7 +885,7 @@ def measure_and_update_result(): group[j], dict_time_strat, dict_mem_strat, ex) for k in range(start_idx + i + 1, len(group), increment_factor): match_bsym_output( - group[k], dict_time_strat, dict_mem_strat, get_optimal_executor( + group[k], dict_time_strat, dict_mem_strat, get_first_available_executor( group[k]) ) # Benchmark @@ -1062,7 +983,7 @@ def measure_and_update_result(): else: trc = self.place_optimizers(self.trace, executors_mem) _, trc = rematerialize_forward_and_backward( - self.cached_fw_trace, trc) + self.active_fw_trace, trc) c, m, o = benchmark_trace(trc) del o self.log(f"Debug MEM, mem = {m/(2**30)} GB:\n{trc}") @@ -1070,7 +991,7 @@ def measure_and_update_result(): ex.name: trc}) trc = self.place_optimizers(self.trace, executors_time) _, trc = rematerialize_forward_and_backward( - self.cached_fw_trace, trc) + self.active_fw_trace, trc) c, m, o = benchmark_trace(trc) del o self.log(f"Debug TIME, time = {c} ms:\n{trc}") @@ -1083,205 +1004,197 @@ def measure_and_update_result(): self.executor_placement_options.placement_options_mem.append( executors_mem) - def get_optimal_fw_traces_time_and_mem(self) -> tuple[TraceCtx, TraceCtx]: - if self.fw_trace_candidates.best_time is None or self.fw_trace_candidates.best_mem is None: + def get_optimal_fw_traces(self) -> Sequence[TraceCtx]: + if not self.cached_fw_traces: raise AssertionError("Failed to obtain optimal fw traces") - return self.fw_trace_candidates.best_time, self.fw_trace_candidates.best_mem + return [getattr(candidate, field) for candidate in self.cached_fw_traces.values() for field in ['best_time', 'best_mem']] def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: # This is agnostic from the optimization strat as results are both floats min_value: float = float("inf") ans: FinalOutputCandidates | None = None + self.log(f'Computing the best pair option (tot options = {len(self.out)})') for pair in self.out: if pair.tot_cost < min_value: self.log(f"New best pair:\n{pair}") min_value = pair.tot_cost ans = pair + if ans is None: + raise AssertionError('Best pair not found') return ans.fw, ans.bw def bsym_assigned(self, bsym: BoundSymbol) -> bool: return isinstance(bsym.sym.executor, OperatorExecutor) or isinstance(bsym.sym.executor, FusionExecutor) def benchmark_traces(self): - tm = self.Result() - mem = self.Result() self.debug_msg += "Traces benchmarks:\n\n" - source_mem = None - source_time = None - match self.optimization_algorithm: - case OptimizationAlgorithm.BEST_FUSER: - # TODO: handle requests for no remat when introduced in thunder (split_fw_bw) - source_mem = self.fusion_strat_helper.optimized_traces_mem_benchmark_only - source_time = self.fusion_strat_helper.optimized_traces_time_benchmark_only - case OptimizationAlgorithm.GREEDY: - # TODO (matteochen) - raise AssertionError("Not supported") - case OptimizationAlgorithm.EXAUSTIVE: - raise AssertionError("Not supported") - - # Find best trace for runtime - for i, trace_info in enumerate(source_time): - # Unpack the dict - label = None - trace = None - for k, v in trace_info.items(): - label = k - trace = v - - trace_time, trace_mem, res = benchmark_trace(trace, iters=10) - del res - self.debug_msg += ( - f"Trace name = [{label}] - Time = {trace_time} ms - Mem = {trace_mem / (2**30)} GB\n{trace}\n\n" - ) + # We cached every optimized fw traces as they might impact differently on the bw trace + # Number of fw traces to cached are: #fusion_executors * 2 + def fw_benchmark(): + match self.optimization_algorithm: + case OptimizationAlgorithm.BEST_FUSER: + # The optimizator builds the results in order following the self.fusion_executors list order + for pair_time, pair_mem in zip(self.fusion_strat_helper.optimized_traces_time, self.fusion_strat_helper.optimized_traces_mem): + # pair is a dict + trc_time = list(pair_time.values())[0] + trc_mem = list(pair_mem.values())[0] + label = list(pair_time.keys())[0] + # TODO (matteochen): remove the benchmark here as will done later on the bw pass + c, m, _ = benchmark_trace(trc_time, self.benchmark_iters) + self.log( + f'Benchmark fw end: Trace = [{label}] (time = {c} ms, mem = {m / (2**30)} GB)":\n{trc_time}') + self.debug_msg += ( + f"Trace name = [{label}] - Target: TIME - Time = {c} ms - Mem = {m / (2**30)} GB\n{trc_time}\n\n" + ) + c, m, _ = benchmark_trace(trc_mem, self.benchmark_iters) + self.log( + f'Benchmark fw end: Trace = [{label}] (time = {c} ms, mem = {m / (2**30)} GB)":\n{trc_mem}') + self.debug_msg += ( + f"Trace name = [{label}] - Target: MEM - Mem = {m / (2**30)} GB - Time = {c} ms\n{trc_mem}\n\n" + ) + + self.cached_fw_traces[label] = TraceCandidates(best_time = trc_time, best_mem = trc_mem) + + + def bw_benchmark(): + time_result = self.BenchmarkResult() + memory_result = self.BenchmarkResult() + + # Find best trace for runtime + for i, pair in enumerate(self.fusion_strat_helper.optimized_traces_time_benchmark_only): + # Unpack the dict + label = list(pair.keys())[0] + trace = list(pair.values())[0] + trace_time, trace_mem, res = benchmark_trace(trace, self.benchmark_iters) + self.debug_msg += ( + f"Trace name = [{label}] - Target: TIME - Time = {trace_time} ms - Mem = {trace_mem / (2**30)} GB\n{trace}\n\n" + ) + self.log( + f'Benchmark trace (target TIME) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}' + ) + if trace_time < time_result.tm: + time_result.tm = trace_time + time_result.mem = trace_mem + time_result.trace = trace + time_result.label = label + time_result.index = i + + # Find best trace for memory + for i, pair in enumerate(self.fusion_strat_helper.optimized_traces_mem_benchmark_only): + # Unpack the dict + label = list(pair.keys())[0] + trace = list(pair.values())[0] + + trace_time, trace_mem, res = benchmark_trace(trace, self.benchmark_iters) + del res + self.debug_msg += ( + f"Trace name = [{label}] - Target: MEM - Mem = {trace_mem / (2**30)} GB - Time = {trace_time} ms\n{trace}\n\n" + ) + self.log( + f'Benchmark trace (target MEM) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}' + ) + if trace_mem < memory_result.mem: + memory_result.tm = trace_time + memory_result.mem = trace_mem + memory_result.trace = trace + memory_result.label = label + memory_result.index = i + self.log( - f'Benchmark trace (target TIME) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}' - ) - if trace_time < tm.tm: - tm.tm = trace_time - tm.mem = trace_mem - tm.trace = trace - tm.label = label - tm.index = i - - # Find best trace for memory - for i, trace_info in enumerate(source_mem): - # Unpack the dict - label = None - trace = None - for k, v in trace_info.items(): - label = k - trace = v - - trace_time, trace_mem, res = benchmark_trace(trace, iters=10) - del res - self.debug_msg += ( - f"Trace name = [{label}] - Mem = {trace_mem / (2**30)} GB - Time = {trace_time} ms\n{trace}\n\n" - ) + f'Benchmark end: Best trace time "{time_result.label} (time = {time_result.tm} ms)":\n{time_result.trace}') self.log( - f'Benchmark trace (target MEM) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}' - ) - if trace_mem < mem.mem: - mem.tm = trace_time - mem.mem = trace_mem - mem.trace = trace - mem.label = label - mem.index = i - - self.log( - f'Benchmark end: Best trace time "{tm.label} (time = {tm.tm} ms)":\n{tm.trace}') - self.log( - f'Benchmark end: Best trace mem "{mem.label} (mem = {mem.mem / (2 ** 30)} GB)":\n{mem.trace}') - - # TODO (matteochen): remove this - self.log("Strat comparison") - c, m, o = benchmark_trace(tm.trace) - del o - self.log(f"best time: {c} ms, {m/(2**30)} GB") - c, m, o = benchmark_trace(mem.trace) - del o - self.log(f"best mem: {c} ms, {m/(2**30)} GB") - - match self.optimization_algorithm: - case OptimizationAlgorithm.GREEDY: - raise AssertionError("Not implemented") - case OptimizationAlgorithm.BEST_FUSER: - # Here we have to recover the traces without the pass through remat in order to be compliant - # with thunder flow as we might have request for no remat - - d = self.fusion_strat_helper.optimized_traces_time[tm.index] - t = None - # Unpack dict - for _, v in d.items(): - t = v - if t is None: - raise AssertionError("None trace") - - match self.trace_type: - case TraceType.FW: - self.fw_trace_candidates.attach_best_time_candidate(t) - case TraceType.BW: - self.bw_trace_candidates.attach_best_time_candidate(t) - - d = self.fusion_strat_helper.optimized_traces_mem[mem.index] - t = None - # Unpack dict - for _, v in d.items(): - t = v - if t is None: - raise AssertionError("None trace") - - match self.trace_type: - case TraceType.FW: - self.fw_trace_candidates.attach_best_mem_candidate(t) - case TraceType.BW: - self.bw_trace_candidates.attach_best_mem_candidate(t) + f'Benchmark end: Best trace mem "{memory_result.label} (mem = {memory_result.mem / (2 ** 30)} GB)":\n{memory_result.trace}') + + # TODO (matteochen): remove this + # self.log(f"Strat comparison: {self.trace_type}") + # c, m, o = benchmark_trace(tm.trace) + # del o + # self.log(f"best time: {c} ms, {m/(2**30)} GB") + # c, m, o = benchmark_trace(mem.trace) + # del o + # self.log(f"best mem: {c} ms, {m/(2**30)} GB") + + # Here we have to recover the traces without the pass through remat in order to be compliant + # with thunder flow as we might have request for no remat + match self.optimization_algorithm: + case OptimizationAlgorithm.BEST_FUSER: + # Unpack dict + trc = list(self.fusion_strat_helper.optimized_traces_time[time_result.index].values())[0] + self.bw_trace_candidates.attach_best_time_candidate(trc) - match self.trace_type: - case TraceType.FW: - self.log(self.fw_trace_candidates.__repr__()) - case TraceType.BW: - self.log(self.bw_trace_candidates.__repr__()) + # Unpack dict + trc = list(self.fusion_strat_helper.optimized_traces_mem[memory_result.index].values())[0] + self.bw_trace_candidates.attach_best_mem_candidate(trc) + + self.log(self.bw_trace_candidates.__repr__()) - # Now, finally build the pair fw and bw traces for the requested strat - if self.trace_type == TraceType.BW: + # Now, finally build the pair fw and bw traces for the requested strat + # The current fw trace is set by the caller and we take it as is. All current bw traces optimizations are made with the fw trace set by the caller forward_time, forward_memory, _ = benchmark_trace( - self.cached_fw_trace, iters=10) + self.active_fw_trace, self.benchmark_iters) match self.optimizer_type: case OptimizerType.RUNTIME: # Used the computed benchmark from above - if tm.tm < mem.tm: + if time_result.tm < memory_result.tm: self.log( - f"out candidate times: (fw){forward_time} ms, (bw){tm.tm} ms") + f"out candidate times: (fw){forward_time} ms, (bw){time_result.tm} ms") self.out.append( FinalOutputCandidates( - fw=self.cached_fw_trace, + fw=self.active_fw_trace, bw=self.bw_trace_candidates.best_time, - cost=forward_time + tm.tm, + cost=forward_time + time_result.tm, ) ) else: self.log( - f"out candidate times: (fw){forward_time} ms, (bw){mem.tm} ms") + f"out candidate times: (fw){forward_time} ms, (bw){memory_result.tm} ms") self.out.append( FinalOutputCandidates( - fw=self.cached_fw_trace, + fw=self.active_fw_trace, bw=self.bw_trace_candidates.best_mem, - cost=forward_time + mem.tm, + cost=forward_time + memory_result.tm, ) ) case OptimizerType.MEMORY: # Used the computed benchmark from above - if tm.mem < mem.mem: + if time_result.mem < memory_result.mem: self.log( - f"out candidate mem: (fw){forward_memory / (2**30)} GB, (bw){tm.mem} GB") + f"out candidate mem: (fw){forward_memory / (2**30)} GB, (bw){time_result.mem} GB") self.out.append( FinalOutputCandidates( - fw=self.cached_fw_trace, + fw=self.active_fw_trace, bw=self.bw_trace_candidates.best_time, - cost=forward_memory + tm.mem, + cost=forward_memory + time_result.mem, ) ) else: self.log( - f"out candidate mem: (fw){forward_memory} GB, (bw){mem.mem} GB") + f"out candidate mem: (fw){forward_memory} GB, (bw){memory_result.mem} GB") self.out.append( FinalOutputCandidates( - fw=self.cached_fw_trace, + fw=self.active_fw_trace, bw=self.bw_trace_candidates.best_mem, - cost=forward_memory + mem.mem, + cost=forward_memory + memory_result.mem, ) ) + match self.trace_type: + case TraceType.FW: + fw_benchmark() + case TraceType.BW: + bw_benchmark() + if self.produce_log: import time - timestamp: str = str(time.time()) with open(f"{timestamp}-{self.log_file_name}", "w") as file: file.write(self.debug_msg) file.close() + self.debug_msg = "" + def return_not_used_vars(trace_in: TraceCtx) -> list[TensorProxy]: def is_in_sequence(seq: Sequence[Any], t: TensorProxy): @@ -1496,7 +1409,6 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: else: raise AssertionError("Unexpexcted args type") - # Always benchmark trace after a deletion last used pass as the final trace out will passed under this stage if apply_del_last_used: trace = del_last_used(trace) diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index 3427d9724f..7977eeb06c 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -172,13 +172,11 @@ def autotune_transform_for_execution( # Assign the trace provenance match trace_type: case TraceType.FW: - fw_extrace_time, fw_extrace_mem = optimizer_context.get_optimal_fw_traces_time_and_mem() - fw_extrace_time.set_provenance( - TraceProvenance(f"Autotuned transform for execution (took {elapsed_time_millis} milliseconds)") - ) - fw_extrace_mem.set_provenance( - TraceProvenance(f"Autotuned transform for execution (took {elapsed_time_millis} milliseconds)") - ) + fw_traces = optimizer_context.get_optimal_fw_traces() + for trc in fw_traces: + trc.set_provenance( + TraceProvenance(f"Autotuned transform for execution (took {elapsed_time_millis} milliseconds)") + ) return None case TraceType.BW: bw_extrace.set_provenance( From 5dedbfffb4f7a5a948b1d2b2fed4926eced6a4c8 Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 1 Aug 2024 11:57:40 +0300 Subject: [PATCH 024/171] Before `transform_for_execution` executors placement autotune / nvsight benchmark interface / nanoGPT model test (#5) --- examples/dev/LLaMAMLP.py | 41 +- examples/dev/MLP.py | 36 +- examples/dev/csa.py | 153 ------- examples/dev/litGPT.py | 76 ++-- examples/dev/nanogpt-block.py | 140 +++++++ examples/dev/nanogpt.py | 162 ++++++++ examples/dev/sdpa.py | 51 ++- examples/dev/simple.py | 124 +----- thunder/backend_optimizer/optimizer.py | 535 ++++++++++-------------- thunder/benchmarks/utils.py | 136 +++++++ thunder/executors/torch_autograd.py | 541 ++++++++++++++++--------- 11 files changed, 1129 insertions(+), 866 deletions(-) delete mode 100644 examples/dev/csa.py create mode 100644 examples/dev/nanogpt-block.py create mode 100644 examples/dev/nanogpt.py create mode 100644 thunder/benchmarks/utils.py diff --git a/examples/dev/LLaMAMLP.py b/examples/dev/LLaMAMLP.py index 28d26e867a..db208eb108 100644 --- a/examples/dev/LLaMAMLP.py +++ b/examples/dev/LLaMAMLP.py @@ -1,5 +1,6 @@ import torch import thunder +from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark class LLaMAMLP(torch.nn.Module): def __init__(self, n_embd, intermediate_size) -> None: @@ -14,7 +15,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.proj(x) with torch.device('cuda'): - from thunder.backend_optimizer.optimizer import benchmark_trace # See changes from mult = 1 to mult = 4 mult = 1 a = 4096 * mult @@ -26,30 +26,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: jmodel_def = thunder.jit(model) jmodel_auto = thunder.jit(model, autotune_type='memory', executors = ['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python']) - y = model(x) - print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) + print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) + + print('Results with thunder benchmark:') + traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] + labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] + thunder_fw_bw_benchmark(traces, labels, 50, nvsight = False) + + callables = [jmodel_def, jmodel_auto] + labels = ['def', 'auto'] + inputs = [x, x] + print('Results with torch fw bw benchmark:') + torch_fw_bw_benchmark(callables, labels, inputs, 50) - print('########################################') - c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='LLaMAMLP_def_fw', iters=10) - print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') - del o - c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='LLaMAMLP_auto_fw', iters=10) - print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') - del o - c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='LLaMAMLP_def_bw', iters=10) - print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') - del o - c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='LLaMAMLP_auto_bw', iters=10) - print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') - del o - - print('\n\n\n\n\n\n') - print(f'{thunder.last_traces(jmodel_def)[-1]}') - print('###############################################################################') - print(f'{thunder.last_traces(jmodel_auto)[-1]}') - - print('\n\n') - print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') - print('###############################################################################') - print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') + for t in traces: + print(f'{t}\n#####################################') diff --git a/examples/dev/MLP.py b/examples/dev/MLP.py index 4c1fa488ae..9dfa7510a6 100644 --- a/examples/dev/MLP.py +++ b/examples/dev/MLP.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import thunder -from thunder.backend_optimizer.optimizer import benchmark_trace +from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_fw_bw_benchmark_nvsight class ModelConfig: def __init__(self, n_embd=256, n_head=8, dropout=0.1, block_size=64, bias=True): @@ -36,29 +36,23 @@ def forward(self, x): jmodel_def = thunder.jit(model) # This model fails under some circumstances after passed the placed traced under the rematelizer - jmodel_auto = thunder.jit(model, autotune_type='runtime', executors = ['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python']) + jmodel_auto = thunder.jit(model, autotune_type='runtime', executors = ['nvfuser', 'torchcompile', 'sdpa', 'torch', 'python']) - y = model(x) - print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) + print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) - print('########################################') - c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_def_fw', iters=10) - print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') - c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_auto_fw', iters=10) - print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') - c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_def_bw', iters=10) - print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') - c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='MLP_auto_bw', iters=10) - print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') + print('Results with thunder benchmark:') + traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] + labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] + thunder_fw_bw_benchmark(traces, labels, 50, nvsight = False) - print('\n\n\n\n\n\n') - print(f'{thunder.last_traces(jmodel_def)[-1]}') - print('###############################################################################') - print(f'{thunder.last_traces(jmodel_auto)[-1]}') + callables = [jmodel_def, jmodel_auto] + labels = ['def', 'auto'] + inputs = [x, x] + print('Results with torch fw bw benchmark:') + torch_fw_bw_benchmark(callables, labels, inputs, 50) - print('\n\n') - print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') - print('###############################################################################') - print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') + # for t in traces: + # print(t) + # print('##########################') diff --git a/examples/dev/csa.py b/examples/dev/csa.py deleted file mode 100644 index c011e5d48b..0000000000 --- a/examples/dev/csa.py +++ /dev/null @@ -1,153 +0,0 @@ -import torch -import torch.nn as nn -import thunder -from thunder.backend_optimizer.optimizer import benchmark_trace - -# import torch._dynamo -# torch._dynamo.config.suppress_errors = True - -class CausalSelfAttention(nn.Module): - - def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0): - super().__init__() - assert embed_dimension % num_heads == 0 - # key, query, value projections for all heads, but in a batch - self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias) - # output projection - self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias) - # regularization - self.dropout = dropout - self.resid_dropout = nn.Dropout(dropout) - self.num_heads = num_heads - self.embed_dimension = embed_dimension - # Perform causal masking - self.is_causal = is_causal - - def forward(self, x): - # calculate query, key, values for all heads in batch and move head forward to be the batch dim - query_projected = self.c_attn(x) - - batch_size = query_projected.size(0) - embed_dim = query_projected.size(2) - head_dim = embed_dim // (self.num_heads * 3) - - query, key, value = query_projected.chunk(3, -1) - query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) - - if self.training: - dropout = self.dropout - is_causal = self.is_causal - else: - dropout = 0.0 - is_causal = False - - y = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal) - y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim) - - y = self.resid_dropout(self.c_proj(y)) - return y - -device = torch.device('cuda') -num_heads = 8 -heads_per_dim = 64 * 4 -embed_dimension = num_heads * heads_per_dim -dtype = torch.float32 -model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to(device).to(dtype) -print(model) -batch_size = 16 -max_sequence_len = 1024 -x = torch.randn(batch_size, max_sequence_len, embed_dimension, dtype=dtype, requires_grad=True, device=device) - -jmodel_def = thunder.jit(model) -jmodel_auto = thunder.jit(model, autotune_type='runtime') - -warm_up_iters = 2 -iters = 10 -stream = torch.cuda.current_stream() - -y = model(x) -for _ in range(warm_up_iters): - yy = jmodel_def(x) - yyy = jmodel_auto(x) - torch.autograd.grad(yy, x, grad_outputs=torch.ones_like(y)) - torch.autograd.grad(yyy, x, grad_outputs=torch.ones_like(y)) - -print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) -print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) - -# print('\n\n') - -# start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] -# middle_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] -# end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - -# for i in range(iters): -# torch.cuda.empty_cache() -# torch.cuda._sleep(1_000_000) -# start_events[i].record(stream) -# y = jmodel_auto(x) -# middle_events[i].record(stream) -# torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y)) -# end_events[i].record(stream) - -# torch.cuda.synchronize() -# fw = [s.elapsed_time(e) for s, e in zip(start_events, middle_events)] -# bw = [s.elapsed_time(e) for s, e in zip(middle_events, end_events)] -# tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] -# fw_time = sum(fw) -# bw_time = sum(bw) -# tot_time = sum(tot) -# print(f'Auto fw: {fw_time / iters}') -# print(f'Auto bw: {bw_time / iters}') -# print(f'Auto tot: {tot_time / iters}') -# print('\n') - -# start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] -# middle_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] -# end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - -# for i in range(iters): -# torch.cuda.empty_cache() -# torch.cuda._sleep(1_000_000) -# start_events[i].record(stream) -# y = jmodel_def(x) -# middle_events[i].record(stream) -# torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y)) -# end_events[i].record(stream) - -# torch.cuda.synchronize() -# fw = [s.elapsed_time(e) for s, e in zip(start_events, middle_events)] -# bw = [s.elapsed_time(e) for s, e in zip(middle_events, end_events)] -# tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] -# fw_time = sum(fw) -# bw_time = sum(bw) -# tot_time = sum(tot) -# print(f'Default fw: {fw_time / iters}') -# print(f'Default bw: {bw_time / iters}') -# print(f'Default tot: {tot_time / iters}') -# print('-------------------------------------------------------') - -c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], iters = 10, apply_del_last_used=False, snapshot=True, snapshot_name='def_fw') -print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') -del o -c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], iters=10, apply_del_last_used=False, snapshot=True, snapshot_name='auto_fw') -print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') -del o -c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], iters=10, apply_del_last_used=False, snapshot=True, snapshot_name='def_bw') -print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') -del o -c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], iters=10, apply_del_last_used=False, snapshot=True, snapshot_name='auto_bw') -print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') -del o - -print('\n\n\n\n\n\n') -print(f'{thunder.last_traces(jmodel_def)[-1]}') -print('###############################################################################') -print(f'{thunder.last_traces(jmodel_auto)[-1]}') - -print('\n\n') -print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') -print('###############################################################################') -print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index effb664c25..a626201e6f 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -1,49 +1,55 @@ from litgpt import GPT +from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark from thunder.tests.litgpt_model import Config import thunder import torch -from thunder.backend_optimizer.optimizer import benchmark_trace class Test: def __init__(self, layers: int, autotune_type: str) -> None: self.layers = layers self.autotune_type = autotune_type -layers = [Test(1, 'runtime')] +layers = [Test(8, 'runtime')]#, Test(8, 'runtime'), Test(16, 'runtime')] model_name = 'Llama-2-7b-hf' for test in layers: - print('Layers:', test.layers) - cfg = Config.from_name(model_name) - cfg.n_layer = test.layers - torch.set_default_dtype(torch.bfloat16) - with torch.device('cuda'): - model = GPT(cfg) - x = torch.randint(1, model.config.vocab_size, (1, 512)) - - jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type=test.autotune_type, executors = ['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python']) - - print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) - print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) - - print('Results ########################################') - c, m, _ = benchmark_trace(thunder.last_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name=model_name) - print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') - c, m, _ = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name=model_name) - print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') - c, m, _ = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name=model_name) - print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') - c, m, _ = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name=model_name) - print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') - - print('\n\n\n\n\n\n') - print(f'{thunder.last_traces(jmodel_def)[-1]}') - print('###############################################################################') - print(f'{thunder.last_traces(jmodel_auto)[-1]}') - - print('\n\n') - print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') - print('###############################################################################') - print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') + try: + print('Layers:', test.layers) + cfg = Config.from_name(model_name) + cfg.n_layer = test.layers + torch.set_default_dtype(torch.bfloat16) + with torch.device('cuda'): + model = GPT(cfg) + x = torch.randint(1, model.config.vocab_size, (1, 512)) + + jmodel_def = thunder.jit(model) + # Torchcompile gives some troubles for now + jmodel_auto = thunder.jit(model, autotune_type=test.autotune_type, executors = ['nvfuser', 'cudnn', 'torch', 'python']) + + print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) + print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) + + + print('Results thunder benchmark:') + traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] + labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] + thunder_fw_bw_benchmark(traces, labels, 50) + + print('\n\nResults torch fw bw benchmark:') + callables = [jmodel_def, jmodel_auto] + labels = ['def', 'auto'] + inputs = [x, x] + torch_fw_bw_benchmark(callables, labels, inputs, 50) + + print('\n\n\n\n\n\n') + print(f'{thunder.last_traces(jmodel_def)[-1]}') + print('###############################################################################') + print(f'{thunder.last_traces(jmodel_auto)[-1]}') + + print('\n\n') + print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') + print('###############################################################################') + print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') + except Exception as e: + print(f'Test failed:\n{e}') diff --git a/examples/dev/nanogpt-block.py b/examples/dev/nanogpt-block.py new file mode 100644 index 0000000000..59ca2fe21e --- /dev/null +++ b/examples/dev/nanogpt-block.py @@ -0,0 +1,140 @@ +import math +import torch +import torch.nn as nn +from torch.nn import functional as F +import thunder +from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark + +# torch.set_default_dtype(torch.bfloat16) + +class LayerNorm(nn.Module): + """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ + + def __init__(self, ndim, bias): + super().__init__() + self.weight = nn.Parameter(torch.ones(ndim)) + self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None + + def forward(self, input): + return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) + +class CausalSelfAttention(nn.Module): + + def __init__(self, config): + super().__init__() + assert config.n_embd % config.n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) + # regularization + self.attn_dropout = nn.Dropout(config.dropout) + self.resid_dropout = nn.Dropout(config.dropout) + self.n_head = config.n_head + self.n_embd = config.n_embd + self.dropout = config.dropout + # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 + self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') + if not self.flash: + print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") + # causal mask to ensure that attention is only applied to the left in the input sequence + self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) + .view(1, 1, config.block_size, config.block_size)) + + def forward(self, x): + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + if self.flash: + # efficient attention using Flash Attention CUDA kernels + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) + else: + # manual implementation of attention + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + +class MLP(nn.Module): + + def __init__(self, config): + super().__init__() + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + self.gelu = nn.GELU() + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.dropout = nn.Dropout(config.dropout) + + def forward(self, x): + x = self.c_fc(x) + x = self.gelu(x) + x = self.c_proj(x) + x = self.dropout(x) + return x + +class GPTConfig: + block_size: int = 1024 + vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: int = 12 + n_head: int = 12 + n_embd: int = 3072 + dropout: float = 0.0 + bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + +class Block(nn.Module): + + def __init__(self, config): + super().__init__() + self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) + self.attn = CausalSelfAttention(config) + self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) + self.mlp = MLP(config) + + def forward(self, x): + x = x + self.attn(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + +with torch.device('cuda'): + config = GPTConfig() + model = Block(config) + x = torch.randn((16, 1024, 3072), dtype=torch.float32) + + jmodel_def = thunder.jit(model) + jmodel_auto = thunder.jit(model, autotune_type='runtime', executors = ['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python']) + + print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) + print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) + + + print('Results thunder benchmark:') + traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] + labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] + thunder_fw_bw_benchmark(traces, labels, 50) + + print('\n\nResults torch fw bw benchmark:') + callables = [jmodel_def, jmodel_auto] + labels = ['def', 'auto'] + inputs = [x, x] + torch_fw_bw_benchmark(callables, labels, inputs, 50) + + print('\n\n\n\n\n\n') + print(f'{thunder.last_traces(jmodel_def)[-1]}') + print('###############################################################################') + print(f'{thunder.last_traces(jmodel_auto)[-1]}') + + print('\n\n') + print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') + print('###############################################################################') + print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') diff --git a/examples/dev/nanogpt.py b/examples/dev/nanogpt.py new file mode 100644 index 0000000000..23318334c6 --- /dev/null +++ b/examples/dev/nanogpt.py @@ -0,0 +1,162 @@ +import torch +import thunder +from thunder.benchmarks.utils import thunder_fw_bw_benchmark +from thunder.tests.nanogpt_model import GPTConfig, GPT + +warm_up_iters = 50 + +def torch_fw_bw_benchmark(models: list, labels: list, inputs: list, iters: int) -> None: + + for m, input, label in zip(models, inputs, labels): + # Warm up + for _ in range(warm_up_iters): + _, loss = m(input) + loss.backward() + + torch.cuda.synchronize() + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + stream = torch.cuda.current_stream() + max_allocated_bytes = 0 + for i in range(iters): + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + + start_events[i].record(stream) + out = m(input) + end_events[i].record(stream) + + max_allocated_bytes = max( + max_allocated_bytes, torch.cuda.max_memory_allocated( + torch.cuda.current_device()) + ) + + torch.cuda.synchronize() + tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + tot_time = sum(tot) / iters + print(f'{label} tot fw time: {tot_time} ms') + print(f'{label} max fw allocated memory: {max_allocated_bytes / (2**30)} GB') + + torch.cuda.synchronize() + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + stream = torch.cuda.current_stream() + max_allocated_bytes = 0 + for i in range(iters): + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + + _, loss = m(input) + loss.backward() + start_events[i].record(stream) + loss.backward() + end_events[i].record(stream) + + max_allocated_bytes = max( + max_allocated_bytes, torch.cuda.max_memory_allocated( + torch.cuda.current_device()) + ) + + torch.cuda.synchronize() + tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + tot_time = sum(tot) / iters + print(f'{label} tot bw time: {tot_time} ms') + print(f'{label} max bw allocated memory: {max_allocated_bytes / (2**30)} GB') + +def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int) -> None: + + for m, input, label in zip(models, inputs, labels): + # Warm up + for _ in range(warm_up_iters): + _, loss = m(input) + loss.backward() + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + torch.cuda.synchronize() + stream = torch.cuda.current_stream() + max_allocated_bytes = 0 + for i in range(iters): + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + + start_events[i].record(stream) + _, loss = m(input) + loss.backward() + end_events[i].record(stream) + + max_allocated_bytes = max( + max_allocated_bytes, torch.cuda.max_memory_allocated( + torch.cuda.current_device()) + ) + + torch.cuda.synchronize() + tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + tot_time = sum(tot) / iters + print(f'{label} tot time: {tot_time} ms') + print(f'{label} max allocated memory: {max_allocated_bytes / (2**30)} GB') + +# ----------------------------------------------------------------------------- +batch_size = 12 +block_size = 1024 +bias = False +seed = 1337 +device = 'cuda' +# dtype = 'float16' +dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' +# ----------------------------------------------------------------------------- +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) +torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul +torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn +device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast +ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] + +x = torch.randint(50304, (batch_size, block_size), device=device) +y = torch.randint(50304, (batch_size, block_size), device=device) +get_batch = lambda split: (x, y) + +# model init +gptconf = GPTConfig( + block_size = block_size, # how far back does the model look? i.e. context size + n_layer = 1, n_head = 12, n_embd = 768, # size of the model + dropout = 0, # for determinism + bias = bias, +) +model = GPT(gptconf) +model.to(device) + +jmodel_def = thunder.jit(model) +jmodel_auto = thunder.jit(model, autotune_type='runtime', executors = ['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python']) + +X, Y = get_batch('train') +# Run compilation +jmodel_def(x, y) +jmodel_auto(x, y) + +print('Results thunder benchmark:') +traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] +labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] +thunder_fw_bw_benchmark(traces, labels, 50) + +print('\n\nResults torch fw bw benchmark:') +callables = [jmodel_def, jmodel_auto] +labels = ['def', 'auto'] +inputs = [x, x] +torch_fw_bw_benchmark(callables, labels, inputs, 50) +print('\n\nResults torch tot benchmark:') +torch_total_benchmark(callables, labels, inputs, 50) + +print('\n\n\n\n\n\n') +print(f'{thunder.last_traces(jmodel_def)[-1]}') +print('###############################################################################') +print(f'{thunder.last_traces(jmodel_auto)[-1]}') + +print('\n\n') +print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') +print('###############################################################################') +print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') + diff --git a/examples/dev/sdpa.py b/examples/dev/sdpa.py index a4d1d35328..d307cae578 100644 --- a/examples/dev/sdpa.py +++ b/examples/dev/sdpa.py @@ -1,28 +1,49 @@ import torch import thunder +from thunder.backend_optimizer.optimizer import benchmark_trace -class Module(torch.nn.Module): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) +class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() def forward(self, query, key, value): - query = query + query - key = key * key a = torch.nn.functional.scaled_dot_product_attention(query, key, value) return a - with torch.device('cuda'): - module = Module() - j_module = thunder.jit(module) - - query = torch.rand(32, 8, 128, 64, dtype=torch.float16) - key = torch.rand(32, 8, 128, 64, dtype=torch.float16) - value = torch.rand(32, 8, 128, 64, dtype=torch.float16) - - ans = j_module(query, key, value) + model = Model() + + jmodel_def = thunder.jit(model) + # Order does not matter anymore + jmodel_auto = thunder.jit(model, autotune_type='runtime', executors = ['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python']) + + q = torch.rand(32, 8, 128, 64*16, dtype=torch.float32, requires_grad=True) + k = torch.rand(32, 8, 128, 64*16, dtype=torch.float32, requires_grad=True) + v = torch.rand(32, 8, 128, 64*16, dtype=torch.float32, requires_grad=True) + + print('deviation def:', (jmodel_def(q, k, v) - model(q, k, v)).abs().max().item()) + print('deviation auto:', (jmodel_auto(q, k, v) - model(q, k, v)).abs().max().item()) + + print('########################################') + c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='sdpa_def_fw', iters=10) + print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') + c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='sdpa_auto_fw', iters=10) + print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') + c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='sdpa_def_bw', iters=10) + print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') + c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='sdpa_auto_bw', iters=10) + print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') + + print('\n\n\n\n\n\n') + print(f'{thunder.last_traces(jmodel_def)[-1]}') + print('###############################################################################') + print(f'{thunder.last_traces(jmodel_auto)[-1]}') + + print('\n\n') + print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') + print('###############################################################################') + print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') - print(thunder.last_traces(j_module)[-1]) diff --git a/examples/dev/simple.py b/examples/dev/simple.py index a3128ffda2..b995a83102 100644 --- a/examples/dev/simple.py +++ b/examples/dev/simple.py @@ -1,6 +1,6 @@ import torch import thunder -from thunder.backend_optimizer.optimizer import benchmark_trace +from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark class Module(torch.nn.Module): def __init__(self, in_features, out_features) -> None: @@ -12,120 +12,28 @@ def forward(self, x: torch.Tensor): a = x + x b: torch.Tensor = self.linear(a) c = b * b - d = c + c - return self.silu(d) + return self.silu(c) with torch.device('cuda'): - multiplier = 1000 + multiplier = 100 in_features = 20 * multiplier out_features = 30 * multiplier model = Module(in_features, out_features) x = torch.randn(128, in_features, requires_grad=True) - jmodel_def = thunder.jit(model, autotune_executors=False) - jmodel_auto = thunder.jit(model, autotune_executors=True) - stream = torch.cuda.current_stream() + jmodel_def = thunder.jit(model) + jmodel_auto = thunder.jit(model, autotune_type='runtime', executors=['nvfuser', 'torchcompile', 'cudnn', 'sdpa', 'torch', 'python']) - warm_up_iters = 2 - iters = 10 - for _ in range(warm_up_iters): - y = jmodel_auto(x) - yy = jmodel_def(x) - grad_outputs = torch.ones_like(y) - torch.autograd.grad(y, x, grad_outputs=grad_outputs) - torch.autograd.grad(yy, x, grad_outputs=grad_outputs) + y = jmodel_def(x) + y = jmodel_auto(x) - print('\n\n') + print('Results thunder benchmark:') + traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] + labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] + thunder_fw_bw_benchmark(traces, labels, 50) - for i in range(1): - - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - middle_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - - for i in range(iters): - torch.cuda.empty_cache() - torch.cuda._sleep(1_000_000) - start_events[i].record(stream) - y = jmodel_auto(x) - middle_events[i].record(stream) - torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y)) - end_events[i].record(stream) - - torch.cuda.synchronize() - fw = [s.elapsed_time(e) for s, e in zip(start_events, middle_events)] - bw = [s.elapsed_time(e) for s, e in zip(middle_events, end_events)] - tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] - fw_time = sum(fw) - bw_time = sum(bw) - tot_time = sum(tot) - print(f'Auto fw: {fw_time / iters}') - print(f'Auto bw: {bw_time / iters}') - print(f'Auto tot: {tot_time / iters}') - print('\n') - - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - middle_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - - for i in range(iters): - torch.cuda.empty_cache() - torch.cuda._sleep(1_000_000) - start_events[i].record(stream) - y = jmodel_def(x) - middle_events[i].record(stream) - torch.autograd.grad(y, x, grad_outputs=torch.ones_like(y)) - end_events[i].record(stream) - - torch.cuda.synchronize() - fw = [s.elapsed_time(e) for s, e in zip(start_events, middle_events)] - bw = [s.elapsed_time(e) for s, e in zip(middle_events, end_events)] - tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] - fw_time = sum(fw) - bw_time = sum(bw) - tot_time = sum(tot) - print(f'Default fw: {fw_time / iters}') - print(f'Default bw: {bw_time / iters}') - print(f'Default tot: {tot_time / iters}') - print('-------------------------------------------------------') - - c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], apply_del_last_used=False) - print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') - del o - c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], apply_del_last_used=False) - print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') - del o - c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], apply_del_last_used=False) - print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') - del o - c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], apply_del_last_used=False) - print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') - del o - # print('\n\n\n\n\n\n') - # print(f'{thunder.last_traces(jmodel_def)[-1]}') - # print('###############################################################################') - # print(f'{thunder.last_traces(jmodel_auto)[-1]}') - - # print('\n\n') - # print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') - # print('###############################################################################') - # print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') - - from torch.profiler import profile, record_function, ProfilerActivity - with profile(activities=[ - ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: - with record_function("def"): - y = jmodel_def(x) - grad_outputs = torch.ones_like(y) - torch.autograd.grad(y, x, grad_outputs=grad_outputs) - - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - - with profile(activities=[ - ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof: - with record_function("auto"): - y = jmodel_auto(x) - grad_outputs = torch.ones_like(y) - torch.autograd.grad(y, x, grad_outputs=grad_outputs) - - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + callables = [jmodel_def, jmodel_auto] + labels = ['def', 'auto'] + inputs = [x, x] + print('Results torch benchmark:') + torch_fw_bw_benchmark(callables, labels, inputs, 50) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 76afecfbcf..48a673ddfc 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -8,15 +8,23 @@ from thunder.core.proxies import CollectionProxy, FloatProxy, IntegerProxy, Proxy, TensorProxy, variableify, Variable from thunder.core.symbol import BoundSymbol, Symbol from thunder.core.trace import from_trace, set_tracectx, reset_tracectx, get_tracectx, TraceCtx -from thunder.executors.data_dependent_partition import Graph, Node +from thunder.executors.data_dependent_partition import Node from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_always_executors from thunder.visualizer.visualizer_helper import Visualizer from typing import Any, Hashable import thunder import thunder.core.transforms as transforms -import concurrent.futures import torch +# Currently this manages both time and memory +class BenchmarkResult: + def __init__(self) -> None: + self.tm: float = float("inf") + self.mem: float = float("inf") + self.trace: TraceCtx | None = None + self.label: str | Hashable = "" + self.index = -1 + class OptimizerType(Enum): MEMORY = 0 @@ -29,8 +37,7 @@ class TraceType(Enum): class OptimizationAlgorithm(Enum): - GREEDY = 0 - BEST_FUSER = 1 + BEST_FUSER = 0 class OptimizerNode: @@ -76,7 +83,6 @@ def __repr__(self) -> str: # Benchmark only traces will contain traces after the rematerialization call with fw and bw calls, reproducing what will be the real traces after the autotune pass # Non benchmark traces will contain traces after the placement (default) with no call to remat # We have duplciated those in order to maintain thunder compilation flow as the output from the autotuner will be the traces with no pass through rematerialization -# TODO (matteochen): currently the GREEDY strat is using this data structure, fix this class FusionStratHelper: def __init__(self) -> None: self.supported_executors: set = set(["nvfuser", "torchcompile"]) @@ -94,11 +100,22 @@ def __init__(self) -> None: self.placement_options_time: list[list[Executor]] = [] -class BackendOptimizer: - def log(self, what: str): +class LogLevel(Enum): + DEBUG = 0 + INFO = 1 + + +log_level: LogLevel = LogLevel.INFO + + +def log(what: str, level: LogLevel): + if log_level == LogLevel.DEBUG or log_level == level: print( f"================================================================================ Autotune: {what}") + +class BackendOptimizer: + def __init__( self, *, @@ -130,21 +147,19 @@ def __init__( self.bw_trace_candidates: TraceCandidates = TraceCandidates() self.out: list[FinalOutputCandidates] = [] - # Strat greedy - self.computation_graph: Graph - # Strat fusion self.fusion_strat_helper: FusionStratHelper = FusionStratHelper() self.executor_placement_options: ExecutorPlacementOptions = ExecutorPlacementOptions() self.apply_bucketing_bw_trace: bool = apply_bucketing_bw_trace - self.benchmark_iters = 10 + self.benchmark_iters = 5 - self.log("Executors:") + log("Executors:", level=LogLevel.INFO) for e in self.executors: - self.log( - f"{e.name} -> is operator = {isinstance(e, OperatorExecutor)}, is fusion = {isinstance(e, FusionExecutor)}" + log( + f"{e.name} -> is operator = {isinstance(e, OperatorExecutor)}, is fusion = {isinstance(e, FusionExecutor)}", + level=LogLevel.INFO, ) class SearchNode: @@ -152,15 +167,6 @@ def __init__(self, symbol: BoundSymbolInterface, idx: int) -> None: self.symbol = symbol self.idx = idx - # Currently this manages both time and memory - class BenchmarkResult: - def __init__(self) -> None: - self.tm: float = float("inf") - self.mem: float = float("inf") - self.trace: TraceCtx | None = None - self.label: str | Hashable = "" - self.index = -1 - def attach_cached_fw_traces(self, cached_fw_traces: TraceCandidates, executor_name: str) -> None: self.cached_fw_traces[executor_name] = cached_fw_traces @@ -174,161 +180,15 @@ def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): match self.trace_type: case TraceType.FW: - self.log( - f"New forward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}") + log( + f"New forward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}", level=LogLevel.INFO) # TODO (matteochen): support bw trace optimization even though with no fw traces cached case TraceType.BW: if not self.cached_fw_traces: raise AssertionError( "Can not optimize backward traces before forward traces") - self.log( - f"New backward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}") - - def build_placement_options_greedy(self): - import sys - - old_max_recursion = sys.getrecursionlimit() - sys.setrecursionlimit(2000) - - # Last index inclusive - def benchmark_partial_trace( - trace_in: TraceCtx, last_idx: int, configuration: list[Executor] - ) -> tuple[float, TraceCtx]: - def safe_update_dict(d: dict, key_one, key_two, value): - if key_one not in d: - d[key_one] = {key_two: value} - else: - d[key_one][key_two] = value - - # Retrive all output tensors from each subregion - tensors = [] - for i in range(last_idx + 1): - if not isinstance(trace_in.bound_symbols[i], BoundSymbol): - raise AssertionError( - "Expected BoundSymbol but received BoundSymbolInterface") - s = trace_in.bound_symbols[i] - # For each bsym region we expect to output a Tensor - tensors.append(s.output) - - forced_return_bsym = trace_in.bound_symbols[-1].from_bsym( - args=tensors - ) # Should not be an Interface type at this point - - t = from_trace(trace_in) - # Cut the trace to the required depth - t.bound_symbols = list(trace_in.bound_symbols)[: last_idx + 1] - - t.bound_symbols.append(forced_return_bsym) - configuration.append( - Executor(name=self.empty_executor_hashable_placeholder) - ) # Empty executor for the forced_return - - # Place the assigned symbols - placed_t = self.place_optimizers(t, configuration) - - cost, mem, answer = benchmark_trace(placed_t, self.benchmark_iters) - del answer - self.log( - f"Executing partial trace for incremental benchmark:\n{placed_t}") - self.log(f"Symbol under test = {t.bound_symbols[-2].sym.name}") - self.log(f"Assigned executor = {configuration[-2].name}") - self.log(f"Time = {cost} ms") - # TODO (matteochen): log this to file - safe_update_dict(self.partial_costs, whoami, t, cost) - return cost, placed_t - - # We assign an internal id to each symbol based on its idx inside the bound_symbols list - def search(node: self.SearchNode, configuration: list[Executor]): - def continue_search(): - if node.idx + 1 < max_len: - new_idx: int = node.idx + 1 - new_symbol: BoundSymbolInterface = bound_symbols[new_idx] - search(self.SearchNode(new_symbol, new_idx), configuration) - else: - all_configurations.append(configuration) - - has_backend = False - min_cost = float("inf") - min_cost_ex = None - ex: Executor - # TODO (matteochen): do parallel for - for ex in self.executors: - cost = float("inf") - if not isinstance(node.symbol, BoundSymbol): - raise AssertionError( - "Receive a symbol which is not a BoundSymbol") - if isinstance(ex, OperatorExecutor) and ex.can_execute(node.symbol): - has_backend = True - - configuration.append(ex) - cost, _ = benchmark_partial_trace( - self.trace, node.idx, list(configuration)) - configuration.pop() - - if isinstance(ex, FusionExecutor) and ex.can_fuse(node.symbol): - has_backend = True - - configuration.append(ex) - cost, _ = benchmark_partial_trace( - self.trace, node.idx, list(configuration)) - configuration.pop() - - if cost < min_cost: - min_cost = cost - min_cost_ex = ex - - if not has_backend: - configuration.append(empty_executor) - continue_search() - else: - if min_cost_ex is None: - raise AssertionError( - "Unexpected min cost executor or trace: None") - self.log( - f"\nFor id: {node.idx} - {node.symbol.sym.name} -> best backend {min_cost_ex.name}\n") - configuration.append(min_cost_ex) - continue_search() - - bound_symbols: list[BoundSymbolInterface] = self.trace.bound_symbols - max_len = len(bound_symbols) - - all_configurations: list[list[Executor]] = [] - # Is the name reserved? - empty_executor = Executor( - name=self.empty_executor_hashable_placeholder) - - if len(bound_symbols) > 0: - search(self.SearchNode(bound_symbols[0], 0), []) - self.placement_options = all_configurations - - sys.setrecursionlimit(old_max_recursion) - - # This expects a trace after the placement call. - # Fusion operators as nvFuser can be slower on the single trace region but can be faster by combining more of them, - # try to fuse then and compare - def try_to_fuse_after_executors_placement(self, trace_in: TraceCtx) -> TraceCtx: - best_trace: TraceCtx = trace_in - best_time, best_mem, answer = benchmark_trace(best_trace, self.benchmark_iters) - del answer - trace_in_time = best_time - - for ex in self.fusion_executors: - self.log(f"Try to fuse executor {ex.name} with trace:\n{trace_in}") - extrace = ex.fusion_pass(trace_in) - self.log(f"Fused trace:\n{extrace}") - extrace_time, extrace_mem, answer = benchmark_trace( - extrace, self.benchmark_iters) - del answer - self.log(f"Fused trace time:{extrace_time} ms") - - if extrace_time < best_time: - best_time = extrace_time - best_trace = extrace - - self.log(f"Trace in (time = {trace_in_time } ms):\n{trace_in}") - self.log(f"Best fused trace (time = {best_time } ms):\n{best_trace}") - - return best_trace + log( + f"New backward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}", level=LogLevel.INFO) def place_optimizers(self, in_trace, executor_list: list[Executor]) -> TraceCtx: from thunder.executors.passes import _transform_for_operator_executor_execution @@ -421,7 +281,7 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: raise AssertionError( "len(executor_list) != len(in_trace.bound_symbols)") - # self.log(f'Visit transf') + # log(f'Visit transf') # for n, e in zip(in_trace.bound_symbols, executor_list): # print(f'{n.sym.name} -> {e.name}') cached_subsymbols: dict[str, Sequence[BoundSymbol]] = {} @@ -487,9 +347,7 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: return extrace - # TODO (matteochen): add config for exaustive search or incremental one def optimize(self, strat: OptimizationAlgorithm = OptimizationAlgorithm.BEST_FUSER): - from thunder.executors.passes import transform_for_execution from thunder.core.transform_common import replace_redundant_inputs from thunder.core.transform_common import dce @@ -523,42 +381,8 @@ def optmize_best_fuser(): self.benchmark_traces() - def optimize_greedy(): - # Reset helpers data structures - self.executor_placement_options = ExecutorPlacementOptions() - - # 1. This builds one option by default - self.build_placement_options_greedy() - - if len(self.placement_options) != 1: - raise AssertionError("Unexpected placement options size") - - option = self.placement_options[0] - trace_greedy = self.place_optimizers(self.trace, option) - # Append the unique trace - self.optimized_traces.append({"greedy": trace_greedy}) - - # 2. Try to fuse additional regions from the greedy result - # Attention, if all the fused traces perform worse that the greedy one, the greedy one is returned - # TODO (matteochen): ignore a duplicated trace - trace_greedy_fused = self.try_to_fuse_after_executors_placement( - trace_greedy) - self.optimized_traces.append({"fused_greedy": trace_greedy_fused}) - - # 3. Try the priority list approach - trace_priority = transform_for_execution( - self.trace, self.executors) - self.optimized_traces.append({"priority_list": trace_priority}) - - # There are no hidden placements hence do not call the visualizer - - # Run benchmarks - self.benchmark_traces() - def match_optimizer_algorithm(): match self.optimization_algorithm: - case OptimizationAlgorithm.GREEDY: - optimize_greedy() case OptimizationAlgorithm.BEST_FUSER: optmize_best_fuser() @@ -572,7 +396,7 @@ def match_optimizer_algorithm(): cached_self_trace.bound_symbols = list( self.trace.bound_symbols) for label, candidate in self.cached_fw_traces.items(): - self.log(f'Backward optimization with fw from {label}') + log(f'Backward optimization with fw from {label}', level=LogLevel.INFO) fw_traces = candidate.iterable() for trc in fw_traces: @@ -585,8 +409,8 @@ def match_optimizer_algorithm(): # Set the current active cached forward trace self.active_fw_trace = trc - self.log(f"Cached fw trace:\n{self.active_fw_trace}") - self.log(f"Input bw trace:\n{self.trace}") + log(f"Cached fw trace:\n{self.active_fw_trace}", level=LogLevel.DEBUG) + log(f"Input bw trace:\n{self.trace}", level=LogLevel.DEBUG) # Some of the optimization passes change proxies in the trace and # any change in the forward trace must be reflected in the backward @@ -626,6 +450,12 @@ def match_optimizer_algorithm(): match_optimizer_algorithm() + def can_executor_execute(self, ex: Executor, bsym: BoundSymbol) -> bool: + try: + return ex.can_execute(bsym) + except: + return False + # For each fusion executor in the input list, find the best trace dispatching for each executor def build_placement_options_best_fuser(self, increment_factor: int = 1): from thunder.executors.data_dependent_partition import Node, fuse_bound_symbols @@ -641,6 +471,9 @@ def sequence_hash(s: Sequence) -> str: or isinstance(e, FloatProxy) ): name += e.name + # TODO (matteochen): investigate if this is suitable + elif isinstance(e ,int): + name += f'int{e}' elif e is None: name += "None" else: @@ -653,13 +486,13 @@ def get_first_available_executor(bsym: BoundSymbol): for ex in self.executors: if isinstance(ex, FusionExecutor): continue - if ex.can_execute(bsym): + if self.can_executor_execute(ex, bsym): return ex return Executor(name=self.empty_executor_hashable_placeholder) def get_placed_trace(mapping: dict[str, Executor], bound_symbols_in: Sequence[BoundSymbol]): - # self.log(f'Input mapping len = {len(mapping)}:') - # self.log(f'Input bound_symbols len = {len(bound_symbols_in)}:') + # log(f'Input mapping len = {len(mapping)}:') + # log(f'Input bound_symbols len = {len(bound_symbols_in)}:') trc = from_trace(self.trace) trc.bound_symbols = list(bound_symbols_in) @@ -722,8 +555,8 @@ def get_placed_trace(mapping: dict[str, Executor], bound_symbols_in: Sequence[Bo raise AssertionError( f"Fusion operator not supported: {ex.name}") - self.log( - f"Searching best placement for fusion executor = {ex.name}") + log( + f"Searching best placement for fusion executor = {ex.name}", level=LogLevel.DEBUG) # TODO (matteochen): each executor has a custom should fuse function, can we make this prettier? def _should_fuse_nvfuser(a: Node, b: Node): @@ -747,18 +580,18 @@ def _can_fuse_node(n: Node): return _can_fuse_node(a) and _can_fuse_node(b) - def match_bsym_output(bsym_in: BoundSymbol, time_dict: dict, mem_dict: dict, ex_in: Executor): + def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): if isinstance(bsym_in.output, Sequence): - time_dict[sequence_hash(bsym_in.output)] = ex_in - mem_dict[sequence_hash(bsym_in.output)] = ex_in + for d in dicts: + d[sequence_hash(bsym_in.output)] = ex_in elif ( isinstance(bsym_in.output, CollectionProxy) or isinstance(bsym_in.output, TensorProxy) or isinstance(bsym_in.output, IntegerProxy) or isinstance(bsym_in.output, FloatProxy) ): - time_dict[bsym_in.output.name] = ex_in - mem_dict[bsym_in.output.name] = ex_in + for d in dicts: + d[bsym_in.output.name] = ex_in else: raise AssertionError( f"Type not handled: {type(bsym_in.output)}") @@ -766,47 +599,89 @@ def match_bsym_output(bsym_in: BoundSymbol, time_dict: dict, mem_dict: dict, ex_ bound_symbol_groups = fuse_bound_symbols( self.trace, _should_fuse_nvfuser if ex.name == "nvfuser" else _should_fuse_torchcompile ) - self.log(f"Num of groups = {len(bound_symbol_groups)}") + log(f"Num of groups = {len(bound_symbol_groups)}", level=LogLevel.DEBUG) for id, group in enumerate(bound_symbol_groups): - self.log(f"Group id: {id}") + log(f"Group id: {id}", level=LogLevel.DEBUG) for sub in group: - self.log(f"{sub.sym.name} -> out: {sub.output}") - if len(group) > 0: - print("\n") + log(f"{sub.sym.name} -> out: {sub.output}", level=LogLevel.DEBUG) + # if len(group) > 0: + # print("\n") dict_time_strat: dict[str, Executor] = {} dict_mem_strat: dict[str, Executor] = {} increasing_symbols = [] for group_id, group in enumerate(bound_symbol_groups): - self.log(f"Group id: {group_id}") - self.log(f"group start = {group[0].sym.name}") - self.log(f"group end = {group[-1].sym.name}") + log(f"Group id: {group_id}", level=LogLevel.DEBUG) + log(f"group start = {group[0].sym.name}", level=LogLevel.DEBUG) + log(f"group end = {group[-1].sym.name}", level=LogLevel.DEBUG) if group[0].sym.name != "return": increasing_symbols += group - # Is not a fusion region, get the optimal executor + # Is not a fusion region, get the optimal executor (OperatorExecutor) if len(group) < 2: current_bsym = group[0] - self.log(f"--> Single group: {current_bsym.sym.name}") + log(f"--> Single group: {current_bsym.sym.name}", level=LogLevel.DEBUG) name = current_bsym.sym.name - optimal_ex = get_first_available_executor(current_bsym) + # Filter out all possible candidates for the current symbol + candidate_executors = [ex for ex in self.executors if self.can_executor_execute(ex, current_bsym) and not isinstance(ex, FusionExecutor)] + if name == "return": - dict_time_strat["return"] = optimal_ex - dict_mem_strat["return"] = optimal_ex + dict_time_strat["return"] = Executor(name=self.empty_executor_hashable_placeholder) + dict_mem_strat["return"] = Executor(name=self.empty_executor_hashable_placeholder) # Add the modified return statement at the end of the for loop break - # before was ex??? + + # Not executors available + if not candidate_executors: + match_bsym_output( + current_bsym, [dict_time_strat, dict_mem_strat], Executor(name=self.empty_executor_hashable_placeholder)) + continue + else: + log(f'Available executors for single region:\n{candidate_executors}', level=LogLevel.DEBUG) + + # Helpers + candidate_best_time = BenchmarkResult() + candidate_best_mem = BenchmarkResult() + # Search for best candidate + for i, candidate in enumerate(candidate_executors): + # Match the current candidate to benchmark partial trace + match_bsym_output( + current_bsym, [dict_time_strat, dict_mem_strat], candidate) + # Retrieve partial trace and benchmark, apply remat if possible + trc, _, _ = get_placed_trace( + dict_time_strat, increasing_symbols) + if self.trace_type == TraceType.BW and self.active_fw_trace is not None: + _, trc = rematerialize_forward_and_backward( + self.active_fw_trace, trc) + t, m, _ = benchmark_trace(trc, self.benchmark_iters) + # Update results + if t < candidate_best_time.tm: + candidate_best_time.tm = t + candidate_best_time.index = i + + if m < candidate_best_mem.mem: + candidate_best_mem.mem = m + candidate_best_mem.index = i + + if candidate_best_time.index == -1 or candidate_best_mem.index == -1: + raise AssertionError(f'Failed to get optimal single trace region candidate. Available candidates for {name}:\n{candidate_executors}') + + log(f'Best time OperatorExecutor for single {name}: {candidate_executors[candidate_best_time.index].name}', level=LogLevel.DEBUG) + log(f'Best mem OperatorExecutor for single {name}: {candidate_executors[candidate_best_mem.index].name}', level=LogLevel.DEBUG) + match_bsym_output( - current_bsym, dict_time_strat, dict_mem_strat, optimal_ex) + current_bsym, [dict_time_strat], candidate_executors[candidate_best_time.index]) + match_bsym_output( + current_bsym, [dict_mem_strat], candidate_executors[candidate_best_mem.index]) continue # Inside groups we should have alwasy tensors as out - best_res_time = self.BenchmarkResult() - best_res_mem = self.BenchmarkResult() - worst_res_time = self.BenchmarkResult() - worst_res_mem = self.BenchmarkResult() + best_res_time = BenchmarkResult() + best_res_mem = BenchmarkResult() + worst_res_time = BenchmarkResult() + worst_res_mem = BenchmarkResult() # Only for visual worst_res_mem.measure = 0 worst_res_time.measure = 0 @@ -833,8 +708,8 @@ def measure_and_update_result(): self.active_fw_trace, trc) cost, mem, out = benchmark_trace(trc, self.benchmark_iters) del out - self.log( - f"Placed trace (cost = {cost} ms, mem = {mem/(2**30)} GB)\n{trc}") + log( + f"Placed trace (cost = {cost} ms, mem = {mem/(2**30)} GB)\n{trc}", level=LogLevel.DEBUG) if cost < best_res_time.tm or (cost == best_res_time.tm and mem < best_res_time.mem): best_res_time.tm = cost best_res_time.mem = mem @@ -861,12 +736,12 @@ def measure_and_update_result(): for idx in range(0, len(group)): if group[idx].sym.name == "embedding_backward": last_embedding_idx = idx - self.log(f"last embedding {last_embedding_idx}") + log(f"last embedding {last_embedding_idx}", level=LogLevel.DEBUG) if last_embedding_idx != -1: # Until last_embedding_idx (included) assigned to current fusion ex for i in range(0, last_embedding_idx + 1, 1): match_bsym_output( - group[i], dict_time_strat, dict_mem_strat, ex) + group[i], [dict_time_strat, dict_mem_strat], ex) if last_embedding_idx == len(group) - 1: # Benchmark @@ -875,17 +750,18 @@ def measure_and_update_result(): start_idx = last_embedding_idx + 1 n_missing_bsyms = len(group) - start_idx - for i in range(0, n_missing_bsyms): + for i in range(0, n_missing_bsyms, n_missing_bsyms-1 if self.trace_type == TraceType.BW else 1): + # for i in range(0, n_missing_bsyms): # From top to bottom (this will include the whole region) # -> First iteration is the one with fusion region with single element # -> Last iteration gives the complete fusion region for j in range(start_idx, start_idx + i + 1, increment_factor): match_bsym_output( - group[j], dict_time_strat, dict_mem_strat, ex) + group[j], [dict_time_strat, dict_mem_strat], ex) for k in range(start_idx + i + 1, len(group), increment_factor): match_bsym_output( - group[k], dict_time_strat, dict_mem_strat, get_first_available_executor( + group[k], [dict_time_strat, dict_mem_strat], get_first_available_executor( group[k]) ) # Benchmark @@ -908,11 +784,13 @@ def measure_and_update_result(): if best_placement_mem is None or best_keys_mem is None: raise AssertionError("Failed to get best placement") - self.log( - f"For group {group_id} best placement with time cost = {best_res_time.tm} ms (worst time = {worst_res_time.tm} ms):\n{best_res_time.trace}" + log( + f"For group {group_id} best placement with time cost = {best_res_time.tm} ms (worst time = {worst_res_time.tm} ms):\n{best_res_time.trace}", + level=LogLevel.DEBUG ) - self.log( - f"For group {group_id} best placement with mem cost = {best_res_mem.mem / (2**30)} GB (worst mem = {worst_res_mem.mem/(2**30)} GB) is:\n{best_res_mem.trace}" + log( + f"For group {group_id} best placement with mem cost = {best_res_mem.mem / (2**30)} GB (worst mem = {worst_res_mem.mem/(2**30)} GB) is:\n{best_res_mem.trace}", + level=LogLevel.DEBUG ) # for n, p in zip(best_keys, best_placement): @@ -969,32 +847,32 @@ def measure_and_update_result(): self.trace.bound_symbols[-1].from_bsym(args=return_not_used_vars(trc))) # NOTE: Here the active trace to place will be 'trc' and not 'self.trace' trc_time = self.place_optimizers(trc, executors_mem) - c, m, o = benchmark_trace(trc_time) + c, m, o = benchmark_trace(trc_time, self.benchmark_iters) del o - self.log(f"Debug MEM, mem = {m/(2**30)} GB:\n{trc_time}") + log(f"Debug MEM, mem = {m/(2**30)} GB:\n{trc_time}", level=LogLevel.DEBUG) self.fusion_strat_helper.optimized_traces_mem_benchmark_only.append({ ex.name: trc_time}) trc_mem = self.place_optimizers(trc, executors_time) - c, m, o = benchmark_trace(trc_mem) + c, m, o = benchmark_trace(trc_mem, self.benchmark_iters) del o - self.log(f"Debug TIME, time = {c} ms:\n{trc_mem}") + log(f"Debug TIME, time = {c} ms:\n{trc_mem}", level=LogLevel.DEBUG) self.fusion_strat_helper.optimized_traces_time_benchmark_only.append({ ex.name: trc_mem}) else: trc = self.place_optimizers(self.trace, executors_mem) _, trc = rematerialize_forward_and_backward( self.active_fw_trace, trc) - c, m, o = benchmark_trace(trc) + c, m, o = benchmark_trace(trc, self.benchmark_iters) del o - self.log(f"Debug MEM, mem = {m/(2**30)} GB:\n{trc}") + log(f"Debug MEM, mem = {m/(2**30)} GB:\n{trc}", level=LogLevel.DEBUG) self.fusion_strat_helper.optimized_traces_mem_benchmark_only.append({ ex.name: trc}) trc = self.place_optimizers(self.trace, executors_time) _, trc = rematerialize_forward_and_backward( self.active_fw_trace, trc) - c, m, o = benchmark_trace(trc) + c, m, o = benchmark_trace(trc, self.benchmark_iters) del o - self.log(f"Debug TIME, time = {c} ms:\n{trc}") + log(f"Debug TIME, time = {c} ms:\n{trc}", level=LogLevel.DEBUG) self.fusion_strat_helper.optimized_traces_time_benchmark_only.append({ ex.name: trc}) @@ -1013,10 +891,10 @@ def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: # This is agnostic from the optimization strat as results are both floats min_value: float = float("inf") ans: FinalOutputCandidates | None = None - self.log(f'Computing the best pair option (tot options = {len(self.out)})') + log(f'Computing the best pair option (tot options = {len(self.out)})', level=LogLevel.INFO) for pair in self.out: if pair.tot_cost < min_value: - self.log(f"New best pair:\n{pair}") + log(f"New best pair:\n{pair}", level=LogLevel.INFO) min_value = pair.tot_cost ans = pair if ans is None: @@ -1043,14 +921,14 @@ def fw_benchmark(): label = list(pair_time.keys())[0] # TODO (matteochen): remove the benchmark here as will done later on the bw pass c, m, _ = benchmark_trace(trc_time, self.benchmark_iters) - self.log( - f'Benchmark fw end: Trace = [{label}] (time = {c} ms, mem = {m / (2**30)} GB)":\n{trc_time}') + log( + f'Benchmark fw end: Trace = [{label}] (time = {c} ms, mem = {m / (2**30)} GB)":\n{trc_time}', level=LogLevel.INFO) self.debug_msg += ( f"Trace name = [{label}] - Target: TIME - Time = {c} ms - Mem = {m / (2**30)} GB\n{trc_time}\n\n" ) c, m, _ = benchmark_trace(trc_mem, self.benchmark_iters) - self.log( - f'Benchmark fw end: Trace = [{label}] (time = {c} ms, mem = {m / (2**30)} GB)":\n{trc_mem}') + log( + f'Benchmark fw end: Trace = [{label}] (time = {c} ms, mem = {m / (2**30)} GB)":\n{trc_mem}', level=LogLevel.INFO) self.debug_msg += ( f"Trace name = [{label}] - Target: MEM - Mem = {m / (2**30)} GB - Time = {c} ms\n{trc_mem}\n\n" ) @@ -1059,8 +937,8 @@ def fw_benchmark(): def bw_benchmark(): - time_result = self.BenchmarkResult() - memory_result = self.BenchmarkResult() + time_result = BenchmarkResult() + memory_result = BenchmarkResult() # Find best trace for runtime for i, pair in enumerate(self.fusion_strat_helper.optimized_traces_time_benchmark_only): @@ -1071,8 +949,8 @@ def bw_benchmark(): self.debug_msg += ( f"Trace name = [{label}] - Target: TIME - Time = {trace_time} ms - Mem = {trace_mem / (2**30)} GB\n{trace}\n\n" ) - self.log( - f'Benchmark trace (target TIME) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}' + log( + f'Benchmark trace (target TIME) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}', level=LogLevel.INFO ) if trace_time < time_result.tm: time_result.tm = trace_time @@ -1092,8 +970,8 @@ def bw_benchmark(): self.debug_msg += ( f"Trace name = [{label}] - Target: MEM - Mem = {trace_mem / (2**30)} GB - Time = {trace_time} ms\n{trace}\n\n" ) - self.log( - f'Benchmark trace (target MEM) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}' + log( + f'Benchmark trace (target MEM) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}', level=LogLevel.INFO ) if trace_mem < memory_result.mem: memory_result.tm = trace_time @@ -1102,19 +980,19 @@ def bw_benchmark(): memory_result.label = label memory_result.index = i - self.log( - f'Benchmark end: Best trace time "{time_result.label} (time = {time_result.tm} ms)":\n{time_result.trace}') - self.log( - f'Benchmark end: Best trace mem "{memory_result.label} (mem = {memory_result.mem / (2 ** 30)} GB)":\n{memory_result.trace}') + log( + f'Benchmark end: Best trace time "{time_result.label} (time = {time_result.tm} ms)":\n{time_result.trace}', level=LogLevel.INFO) + log( + f'Benchmark end: Best trace mem "{memory_result.label} (mem = {memory_result.mem / (2 ** 30)} GB)":\n{memory_result.trace}', level=LogLevel.INFO) # TODO (matteochen): remove this - # self.log(f"Strat comparison: {self.trace_type}") + # log(f"Strat comparison: {self.trace_type}") # c, m, o = benchmark_trace(tm.trace) # del o - # self.log(f"best time: {c} ms, {m/(2**30)} GB") + # log(f"best time: {c} ms, {m/(2**30)} GB") # c, m, o = benchmark_trace(mem.trace) # del o - # self.log(f"best mem: {c} ms, {m/(2**30)} GB") + # log(f"best mem: {c} ms, {m/(2**30)} GB") # Here we have to recover the traces without the pass through remat in order to be compliant # with thunder flow as we might have request for no remat @@ -1128,7 +1006,7 @@ def bw_benchmark(): trc = list(self.fusion_strat_helper.optimized_traces_mem[memory_result.index].values())[0] self.bw_trace_candidates.attach_best_mem_candidate(trc) - self.log(self.bw_trace_candidates.__repr__()) + log(self.bw_trace_candidates.__repr__(), level=LogLevel.DEBUG) # Now, finally build the pair fw and bw traces for the requested strat # The current fw trace is set by the caller and we take it as is. All current bw traces optimizations are made with the fw trace set by the caller @@ -1138,8 +1016,8 @@ def bw_benchmark(): case OptimizerType.RUNTIME: # Used the computed benchmark from above if time_result.tm < memory_result.tm: - self.log( - f"out candidate times: (fw){forward_time} ms, (bw){time_result.tm} ms") + log( + f"out candidate times: (fw){forward_time} ms, (bw){time_result.tm} ms", level=LogLevel.INFO) self.out.append( FinalOutputCandidates( fw=self.active_fw_trace, @@ -1148,8 +1026,8 @@ def bw_benchmark(): ) ) else: - self.log( - f"out candidate times: (fw){forward_time} ms, (bw){memory_result.tm} ms") + log( + f"out candidate times: (fw){forward_time} ms, (bw){memory_result.tm} ms", level=LogLevel.INFO) self.out.append( FinalOutputCandidates( fw=self.active_fw_trace, @@ -1160,8 +1038,8 @@ def bw_benchmark(): case OptimizerType.MEMORY: # Used the computed benchmark from above if time_result.mem < memory_result.mem: - self.log( - f"out candidate mem: (fw){forward_memory / (2**30)} GB, (bw){time_result.mem} GB") + log( + f"out candidate mem: (fw){forward_memory / (2**30)} GB, (bw){time_result.mem / (2**30)} GB", level=LogLevel.INFO) self.out.append( FinalOutputCandidates( fw=self.active_fw_trace, @@ -1170,8 +1048,8 @@ def bw_benchmark(): ) ) else: - self.log( - f"out candidate mem: (fw){forward_memory} GB, (bw){memory_result.mem} GB") + log( + f"out candidate mem: (fw){forward_memory / (2**30)} GB, (bw){memory_result.mem / (2**30)} GB", level=LogLevel.INFO) self.out.append( FinalOutputCandidates( fw=self.active_fw_trace, @@ -1234,7 +1112,7 @@ def is_possible_out(name: str): # TODO (matteochen): move into utils module def benchmark_trace( - trace: TraceCtx, iters: int = 1, show_func=False, apply_del_last_used=True, snapshot=False, snapshot_name="" + trace: TraceCtx, iters: int = 1, show_func=False, apply_del_last_used=True, snapshot=False, snapshot_name="", nvsight: bool = False, nvsight_fn_name: str = "" ) -> tuple[float, float, Any]: from thunder.executors.passes import del_last_used import inspect @@ -1244,9 +1122,31 @@ def benchmark_trace( if trace.bound_symbols[-1].sym.id != PrimIDs.RETURN: raise AssertionError("Missing return statement") + def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, float, Any]: + try: + warm_up_iters = 50 + torch.cuda.empty_cache() + # Warm up cycles + for _ in range(warm_up_iters): + fn(*args) + # Benchmark + torch.cuda.cudart().cudaProfilerStart() + for i in range(iters): + torch.cuda.nvtx.range_push(f'{nvsight_fn_name}-iter{i}') + fn(*args) + torch.cuda.nvtx.range_pop() + torch.cuda.cudart().cudaProfilerStop() + + return float('inf'), float('inf'), None + except Exception as e: + import inspect + trc = inspect.getsource(fn) + print(f"#NVSIGHT FN EXECUTION FAILED:\n{trc}") + raise e + def compute_time_cost_ms(fn: Callable, iters: int, *args) -> tuple[float, float, Any]: try: - warm_up_iters = 3 + warm_up_iters = 50 out = None torch.cuda.empty_cache() @@ -1255,7 +1155,6 @@ def compute_time_cost_ms(fn: Callable, iters: int, *args) -> tuple[float, float, end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - max_allocated_bytes = 0 # Warm up cycles for _ in range(warm_up_iters): fn(*args) @@ -1268,6 +1167,8 @@ def compute_time_cost_ms(fn: Callable, iters: int, *args) -> tuple[float, float, torch.cuda.memory._record_memory_history(enabled=None) # Benchmark stream = torch.cuda.current_stream() + max_allocated_bytes = 0 + torch.cuda.synchronize() for i in range(iters): torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) torch.cuda.empty_cache() @@ -1352,6 +1253,8 @@ def transform_input_tuple(t: tuple, level=0) -> tuple: res.append(0 if e.value is None else e.value) elif isinstance(e, FloatProxy): res.append(0.0 if e.value is None else e.value) + elif e is None: + res.append(None) else: raise AssertionError( f"Input arg type not recognized: {type(e)}") @@ -1369,12 +1272,12 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: if dtype is not None and is_float_dtype(dtype): torch_dtype = thunder_to_torch_float_dtype(dtype, dtype.bytes) tensor: torch.Tensor = torch.randn( - *shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad + shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad ) elif dtype is not None and is_signedinteger_dtype(dtype): torch_dtype = thunder_to_torch_int_dtype(dtype.bytes) tensor: torch.Tensor = torch.randint( - 0, 10, shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad + 0, 8, shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad ) elif dtype is not None and is_boolean_dtype(dtype): # TODO (matteochen): maybe random? @@ -1386,6 +1289,24 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: return tensor + # print(f'BENCHMARKING:\n{trace}') + # def p(args): + # for e in args: + # if not isinstance(e, Sequence): + # if isinstance(e, torch.Tensor): + # print(f'{e.size()}') + # else: + # try: + # print(f'{e.name} -> {e}') + # except: + # print(f'{e}') + # else: + # print('rec') + # p(e) + # p(trace.args) + # print('##################') + # p(input_args) + # Can we remove this check? # TODO (matteochen): use more appropriate mock int and float if isinstance(trace.args, Sequence): @@ -1412,24 +1333,6 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: if apply_del_last_used: trace = del_last_used(trace) - # print(f'BENCHMARKING:\n{trace}') - # def p(args): - # for e in args: - # if not isinstance(e, Sequence): - # if isinstance(e, torch.Tensor): - # print(f'{e.size()}') - # else: - # try: - # print(f'{e.name} -> {e}') - # except: - # print(f'{e}') - # else: - # print('rec') - # p(e) - # p(trace.args) - # print('##################') - # p(input_args) - trace_tok = set_tracectx(trace) # Obtain the python executable string @@ -1441,11 +1344,14 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: m = float("inf") answer = None try: - t, m, answer = compute_time_cost_ms(executable, iters, *input_args) + if nvsight: + t, m, answer = compute_time_cost_nvsight(executable, iters, *input_args) + else: + t, m, answer = compute_time_cost_ms(executable, iters, *input_args) except Exception as e: # https://github.com/Lightning-AI/lightning-thunder/issues/664 print(f"Exception:\n{e}") - if "call_method UserDefinedObjectVariable(set) __contains__ [UserDefinedObjectVariable()] {}" in str(e): + if "call_method UserDefinedObjectVariable(set) __contains__ [UserDefinedObjectVariable()] {}" in str(e) and not nvsight: print( "Executing with torch compile no full graph (this might still fail), see: https://github.com/Lightning-AI/lightning-thunder/issues/664" ) @@ -1461,3 +1367,4 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: reset_tracectx(trace_tok) return t, m, answer + diff --git a/thunder/benchmarks/utils.py b/thunder/benchmarks/utils.py new file mode 100644 index 0000000000..8ea9a90f6a --- /dev/null +++ b/thunder/benchmarks/utils.py @@ -0,0 +1,136 @@ +import torch +from thunder.backend_optimizer.optimizer import benchmark_trace + +warm_up_iters = 50 + +def torch_fw_bw_benchmark_nvsight(models: list, torch_module: torch.nn.Module | None, labels: list, inputs: list, iters: int, int_input_tensor: bool = False) -> None: + + for m, input, label in zip(models, inputs, labels): + # Warm up + for _ in range(10): + y = m(input) + # Not supported by autograd + if int_input_tensor: + torch.autograd.grad(y.sum(), torch_module.parameters()) + else: + torch.autograd.grad(y, input, grad_outputs=torch.ones_like(y)) + + torch.cuda.cudart().cudaProfilerStart() + for i in range(iters): + torch.cuda.empty_cache() + torch.cuda.nvtx.range_push(f'iteration_nvsight-{label}') + torch.cuda.nvtx.range_push("fw_nvsight") + y = m(input) + torch.cuda.nvtx.range_pop() + # Not supported by autograd + if int_input_tensor: + torch.cuda.nvtx.range_push("bw_nvsight") + torch.autograd.grad(y.sum(), torch_module.parameters()) + torch.cuda.nvtx.range_pop() + else: + torch.cuda.nvtx.range_push("bw_nvsight") + torch.autograd.grad(y, input, grad_outputs=torch.ones_like(y)) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_pop() + torch.cuda.cudart().cudaProfilerStop() + +def torch_fw_bw_benchmark(models: list, labels: list, inputs: list, iters: int) -> None: + + for m, input, label in zip(models, inputs, labels): + # Warm up + for _ in range(warm_up_iters): + y = m(input) + y.sum().backward() + + torch.cuda.synchronize() + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + stream = torch.cuda.current_stream() + max_allocated_bytes = 0 + for i in range(iters): + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + + start_events[i].record(stream) + y = m(input) + end_events[i].record(stream) + + max_allocated_bytes = max( + max_allocated_bytes, torch.cuda.max_memory_allocated( + torch.cuda.current_device()) + ) + + torch.cuda.synchronize() + tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + tot_time = sum(tot) / iters + print(f'{label} tot fw time: {tot_time} ms') + print(f'{label} max fw allocated memory: {max_allocated_bytes / (2**30)} GB') + + torch.cuda.synchronize() + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + stream = torch.cuda.current_stream() + max_allocated_bytes = 0 + for i in range(iters): + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + + y = m(input) + start_events[i].record(stream) + y.sum().backward() + end_events[i].record(stream) + + max_allocated_bytes = max( + max_allocated_bytes, torch.cuda.max_memory_allocated( + torch.cuda.current_device()) + ) + + torch.cuda.synchronize() + tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + tot_time = sum(tot) / iters + print(f'{label} tot bw time: {tot_time} ms') + print(f'{label} max bw allocated memory: {max_allocated_bytes / (2**30)} GB') + +def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int) -> None: + + for m, input, label in zip(models, inputs, labels): + # Warm up + for _ in range(warm_up_iters): + y = m(input) + y.sum().backward() + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + torch.cuda.synchronize() + stream = torch.cuda.current_stream() + max_allocated_bytes = 0 + for i in range(iters): + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + + start_events[i].record(stream) + y = m(input) + y.sum().backward() + end_events[i].record(stream) + + max_allocated_bytes = max( + max_allocated_bytes, torch.cuda.max_memory_allocated( + torch.cuda.current_device()) + ) + + torch.cuda.synchronize() + tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + tot_time = sum(tot) / iters + print(f'{label} tot time: {tot_time} ms') + print(f'{label} max allocated memory: {max_allocated_bytes / (2**30)} GB') + + +def thunder_fw_bw_benchmark(traces: list, labels: list, iters: int, nvsight: bool = False) -> None: + for trc, label in zip(traces, labels): + c, m, _ = benchmark_trace(trc, apply_del_last_used=False, snapshot=True, snapshot_name=label, iters=iters, nvsight=nvsight, nvsight_fn_name=label) + print(f'Executing {label} trace:\n{c} ms, {m / (2**30)} GB') + diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 99ed7492c5..738a218dc2 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -3,7 +3,6 @@ import torch -from thunder.backend_optimizer.optimizer import OptimizerType import thunder.core.utils as utils from thunder.core.prims import PrimIDs from thunder.core.proxies import TensorProxy, variableify @@ -11,6 +10,7 @@ from thunder.core.symbol import BoundSymbol from thunder.core.trace import TraceCtx, from_trace, set_tracectx, reset_tracectx from thunder.core.transform_common import replace_redundant_inputs +from thunder.extend import OperatorExecutor if TYPE_CHECKING: from thunder.core.trace import VariableInterface @@ -113,217 +113,370 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat from thunder.distributed.utils import sort_data_parallel_syncs, sort_waits, sort_communication_ops from thunder.executors.passes import del_last_used, transform_for_execution, autotune_transform_for_execution from thunder.visualizer.visualizer_helper import Visualizer - from thunder.backend_optimizer.optimizer import TraceType, BackendOptimizer - - utils.check(compile_data is not None, lambda: "`compile_data` is required") - # NOTE: This function is rather slow, so it's intended to be used - # behind a cache. - tensor_cls = (torch.Tensor, TensorProxy) - requires_grad_mask = tuple(isinstance(arg, tensor_cls) and arg.requires_grad for arg in flat_args) - # If none of the inputs require gradients, raise an error - if not any(requires_grad_mask): - raise RuntimeError("PyTorch's Autograd interface requires at least one tensor input with requires_grad=True") - - primal_trace = computation_trc - primal_trace = sort_data_parallel_syncs(primal_trace) - - if compile_stats is not None: - compile_stats.last_traces.append(primal_trace) - - # torch.autograd.Function doesn't support non-flat outputs, the - # grads wouldn't be propagated and backward receives None for each - # non-flat non-tensor output. The output must also be a flat tuple, - # not any other container type. So we need to flatten the outputs of - # the forward trace and inputs of the backward trace. - fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True) - - fw_traces = [fw_trace] - bw_traces = [bw_trace] - - from thunder.distributed import FSDPType - - # only enable rematerialize_params_in_backward when using FSDP ZeRO3 - _rematerialize_params_in_backward = ( - getattr(compile_data.fn, "use_fsdp", False) and getattr(compile_data.fn, "sharding_strategy") == FSDPType.ZERO3 - ) - if _rematerialize_params_in_backward: - fw_trace, bw_trace = rematerialize_all_gather(fw_trace, bw_trace) - - # Update the backward trace to only compute gradients for the - # inputs that require gradients - assert bw_trace.bound_symbols[-1].sym.id == PrimIDs.RETURN - filtered_grads = tuple( - (arg_grad if requires_grad else None) - for arg_grad, requires_grad in utils.safe_zip(bw_trace.bound_symbols[-1].args[0], requires_grad_mask) - ) - - # autograd.Function.backward expects a flat tuple of gradients - bw_trace.bound_symbols[-1] = replace(bw_trace.bound_symbols[-1], args=(filtered_grads,)) - - _fsdp_comm_bucketing: FSDPCommBucketing | None = None - if getattr(compile_data.fn, "use_fsdp", False): - _fsdp_comm_bucketing = FSDPCommBucketing(compile_data, computation_trc) - fw_trace = _fsdp_comm_bucketing.apply_bucketing_to_forward_trace(fw_trace) - - do_apply_bucketing_bw_trace: bool = getattr(compile_data.fn, "use_fsdp", False) - - # Now we can run the optimization passes on the forward trace - visualizer = Visualizer(produce_hidden=False) - backend_optimizer_ctx: BackendOptimizer | None = ( - None - if autotune_type is None - else BackendOptimizer( - priority_executors=compile_data.executors_list, - apply_bucketing_bw_trace=do_apply_bucketing_bw_trace, - produce_log=True, - visualizer=visualizer, - optimizer_type=autotune_type, + from thunder.backend_optimizer.optimizer import log, LogLevel, TraceType, BackendOptimizer, OptimizerType, benchmark_trace + + def split(): + utils.check(compile_data is not None, lambda: "`compile_data` is required") + # NOTE: This function is rather slow, so it's intended to be used + # behind a cache. + tensor_cls = (torch.Tensor, TensorProxy) + requires_grad_mask = tuple(isinstance(arg, tensor_cls) and arg.requires_grad for arg in flat_args) + # If none of the inputs require gradients, raise an error + if not any(requires_grad_mask): + raise RuntimeError("PyTorch's Autograd interface requires at least one tensor input with requires_grad=True") + + primal_trace = computation_trc + primal_trace = sort_data_parallel_syncs(primal_trace) + + # Handled by the caller if autotune is not None + if compile_stats is not None and autotune_type is None: + compile_stats.last_traces.append(primal_trace) + + # torch.autograd.Function doesn't support non-flat outputs, the + # grads wouldn't be propagated and backward receives None for each + # non-flat non-tensor output. The output must also be a flat tuple, + # not any other container type. So we need to flatten the outputs of + # the forward trace and inputs of the backward trace. + fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True) + + fw_traces = [fw_trace] + bw_traces = [bw_trace] + + from thunder.distributed import FSDPType + + # only enable rematerialize_params_in_backward when using FSDP ZeRO3 + _rematerialize_params_in_backward = ( + getattr(compile_data.fn, "use_fsdp", False) and getattr(compile_data.fn, "sharding_strategy") == FSDPType.ZERO3 ) - ) - - visualizer.set_fw_initial_trace(fw_trace) - # Get optimzied fw trace - fw_extrace = ( - transform_for_execution(fw_trace, executors_list=compile_data.executors_list) - if autotune_type is None - else autotune_transform_for_execution( - optimizer_context=backend_optimizer_ctx, trace=fw_trace, trace_type=TraceType.FW + if _rematerialize_params_in_backward: + fw_trace, bw_trace = rematerialize_all_gather(fw_trace, bw_trace) + + # Update the backward trace to only compute gradients for the + # inputs that require gradients + assert bw_trace.bound_symbols[-1].sym.id == PrimIDs.RETURN + filtered_grads = tuple( + (arg_grad if requires_grad else None) + for arg_grad, requires_grad in utils.safe_zip(bw_trace.bound_symbols[-1].args[0], requires_grad_mask) ) - ) - # If in default mode, otherwise the best fw will be returned only at the end - if autotune_type is None: - fw_traces.append(fw_extrace) - visualizer.set_fw_optimized_trace(fw_extrace) - - # NOTE: autotuner will take care of this - # Some of the optimization passes change proxies in the trace and - # any change in the forward trace must be reflected in the backward - # trace. - original_bw_saved_tensors_for_backward = bw_trace.args[0][0] - new_fw_saved_tensors_for_backward = fw_extrace.output[1][0] - swap_map = { - variableify(x): y - for x, y in zip(original_bw_saved_tensors_for_backward, new_fw_saved_tensors_for_backward) - if variableify(x) != variableify(y) - } - new_bsyms = replace_redundant_inputs(swap_map, bw_trace.bound_symbols) - # replace_redundant_inputs doesn't replace the output of - # UNPACK_SEQUENCE so we do it manually. Here we have certain - # assumptions about the structure of the backward trace. - assert bw_trace.bound_symbols[0].sym.id == PrimIDs.UNPACK_TRIVIAL - assert bw_trace.bound_symbols[0].kwargs["name"] == "saved_for_backward" - assert bw_trace.bound_symbols[4].sym.id == PrimIDs.UNPACK_SEQUENCE - assert bw_trace.bound_symbols[4].args[0].name == "C0" - new_bsyms[4] = new_bsyms[4].from_bsym_swap_proxies( - swap_map, - skip_inputs=False, - skip_output=False, - skip_subsymbols=False, + # autograd.Function.backward expects a flat tuple of gradients + bw_trace.bound_symbols[-1] = replace(bw_trace.bound_symbols[-1], args=(filtered_grads,)) + + _fsdp_comm_bucketing: FSDPCommBucketing | None = None + if getattr(compile_data.fn, "use_fsdp", False): + _fsdp_comm_bucketing = FSDPCommBucketing(compile_data, computation_trc) + fw_trace = _fsdp_comm_bucketing.apply_bucketing_to_forward_trace(fw_trace) + + do_apply_bucketing_bw_trace: bool = getattr(compile_data.fn, "use_fsdp", False) + + # Now we can run the optimization passes on the forward trace + visualizer = Visualizer(produce_hidden=False) + backend_optimizer_ctx: BackendOptimizer | None = ( + None + if autotune_type is None + else BackendOptimizer( + priority_executors=compile_data.executors_list, + apply_bucketing_bw_trace=do_apply_bucketing_bw_trace, + produce_log=True, + visualizer=visualizer, + optimizer_type=autotune_type, + ) ) - bw_trace.bound_symbols = new_bsyms - - if do_apply_bucketing_bw_trace: - bw_trace = _fsdp_comm_bucketing.apply_bucketing_to_backward_trace(bw_trace) - # Now we can run the optimization passes on the backward trace - # TODO Restore request for no rematerialization - - visualizer.set_bw_initial_trace(bw_trace) - if autotune_type is not None: - fw_extrace, bw_extrace = autotune_transform_for_execution( - optimizer_context=backend_optimizer_ctx, trace=bw_trace, trace_type=TraceType.BW - ) - fw_traces.append(fw_extrace) - visualizer.set_bw_optimized_trace(fw_extrace) - else: - bw_extrace = transform_for_execution( - bw_trace, - executors_list=compile_data.executors_list, + visualizer.set_fw_initial_trace(fw_trace) + # Get optimzied fw trace + fw_extrace = ( + transform_for_execution(fw_trace, executors_list=compile_data.executors_list) + if autotune_type is None + else autotune_transform_for_execution( + optimizer_context=backend_optimizer_ctx, trace=fw_trace, trace_type=TraceType.FW + ) ) - bw_traces.append(bw_extrace) - visualizer.set_bw_optimized_trace(bw_extrace) - - # TODO Restore request for no rematerialization - fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) - fw_traces.append(fw_extrace) - bw_traces.append(bw_extrace) - - # We need to sort the waits in forward and backward trace to overlap - # computation with communication - # For performance we need the wait_prim_impl nodes in the execution trace to be as far from the - # communication ops as possible. But it causes the all_gather_prim_impl nodes gathered at the start of - # backward trace and increases the peak allocated memory - use_fsdp: bool = getattr(compile_data.fn, "use_fsdp", False) - if use_fsdp: - assert hasattr(compile_data.fn, "sharding_strategy") - if getattr(compile_data.fn, "sharding_strategy") == FSDPType.ZERO3: - from thunder.distributed import FSDPBucketingStrategy - from thunder.distributed.utils import limit_in_flight_allgathers - - fw_extrace = sort_communication_ops(fw_extrace) - fw_extrace = limit_in_flight_allgathers( - fw_extrace, - 3, - compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE, + + # If in default mode, otherwise the best fw will be returned only at the end + if autotune_type is None: + fw_traces.append(fw_extrace) + visualizer.set_fw_optimized_trace(fw_extrace) + + # NOTE: autotuner will take care of this + # Some of the optimization passes change proxies in the trace and + # any change in the forward trace must be reflected in the backward + # trace. + original_bw_saved_tensors_for_backward = bw_trace.args[0][0] + new_fw_saved_tensors_for_backward = fw_extrace.output[1][0] + swap_map = { + variableify(x): y + for x, y in zip(original_bw_saved_tensors_for_backward, new_fw_saved_tensors_for_backward) + if variableify(x) != variableify(y) + } + new_bsyms = replace_redundant_inputs(swap_map, bw_trace.bound_symbols) + # replace_redundant_inputs doesn't replace the output of + # UNPACK_SEQUENCE so we do it manually. Here we have certain + # assumptions about the structure of the backward trace. + assert bw_trace.bound_symbols[0].sym.id == PrimIDs.UNPACK_TRIVIAL + assert bw_trace.bound_symbols[0].kwargs["name"] == "saved_for_backward" + assert bw_trace.bound_symbols[4].sym.id == PrimIDs.UNPACK_SEQUENCE + assert bw_trace.bound_symbols[4].args[0].name == "C0" + new_bsyms[4] = new_bsyms[4].from_bsym_swap_proxies( + swap_map, + skip_inputs=False, + skip_output=False, + skip_subsymbols=False, ) - bw_extrace = sort_communication_ops(bw_extrace) - bw_extrace = limit_in_flight_allgathers( - bw_extrace, - 3, - compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE, + bw_trace.bound_symbols = new_bsyms + + if do_apply_bucketing_bw_trace: + bw_trace = _fsdp_comm_bucketing.apply_bucketing_to_backward_trace(bw_trace) + + # Now we can run the optimization passes on the backward trace + # TODO Restore request for no rematerialization + + visualizer.set_bw_initial_trace(bw_trace) + if autotune_type is not None: + fw_extrace, bw_extrace = autotune_transform_for_execution( + optimizer_context=backend_optimizer_ctx, trace=bw_trace, trace_type=TraceType.BW ) - if getattr(compile_data.fn, "sharding_strategy") == FSDPType.ZERO2: - from thunder.distributed import FSDPBucketingStrategy - from thunder.distributed.utils import limit_in_flight_allgathers - from sys import maxsize as INT_MAX - - # sort the allgather+wait as consumer order just before consumer - fw_extrace = sort_communication_ops(fw_extrace) - # unlimited number of allgathers, i.e. allgathers are listed at the beginning of the trace in consumer order and wait stays just before wait - fw_extrace = limit_in_flight_allgathers( - fw_extrace, - INT_MAX, - compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE, + fw_traces.append(fw_extrace) + visualizer.set_bw_optimized_trace(fw_extrace) + else: + bw_extrace = transform_for_execution( + bw_trace, + executors_list=compile_data.executors_list, ) + bw_traces.append(bw_extrace) + visualizer.set_bw_optimized_trace(bw_extrace) + + # TODO Restore request for no rematerialization + c, m, _ = benchmark_trace(fw_extrace, iters=50) + log(f'before remat fw trace time = {c}, mem = {m}', level=LogLevel.INFO) + c, m, _ = benchmark_trace(bw_extrace, iters=50) + log(f'before remat bw trace time = {c}, mem = {m}', level=LogLevel.INFO) + fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) + c, m, _ = benchmark_trace(fw_extrace, iters=50) + log(f'after remat fw trace time = {c}, mem = {m}', level=LogLevel.INFO) + c, m, _ = benchmark_trace(bw_extrace, iters=50) + log(f'after remat bw trace time = {c}, mem = {m}', level=LogLevel.INFO) + fw_traces.append(fw_extrace) + bw_traces.append(bw_extrace) + + # We need to sort the waits in forward and backward trace to overlap + # computation with communication + # For performance we need the wait_prim_impl nodes in the execution trace to be as far from the + # communication ops as possible. But it causes the all_gather_prim_impl nodes gathered at the start of + # backward trace and increases the peak allocated memory + use_fsdp: bool = getattr(compile_data.fn, "use_fsdp", False) + if use_fsdp: + assert hasattr(compile_data.fn, "sharding_strategy") + if getattr(compile_data.fn, "sharding_strategy") == FSDPType.ZERO3: + from thunder.distributed import FSDPBucketingStrategy + from thunder.distributed.utils import limit_in_flight_allgathers + + fw_extrace = sort_communication_ops(fw_extrace) + fw_extrace = limit_in_flight_allgathers( + fw_extrace, + 3, + compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE, + ) + bw_extrace = sort_communication_ops(bw_extrace) + bw_extrace = limit_in_flight_allgathers( + bw_extrace, + 3, + compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE, + ) + if getattr(compile_data.fn, "sharding_strategy") == FSDPType.ZERO2: + from thunder.distributed import FSDPBucketingStrategy + from thunder.distributed.utils import limit_in_flight_allgathers + from sys import maxsize as INT_MAX + + # sort the allgather+wait as consumer order just before consumer + fw_extrace = sort_communication_ops(fw_extrace) + # unlimited number of allgathers, i.e. allgathers are listed at the beginning of the trace in consumer order and wait stays just before wait + fw_extrace = limit_in_flight_allgathers( + fw_extrace, + INT_MAX, + compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE, + ) + bw_extrace = sort_waits(bw_extrace) + use_ddp: bool = getattr(compile_data.fn, "use_ddp", False) + if use_ddp: bw_extrace = sort_waits(bw_extrace) - use_ddp: bool = getattr(compile_data.fn, "use_ddp", False) - if use_ddp: - bw_extrace = sort_waits(bw_extrace) - if (not use_ddp) and (not use_fsdp): - from thunder.distributed.utils import maybe_sort_waits + if (not use_ddp) and (not use_fsdp): + from thunder.distributed.utils import maybe_sort_waits - _, fw_extrace = maybe_sort_waits(fw_extrace) - _, bw_extrace = maybe_sort_waits(bw_extrace) + _, fw_extrace = maybe_sort_waits(fw_extrace) + _, bw_extrace = maybe_sort_waits(bw_extrace) - # Importing here to avoid cyclical dependencies in future. - from thunder.executors.transformer_engineex import _transformer_engine_bwd_fp8_meta_sync, transformer_engine_ex + # Importing here to avoid cyclical dependencies in future. + from thunder.executors.transformer_engineex import _transformer_engine_bwd_fp8_meta_sync, transformer_engine_ex - if transformer_engine_ex in compile_data.executors_list: - # NOTE: `_transformer_engine_bwd_fp8_meta_sync` may mutate `fw_extrace` or `bw_extrace`. - _transformer_engine_bwd_fp8_meta_sync(fw_extrace, bw_extrace) + if transformer_engine_ex in compile_data.executors_list: + # NOTE: `_transformer_engine_bwd_fp8_meta_sync` may mutate `fw_extrace` or `bw_extrace`. + _transformer_engine_bwd_fp8_meta_sync(fw_extrace, bw_extrace) - fw_extrace = del_last_used(fw_extrace) - fw_traces.append(fw_extrace) + fw_extrace = del_last_used(fw_extrace) + fw_traces.append(fw_extrace) + + bw_extrace = del_last_used(bw_extrace, clear_mutable_collections=True) + bw_traces.append(bw_extrace) - bw_extrace = del_last_used(bw_extrace, clear_mutable_collections=True) - bw_traces.append(bw_extrace) + bw_trace = rename_bwd_trace_outputs(bw_extrace, fw_extrace) - bw_trace = rename_bwd_trace_outputs(bw_extrace, fw_extrace) + # This is moved to the caller if autotune is enabled + if compile_stats is not None and autotune_type is None: + compile_stats.last_traces += fw_traces + compile_stats.last_backward_traces += bw_traces - if compile_stats is not None: - compile_stats.last_traces += fw_traces - compile_stats.last_backward_traces += bw_traces + # Enable wrapping with `te.fp8_autocast`. + fw_extrace._include_te_fp8_autocast = True + # We only want the forward function to be called with `te.fp8_autocast` manager. + bw_extrace._include_te_fp8_autocast = False - # Enable wrapping with `te.fp8_autocast`. - fw_extrace._include_te_fp8_autocast = True - # We only want the forward function to be called with `te.fp8_autocast` manager. - bw_extrace._include_te_fp8_autocast = False + # Let's include the last traces also after all the passes + visualizer.set_fw_final_trace(fw_extrace) + visualizer.set_bw_final_trace(bw_extrace) - # Let's include the last traces also after all the passes - visualizer.set_fw_final_trace(fw_extrace) - visualizer.set_bw_final_trace(bw_extrace) + # visualizer.produce() - # visualizer.produce() + if autotune_type is None: + return fw_extrace, bw_extrace + else: + return primal_trace, fw_extrace, bw_extrace, fw_traces, bw_traces - return fw_extrace, bw_extrace + # Defined executors that are matched inside the fw and bw split, hence outside the autotuner scope + # TODO (matteochen): integrate Transofrmer Engine + from thunder.executors.sdpaex import sdpa_ex + from thunder.executors.cudnnex import cudnn_ex + from thunder.executors.transformer_engineex import transformer_engine_ex + + executors_candidates: dict[str, list] = { + 'scaled_dot_product_attention': [sdpa_ex.name, cudnn_ex.name], + 'linear_layer': [transformer_engine_ex.name] + } + + # TODO (matteochen): use BackendOptimizer tracing + + # If autotuner is enabled, we compare different impl of executors which are assigned inside the call 'forward_and_backward_from_trace' + # as the autotuner will receive already split fw and bw traces + if autotune_type is not None: + cached_executor_list = list(compile_data.executors_list) + try: + is_tuned = False + + # We are interested to save the best_*s at the last iteration over the executors_candidates dict as the last + # out *_extrace from calling split will contain all the best executors computed incrementally + # i.e: best_* will track the best placemet for iteration (executors_candidates iteration) i plus every iteration from [0, i-1] + best_cost: float = float('inf') + best_fw_extrace: TraceCtx | None = None + best_bw_extrace: TraceCtx | None = None + best_fw_traces: list[TraceCtx] = [] + best_bw_traces: list[TraceCtx] = [] + best_primal_trace: TraceCtx | None = None + best_executor: OperatorExecutor | None = None + + for i, (ex_type, ex_list) in enumerate(executors_candidates.items()): + log( + f"================================================================================ Before Autotune Tuning: Optimizing {ex_type}", + level=LogLevel.DEBUG) + # Search in the requested executor list if one or more than one options for a know multiple executable region is available + to_benchmark = [ex for ex in cached_executor_list if ex.name in ex_list] + + if not to_benchmark: + log( + f"================================================================================ Before Autotune Tuning: Skipping optimization for {ex_type} as not requested.", + level=LogLevel.DEBUG) + + for e in to_benchmark: + compile_data.executors_list = [ex for ex in cached_executor_list if ex not in to_benchmark] + compile_data.executors_list.insert(0, e) + log( + f"================================================================================ Before Autotune Tuning: Testing compile data executors: {compile_data.executors_list}", level=LogLevel.DEBUG) + + primal_trace, fw_extrace, bw_extrace, fw_traces, bw_traces = split() + time_fw, mem_fw, _ = benchmark_trace(fw_extrace, iters=10, apply_del_last_used=False) + time_bw, mem_bw, _ = benchmark_trace(bw_extrace, iters=10, apply_del_last_used=False) + tot_time = time_fw + time_bw + tot_mem = mem_fw + mem_bw + log( + f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} options: {e.name}. Time fw = {time_fw} ms - Time bw = {time_bw} ms - Mem fw = {mem_fw / (2**30)} GB - Mem bw = {mem_bw / (2**30)} GB", level=LogLevel.DEBUG) + log( + f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} options: {e.name}. Time = {tot_time} ms - Mem = {tot_mem / (2**30)} GB", level=LogLevel.DEBUG) + log(f'Fw trace:\n{fw_extrace}', level=LogLevel.DEBUG) + log(f'Bw trace:\n{bw_extrace}', level=LogLevel.DEBUG) + + benchmark_cost = tot_time if autotune_type == OptimizerType.RUNTIME else tot_mem + if benchmark_cost < best_cost: + is_tuned = True + best_cost = benchmark_cost + best_fw_extrace = fw_extrace + best_bw_extrace = bw_extrace + best_fw_traces = fw_traces + best_bw_traces = bw_traces + best_primal_trace = primal_trace + best_executor = e + + # c, m , _ = benchmark_trace(best_fw_extrace, iters=10, apply_del_last_used=False) + # print(f'inside update {c}') + # c, m , _ = benchmark_trace(best_bw_extrace, iters=10, apply_del_last_used=False) + # print(f'inside update {c}') + + c, m , _ = benchmark_trace(best_fw_extrace, iters=10, apply_del_last_used=False) + log( + f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} best fw_extrace (time = {c}, mem = {m}):\n{best_fw_extrace}", level=LogLevel.DEBUG) + c, m , _ = benchmark_trace(best_bw_extrace, iters=10, apply_del_last_used=False) + log( + f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} best bw_extrace (time = {c}, mem = {m}):\n{best_bw_extrace}", level=LogLevel.DEBUG) + + # Update the executor list with the winner executor for the current ex_type + cached_executor_list = [ex for ex in cached_executor_list if ex not in to_benchmark] + # We have a solution, we don't have it if not requested from the executor list + if best_executor is not None: + cached_executor_list.insert(0, best_executor) + best_executor = None + log( + f"================================================================================ Before Autotune Tuning: Benchmark {ex_type}, new executor list: {cached_executor_list}", level=LogLevel.DEBUG) + + # Update the compile stats on the last iter + if i == len(executors_candidates)-1: + # Check that we have solution, we don't have it if not requested from the executor list + if is_tuned: + # Restore + compile_data.executors_list = list(cached_executor_list) + + log( + f"================================================================================ Before Autotune Tuning: autotuned split_forward_backward from {executors_candidates}", level=LogLevel.DEBUG) + if compile_stats is not None: + compile_stats.last_traces.append(best_primal_trace) + compile_stats.last_traces += best_fw_traces + compile_stats.last_backward_traces += best_bw_traces + + return best_fw_extrace, best_bw_extrace + # If no solution is found at this optmization step, we proceed normally + else: + # Restore before calling split + compile_data.executors_list = list(cached_executor_list) + + log( + f"================================================================================ Before Autotune Tuning: not autotuned split_forward_backward from {executors_candidates}", level=LogLevel.DEBUG) + primal_trace, fw_extrace, bw_extrace, fw_traces, bw_traces = split() + if compile_stats is not None: + compile_stats.last_traces.append(primal_trace) + compile_stats.last_traces += fw_traces + compile_stats.last_backward_traces += bw_traces + + return fw_extrace, bw_extrace + except AssertionError as e: + print(f'Exception occured: {e}') + # Restore before calling split + compile_data.executors_list = list(cached_executor_list) + + log( + f"================================================================================ Before Autotune Tuning: exception occured, not autotuned split_forward_backward from {executors_candidates}", level=LogLevel.DEBUG) + primal_trace, fw_extrace, bw_extrace, fw_traces, bw_traces = split() + if compile_stats is not None: + compile_stats.last_traces.append(primal_trace) + compile_stats.last_traces += fw_traces + compile_stats.last_backward_traces += bw_traces + + return fw_extrace, bw_extrace + else: + return split() From 15f914eb21ac51ac6f5d2decc640aaa211b6f9d8 Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 2 Aug 2024 11:20:38 +0300 Subject: [PATCH 025/171] Fixed remat visual timing / Switched towards `memory` strat / General cleanup (#6) --- examples/dev/MLP.py | 18 +- examples/dev/litGPT.py | 2 +- examples/dev/nanogpt-block.py | 4 +- examples/dev/nanogpt.py | 300 ++++++++++++------------- thunder/backend_optimizer/optimizer.py | 90 +++----- thunder/executors/torch_autograd.py | 118 ++++++---- 6 files changed, 261 insertions(+), 271 deletions(-) diff --git a/examples/dev/MLP.py b/examples/dev/MLP.py index 9dfa7510a6..ef0bb2a7d6 100644 --- a/examples/dev/MLP.py +++ b/examples/dev/MLP.py @@ -36,21 +36,23 @@ def forward(self, x): jmodel_def = thunder.jit(model) # This model fails under some circumstances after passed the placed traced under the rematelizer - jmodel_auto = thunder.jit(model, autotune_type='runtime', executors = ['nvfuser', 'torchcompile', 'sdpa', 'torch', 'python']) + jmodel_auto = thunder.jit(model, autotune_type='memory', executors = ['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python']) print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) + callables = [jmodel_auto, jmodel_def] + labels = ['auto', 'def'] + inputs = [x, x] + print('Results with torch fw bw benchmark:') + torch_fw_bw_benchmark(callables, labels, inputs, 5) + print('Results with thunder benchmark:') traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] + traces.reverse() labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] - thunder_fw_bw_benchmark(traces, labels, 50, nvsight = False) - - callables = [jmodel_def, jmodel_auto] - labels = ['def', 'auto'] - inputs = [x, x] - print('Results with torch fw bw benchmark:') - torch_fw_bw_benchmark(callables, labels, inputs, 50) + labels.reverse() + thunder_fw_bw_benchmark(traces, labels, 5, nvsight = False) # for t in traces: # print(t) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index a626201e6f..0668921d5a 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -9,7 +9,7 @@ def __init__(self, layers: int, autotune_type: str) -> None: self.layers = layers self.autotune_type = autotune_type -layers = [Test(8, 'runtime')]#, Test(8, 'runtime'), Test(16, 'runtime')] +layers = [Test(8, 'runtime'), Test(8, 'runtime'), Test(16, 'runtime')] model_name = 'Llama-2-7b-hf' diff --git a/examples/dev/nanogpt-block.py b/examples/dev/nanogpt-block.py index 59ca2fe21e..75e8c9c854 100644 --- a/examples/dev/nanogpt-block.py +++ b/examples/dev/nanogpt-block.py @@ -121,13 +121,13 @@ def forward(self, x): print('Results thunder benchmark:') traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] - thunder_fw_bw_benchmark(traces, labels, 50) + thunder_fw_bw_benchmark(traces, labels, 5) print('\n\nResults torch fw bw benchmark:') callables = [jmodel_def, jmodel_auto] labels = ['def', 'auto'] inputs = [x, x] - torch_fw_bw_benchmark(callables, labels, inputs, 50) + torch_fw_bw_benchmark(callables, labels, inputs, 5) print('\n\n\n\n\n\n') print(f'{thunder.last_traces(jmodel_def)[-1]}') diff --git a/examples/dev/nanogpt.py b/examples/dev/nanogpt.py index 23318334c6..aab28ab332 100644 --- a/examples/dev/nanogpt.py +++ b/examples/dev/nanogpt.py @@ -2,161 +2,151 @@ import thunder from thunder.benchmarks.utils import thunder_fw_bw_benchmark from thunder.tests.nanogpt_model import GPTConfig, GPT +from contextlib import nullcontext warm_up_iters = 50 -def torch_fw_bw_benchmark(models: list, labels: list, inputs: list, iters: int) -> None: - - for m, input, label in zip(models, inputs, labels): - # Warm up - for _ in range(warm_up_iters): - _, loss = m(input) - loss.backward() - - torch.cuda.synchronize() - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - stream = torch.cuda.current_stream() - max_allocated_bytes = 0 - for i in range(iters): - torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) - torch.cuda.empty_cache() - torch.cuda._sleep(1_000_000) - - start_events[i].record(stream) - out = m(input) - end_events[i].record(stream) - - max_allocated_bytes = max( - max_allocated_bytes, torch.cuda.max_memory_allocated( - torch.cuda.current_device()) - ) - - torch.cuda.synchronize() - tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] - tot_time = sum(tot) / iters - print(f'{label} tot fw time: {tot_time} ms') - print(f'{label} max fw allocated memory: {max_allocated_bytes / (2**30)} GB') - - torch.cuda.synchronize() - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - stream = torch.cuda.current_stream() - max_allocated_bytes = 0 - for i in range(iters): - torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) - torch.cuda.empty_cache() - torch.cuda._sleep(1_000_000) - - _, loss = m(input) - loss.backward() - start_events[i].record(stream) - loss.backward() - end_events[i].record(stream) - - max_allocated_bytes = max( - max_allocated_bytes, torch.cuda.max_memory_allocated( - torch.cuda.current_device()) - ) - - torch.cuda.synchronize() - tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] - tot_time = sum(tot) / iters - print(f'{label} tot bw time: {tot_time} ms') - print(f'{label} max bw allocated memory: {max_allocated_bytes / (2**30)} GB') - -def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int) -> None: - - for m, input, label in zip(models, inputs, labels): - # Warm up - for _ in range(warm_up_iters): - _, loss = m(input) - loss.backward() - - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - torch.cuda.synchronize() - stream = torch.cuda.current_stream() - max_allocated_bytes = 0 - for i in range(iters): - torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) - torch.cuda.empty_cache() - torch.cuda._sleep(1_000_000) - - start_events[i].record(stream) - _, loss = m(input) - loss.backward() - end_events[i].record(stream) - - max_allocated_bytes = max( - max_allocated_bytes, torch.cuda.max_memory_allocated( - torch.cuda.current_device()) - ) - - torch.cuda.synchronize() - tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] - tot_time = sum(tot) / iters - print(f'{label} tot time: {tot_time} ms') - print(f'{label} max allocated memory: {max_allocated_bytes / (2**30)} GB') - -# ----------------------------------------------------------------------------- -batch_size = 12 -block_size = 1024 -bias = False -seed = 1337 -device = 'cuda' -# dtype = 'float16' -dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' -# ----------------------------------------------------------------------------- -torch.manual_seed(seed) -torch.cuda.manual_seed(seed) -torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul -torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn -device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast -ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] - -x = torch.randint(50304, (batch_size, block_size), device=device) -y = torch.randint(50304, (batch_size, block_size), device=device) -get_batch = lambda split: (x, y) - -# model init -gptconf = GPTConfig( - block_size = block_size, # how far back does the model look? i.e. context size - n_layer = 1, n_head = 12, n_embd = 768, # size of the model - dropout = 0, # for determinism - bias = bias, -) -model = GPT(gptconf) -model.to(device) - -jmodel_def = thunder.jit(model) -jmodel_auto = thunder.jit(model, autotune_type='runtime', executors = ['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python']) - -X, Y = get_batch('train') -# Run compilation -jmodel_def(x, y) -jmodel_auto(x, y) - -print('Results thunder benchmark:') -traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] -labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] -thunder_fw_bw_benchmark(traces, labels, 50) - -print('\n\nResults torch fw bw benchmark:') -callables = [jmodel_def, jmodel_auto] -labels = ['def', 'auto'] -inputs = [x, x] -torch_fw_bw_benchmark(callables, labels, inputs, 50) -print('\n\nResults torch tot benchmark:') -torch_total_benchmark(callables, labels, inputs, 50) - -print('\n\n\n\n\n\n') -print(f'{thunder.last_traces(jmodel_def)[-1]}') -print('###############################################################################') -print(f'{thunder.last_traces(jmodel_auto)[-1]}') - -print('\n\n') -print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') -print('###############################################################################') -print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') - +def run(): + # ----------------------------------------------------------------------------- + batch_size = 12 + block_size = 1024 + bias = False + real_data = False + seed = 1337 + device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. + dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' + compile = False # use PyTorch 2.0 to compile the model to be faster + profile = False # use pytorch profiler, or just simple benchmarking? + # exec(open('configurator.py').read()) # overrides from command line or config file + # ----------------------------------------------------------------------------- + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul + torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast + ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] + ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) + + # data loading init + if real_data: + raise RuntimeError('Not supported') + # dataset = 'openwebtext' + # data_dir = os.path.join('data', dataset) + # train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') + # def get_batch(split): + # data = train_data # note ignore split in benchmarking script + # ix = torch.randint(len(data) - block_size, (batch_size,)) + # x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) + # y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) + # x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) + # return x, y + else: + # alternatively, if fixed data is desired to not care about data loading + x = torch.randint(50304, (batch_size, block_size), device=device) + y = torch.randint(50304, (batch_size, block_size), device=device) + get_batch = lambda split: (x, y) + + # model init + gptconf = GPTConfig( + block_size = block_size, # how far back does the model look? i.e. context size + n_layer = 2, n_head = 12, n_embd = 768, # size of the model + dropout = 0, # for determinism + bias = bias, + ) + model = GPT(gptconf) + model.to(device) + + jmodel_def = thunder.jit(model) + # Currently sdpa does not work? + jmodel_auto = thunder.jit(model, autotune_type='memory', executors = ['torchcompile', 'nvfuser', 'cudnn', 'torch', 'python']) + + # optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type) + + if compile: + print("Compiling model...") + model = torch.compile(model) # pytorch 2.0 + + if profile: + # useful docs on pytorch profiler: + # - tutorial https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html + # - api https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile + wait, warmup, active = 5, 5, 5 + num_steps = wait + warmup + active + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler('./bench_log'), + record_shapes=False, + profile_memory=False, + with_stack=False, # incurs an additional overhead, disable if not needed + with_flops=True, + with_modules=False, # only for torchscript models atm + ) as prof: + + models = [jmodel_def, jmodel_auto] + + for mod in models: + print('Profiling new model') + X, Y = get_batch('train') + for k in range(num_steps): + with ctx: + logits, loss = model(X, Y) + X, Y = get_batch('train') + # optimizer.zero_grad(set_to_none=True) + loss.backward() + # optimizer.step() + lossf = loss.item() + print(f"{k}/{num_steps} loss: {lossf:.4f}") + + prof.step() # notify the profiler at end of each step + + else: + def measure(m, label): + # simple benchmarking + torch.cuda.synchronize() + + X, Y = get_batch('train') + for i in range(warm_up_iters): + with ctx: + logits, loss = m(X, Y) + X, Y = get_batch('train') + loss.backward() + + iters = 5 + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + stream = torch.cuda.current_stream() + torch.cuda.synchronize() + X, Y = get_batch('train') + for i in range(iters): + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + start_events[i].record(stream) + with ctx: + logits, loss = m(X, Y) + X, Y = get_batch('train') + loss.backward() + end_events[i].record(stream) + + torch.cuda.synchronize() + tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + tot_time = sum(tot) / iters + print('\n\nResults torch benchmark:') + print(f'{label} tot time: {tot_time} ms') + + measure(jmodel_auto, 'auto') + measure(jmodel_def, 'def') + + print('\n\nResults thunder benchmark:') + traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] + traces.reverse() + labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] + labels.reverse() + thunder_fw_bw_benchmark(traces, labels, 5) + + traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] + for t in traces: + print(f'{t}\n############################################') + +run() diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 48a673ddfc..53b19883de 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -195,8 +195,6 @@ def place_optimizers(self, in_trace, executor_list: list[Executor]) -> TraceCtx: swapmap: dict[Variable, Proxy] = {} - # During the fusion pass and CSE optimizatons some args in trace regions could be different from the cached args. Restore the correct arguments - # https://pytorch-lightning.slack.com/archives/C06QA9M8L3C/p1720732254341999 def restore_correct_args(trace_in: TraceCtx): def args_eq(a, b) -> bool: if len(a) != len(b): @@ -348,8 +346,8 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: return extrace def optimize(self, strat: OptimizationAlgorithm = OptimizationAlgorithm.BEST_FUSER): - from thunder.core.transform_common import replace_redundant_inputs from thunder.core.transform_common import dce + from thunder.executors.torch_autograd import update_bw_from_forward_optimization self.optimization_algorithm = strat @@ -412,32 +410,7 @@ def match_optimizer_algorithm(): log(f"Cached fw trace:\n{self.active_fw_trace}", level=LogLevel.DEBUG) log(f"Input bw trace:\n{self.trace}", level=LogLevel.DEBUG) - # Some of the optimization passes change proxies in the trace and - # any change in the forward trace must be reflected in the backward - # trace. - original_bw_saved_tensors_for_backward = self.trace.args[0][0] - new_fw_saved_tensors_for_backward = trc.output[1][0] - swap_map = { - variableify(x): y - for x, y in zip(original_bw_saved_tensors_for_backward, new_fw_saved_tensors_for_backward) - if variableify(x) != variableify(y) - } - new_bsyms = replace_redundant_inputs( - swap_map, self.trace.bound_symbols) - # replace_redundant_inputs doesn't replace the output of - # UNPACK_SEQUENCE so we do it manually. Here we have certain - # assumptions about the structure of the backward trace. - assert self.trace.bound_symbols[0].sym.id == PrimIDs.UNPACK_TRIVIAL - assert self.trace.bound_symbols[0].kwargs["name"] == "saved_for_backward" - assert self.trace.bound_symbols[4].sym.id == PrimIDs.UNPACK_SEQUENCE - assert self.trace.bound_symbols[4].args[0].name == "C0" - new_bsyms[4] = new_bsyms[4].from_bsym_swap_proxies( - swap_map, - skip_inputs=False, - skip_output=False, - skip_subsymbols=False, - ) - self.trace.bound_symbols = new_bsyms + self.trace = update_bw_from_forward_optimization(fw=trc, bw=self.trace) if self.apply_bucketing_bw_trace: from thunder.distributed.transforms import FSDPCommBucketing @@ -453,7 +426,7 @@ def match_optimizer_algorithm(): def can_executor_execute(self, ex: Executor, bsym: BoundSymbol) -> bool: try: return ex.can_execute(bsym) - except: + except Exception: return False # For each fusion executor in the input list, find the best trace dispatching for each executor @@ -605,8 +578,8 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): log(f"Group id: {id}", level=LogLevel.DEBUG) for sub in group: log(f"{sub.sym.name} -> out: {sub.output}", level=LogLevel.DEBUG) - # if len(group) > 0: - # print("\n") + if log_level == LogLevel.DEBUG and len(group) > 0: + print("\n") dict_time_strat: dict[str, Executor] = {} dict_mem_strat: dict[str, Executor] = {} @@ -680,11 +653,6 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): # Inside groups we should have alwasy tensors as out best_res_time = BenchmarkResult() best_res_mem = BenchmarkResult() - worst_res_time = BenchmarkResult() - worst_res_mem = BenchmarkResult() - # Only for visual - worst_res_mem.measure = 0 - worst_res_time.measure = 0 # TODO (matteochen): Aggregate them best_placement_time = None @@ -696,11 +664,9 @@ def measure_and_update_result(): nonlocal best_res_time nonlocal best_placement_time nonlocal best_keys_time - nonlocal worst_res_time nonlocal best_res_mem nonlocal best_placement_mem nonlocal best_keys_mem - nonlocal worst_res_mem trc, keys, placements = get_placed_trace( dict_time_strat, increasing_symbols) if self.trace_type == TraceType.BW and self.active_fw_trace is not None: @@ -716,8 +682,6 @@ def measure_and_update_result(): best_res_time.trace = trc best_placement_time = placements best_keys_time = keys - if cost > worst_res_time.tm: - worst_res_time.tm = cost if mem < best_res_mem.mem or (mem == best_res_mem.mem and cost < best_res_mem.tm): best_res_mem.tm = cost @@ -725,8 +689,6 @@ def measure_and_update_result(): best_res_mem.trace = trc best_placement_mem = placements best_keys_mem = keys - if mem > worst_res_mem.mem: - worst_res_mem.mem = mem start_idx = 0 # This is to accomodate the following TODO @@ -736,7 +698,7 @@ def measure_and_update_result(): for idx in range(0, len(group)): if group[idx].sym.name == "embedding_backward": last_embedding_idx = idx - log(f"last embedding {last_embedding_idx}", level=LogLevel.DEBUG) + log(f"last embedding idx: {last_embedding_idx}", level=LogLevel.DEBUG) if last_embedding_idx != -1: # Until last_embedding_idx (included) assigned to current fusion ex for i in range(0, last_embedding_idx + 1, 1): @@ -785,11 +747,11 @@ def measure_and_update_result(): raise AssertionError("Failed to get best placement") log( - f"For group {group_id} best placement with time cost = {best_res_time.tm} ms (worst time = {worst_res_time.tm} ms):\n{best_res_time.trace}", + f"For group {group_id} best placement with time cost = {best_res_time.tm} ms:\n{best_res_time.trace}", level=LogLevel.DEBUG ) log( - f"For group {group_id} best placement with mem cost = {best_res_mem.mem / (2**30)} GB (worst mem = {worst_res_mem.mem/(2**30)} GB) is:\n{best_res_mem.trace}", + f"For group {group_id} best placement with mem cost = {best_res_mem.mem / (2**30)} GB:\n{best_res_mem.trace}", level=LogLevel.DEBUG ) @@ -876,7 +838,7 @@ def measure_and_update_result(): self.fusion_strat_helper.optimized_traces_time_benchmark_only.append({ ex.name: trc}) - # Save executors in order to generate real fw and bw trace with correct output + # Save executors in order to generate real fw and bw trace with correct output with the placer self.executor_placement_options.placement_options_time.append( executors_time) self.executor_placement_options.placement_options_mem.append( @@ -894,11 +856,22 @@ def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: log(f'Computing the best pair option (tot options = {len(self.out)})', level=LogLevel.INFO) for pair in self.out: if pair.tot_cost < min_value: - log(f"New best pair:\n{pair}", level=LogLevel.INFO) + log(f"New best pair:\n{pair}", level=LogLevel.DEBUG) min_value = pair.tot_cost ans = pair if ans is None: raise AssertionError('Best pair not found') + + fw = ans.fw + c, m, _ = benchmark_trace(fw, iters=self.benchmark_iters) + log(f'Final pair fw: {c} ms - {m / (2**30)} GB\n{fw}', level=LogLevel.INFO) + bw = ans.bw + c, m, _ = benchmark_trace(bw, iters=self.benchmark_iters) + log(f'Final pair bw: {c} ms - {m / (2**30)} GB\n{bw}', level=LogLevel.INFO) + + + # To debug this: the traces that we will received in the remat call in should be the same as these and runtime should be in line with the best pair time. + # The pairs above are traces with no remat call (in order to be called later on) but their tracking time are made with traces gone under the remat call return ans.fw, ans.bw def bsym_assigned(self, bsym: BoundSymbol) -> bool: @@ -1012,12 +985,13 @@ def bw_benchmark(): # The current fw trace is set by the caller and we take it as is. All current bw traces optimizations are made with the fw trace set by the caller forward_time, forward_memory, _ = benchmark_trace( self.active_fw_trace, self.benchmark_iters) + match self.optimizer_type: case OptimizerType.RUNTIME: # Used the computed benchmark from above if time_result.tm < memory_result.tm: log( - f"out candidate times: (fw){forward_time} ms, (bw){time_result.tm} ms", level=LogLevel.INFO) + f"out candidate times from time res: (fw){forward_time} ms, (bw){time_result.tm} ms", level=LogLevel.INFO) self.out.append( FinalOutputCandidates( fw=self.active_fw_trace, @@ -1027,7 +1001,7 @@ def bw_benchmark(): ) else: log( - f"out candidate times: (fw){forward_time} ms, (bw){memory_result.tm} ms", level=LogLevel.INFO) + f"out candidate times from mem res: (fw){forward_time} ms, (bw){memory_result.tm} ms", level=LogLevel.INFO) self.out.append( FinalOutputCandidates( fw=self.active_fw_trace, @@ -1039,7 +1013,7 @@ def bw_benchmark(): # Used the computed benchmark from above if time_result.mem < memory_result.mem: log( - f"out candidate mem: (fw){forward_memory / (2**30)} GB, (bw){time_result.mem / (2**30)} GB", level=LogLevel.INFO) + f"out candidate mem from time res: (fw){forward_memory / (2**30)} GB (fw){forward_time} ms, (bw){time_result.mem / (2**30)} GB (bw){time_result.tm} ms", level=LogLevel.INFO) self.out.append( FinalOutputCandidates( fw=self.active_fw_trace, @@ -1049,7 +1023,7 @@ def bw_benchmark(): ) else: log( - f"out candidate mem: (fw){forward_memory / (2**30)} GB, (bw){memory_result.mem / (2**30)} GB", level=LogLevel.INFO) + f"out candidate mem from time res: (fw){forward_memory / (2**30)} GB (fw){forward_time} ms, (bw){memory_result.mem / (2**30)} GB (bw){memory_result.tm} ms", level=LogLevel.INFO) self.out.append( FinalOutputCandidates( fw=self.active_fw_trace, @@ -1144,7 +1118,7 @@ def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, f print(f"#NVSIGHT FN EXECUTION FAILED:\n{trc}") raise e - def compute_time_cost_ms(fn: Callable, iters: int, *args) -> tuple[float, float, Any]: + def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[float, float, Any]: try: warm_up_iters = 50 out = None @@ -1184,13 +1158,11 @@ def compute_time_cost_ms(fn: Callable, iters: int, *args) -> tuple[float, float, torch.cuda.synchronize() times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + print(f'times: {times}') tot_time = sum(times) / iters return tot_time, max_allocated_bytes, out except Exception as e: - import inspect - - trc = inspect.getsource(fn) - print(f"#FN EXECUTION FAILED:\n{trc}") + print(f"#FN EXECUTION FAILED:\n{repr}") raise e def print_input_args(args, level=0, show_content=False): @@ -1336,6 +1308,7 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: trace_tok = set_tracectx(trace) # Obtain the python executable string + execuhtable_str = trace.python() executable = trace.python_callable() if show_func: print(inspect.getsource(executable)) @@ -1347,9 +1320,10 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: if nvsight: t, m, answer = compute_time_cost_nvsight(executable, iters, *input_args) else: - t, m, answer = compute_time_cost_ms(executable, iters, *input_args) + t, m, answer = compute_time_cost_ms(executable, execuhtable_str, iters, *input_args) except Exception as e: # https://github.com/Lightning-AI/lightning-thunder/issues/664 + # Seems that this patch never work ... print(f"Exception:\n{e}") if "call_method UserDefinedObjectVariable(set) __contains__ [UserDefinedObjectVariable()] {}" in str(e) and not nvsight: print( diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 738a218dc2..04c1095ade 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -106,6 +106,36 @@ def backward(ctx, *args): del grads return (None, None, None, None, None, *([None] * n_grads)) +def update_bw_from_forward_optimization(*, fw: TraceCtx, bw:TraceCtx) -> TraceCtx: + # Some of the optimization passes change proxies in the trace and + # any change in the forward trace must be reflected in the backward + # trace. + original_bw_saved_tensors_for_backward = bw.args[0][0] + new_fw_saved_tensors_for_backward = fw.output[1][0] + swap_map = { + variableify(x): y + for x, y in zip(original_bw_saved_tensors_for_backward, new_fw_saved_tensors_for_backward) + if variableify(x) != variableify(y) + } + new_bsyms = replace_redundant_inputs( + swap_map, bw.bound_symbols) + # replace_redundant_inputs doesn't replace the output of + # UNPACK_SEQUENCE so we do it manually. Here we have certain + # assumptions about the structure of the backward trace. + assert bw.bound_symbols[0].sym.id == PrimIDs.UNPACK_TRIVIAL + assert bw.bound_symbols[0].kwargs["name"] == "saved_for_backward" + assert bw.bound_symbols[4].sym.id == PrimIDs.UNPACK_SEQUENCE + assert bw.bound_symbols[4].args[0].name == "C0" + new_bsyms[4] = new_bsyms[4].from_bsym_swap_proxies( + swap_map, + skip_inputs=False, + skip_output=False, + skip_subsymbols=False, + ) + bw.bound_symbols = new_bsyms + + return bw + def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stats, autotune_type, /, *flat_args): from thunder.core.rematerialization import rematerialize_all_gather, rematerialize_forward_and_backward from thunder.core.transforms import forward_and_backward_from_trace @@ -195,42 +225,18 @@ def split(): # If in default mode, otherwise the best fw will be returned only at the end if autotune_type is None: + # Here fw_extrace is not None + fw_traces.append(fw_extrace) visualizer.set_fw_optimized_trace(fw_extrace) - # NOTE: autotuner will take care of this - # Some of the optimization passes change proxies in the trace and - # any change in the forward trace must be reflected in the backward - # trace. - original_bw_saved_tensors_for_backward = bw_trace.args[0][0] - new_fw_saved_tensors_for_backward = fw_extrace.output[1][0] - swap_map = { - variableify(x): y - for x, y in zip(original_bw_saved_tensors_for_backward, new_fw_saved_tensors_for_backward) - if variableify(x) != variableify(y) - } - new_bsyms = replace_redundant_inputs(swap_map, bw_trace.bound_symbols) - # replace_redundant_inputs doesn't replace the output of - # UNPACK_SEQUENCE so we do it manually. Here we have certain - # assumptions about the structure of the backward trace. - assert bw_trace.bound_symbols[0].sym.id == PrimIDs.UNPACK_TRIVIAL - assert bw_trace.bound_symbols[0].kwargs["name"] == "saved_for_backward" - assert bw_trace.bound_symbols[4].sym.id == PrimIDs.UNPACK_SEQUENCE - assert bw_trace.bound_symbols[4].args[0].name == "C0" - new_bsyms[4] = new_bsyms[4].from_bsym_swap_proxies( - swap_map, - skip_inputs=False, - skip_output=False, - skip_subsymbols=False, - ) - bw_trace.bound_symbols = new_bsyms - + # If autotuning is activated, it will take care of the followinf 2 calls + bw_trace = update_bw_from_forward_optimization(fw=fw_extrace, bw=bw_trace) if do_apply_bucketing_bw_trace: bw_trace = _fsdp_comm_bucketing.apply_bucketing_to_backward_trace(bw_trace) # Now we can run the optimization passes on the backward trace # TODO Restore request for no rematerialization - visualizer.set_bw_initial_trace(bw_trace) if autotune_type is not None: fw_extrace, bw_extrace = autotune_transform_for_execution( @@ -247,14 +253,15 @@ def split(): visualizer.set_bw_optimized_trace(bw_extrace) # TODO Restore request for no rematerialization - c, m, _ = benchmark_trace(fw_extrace, iters=50) - log(f'before remat fw trace time = {c}, mem = {m}', level=LogLevel.INFO) - c, m, _ = benchmark_trace(bw_extrace, iters=50) + # TODO (matteochen): remove these logs + c, m, _ = benchmark_trace(fw_extrace, iters=5) + log(f'before remat fw trace time = {c}, mem = {m}\n{fw_extrace}', level=LogLevel.INFO) + c, m, _ = benchmark_trace(bw_extrace, iters=5) log(f'before remat bw trace time = {c}, mem = {m}', level=LogLevel.INFO) fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) - c, m, _ = benchmark_trace(fw_extrace, iters=50) - log(f'after remat fw trace time = {c}, mem = {m}', level=LogLevel.INFO) - c, m, _ = benchmark_trace(bw_extrace, iters=50) + c, m, _ = benchmark_trace(fw_extrace, iters=5) + log(f'after remat fw trace time = {c}, mem = {m}\n{fw_extrace}', level=LogLevel.INFO) + c, m, _ = benchmark_trace(bw_extrace, iters=5) log(f'after remat bw trace time = {c}, mem = {m}', level=LogLevel.INFO) fw_traces.append(fw_extrace) bw_traces.append(bw_extrace) @@ -360,6 +367,8 @@ def split(): if autotune_type is not None: cached_executor_list = list(compile_data.executors_list) try: + # disable this part for now + # raise RuntimeError('Disabled') is_tuned = False # We are interested to save the best_*s at the last iteration over the executors_candidates dict as the last @@ -376,20 +385,21 @@ def split(): for i, (ex_type, ex_list) in enumerate(executors_candidates.items()): log( f"================================================================================ Before Autotune Tuning: Optimizing {ex_type}", - level=LogLevel.DEBUG) + level=LogLevel.INFO) # Search in the requested executor list if one or more than one options for a know multiple executable region is available to_benchmark = [ex for ex in cached_executor_list if ex.name in ex_list] if not to_benchmark: log( f"================================================================================ Before Autotune Tuning: Skipping optimization for {ex_type} as not requested.", - level=LogLevel.DEBUG) + level=LogLevel.INFO) for e in to_benchmark: compile_data.executors_list = [ex for ex in cached_executor_list if ex not in to_benchmark] + # Make it with most priority compile_data.executors_list.insert(0, e) log( - f"================================================================================ Before Autotune Tuning: Testing compile data executors: {compile_data.executors_list}", level=LogLevel.DEBUG) + f"================================================================================ Before Autotune Tuning: Testing compile data executors: {compile_data.executors_list}", level=LogLevel.INFO) primal_trace, fw_extrace, bw_extrace, fw_traces, bw_traces = split() time_fw, mem_fw, _ = benchmark_trace(fw_extrace, iters=10, apply_del_last_used=False) @@ -397,11 +407,11 @@ def split(): tot_time = time_fw + time_bw tot_mem = mem_fw + mem_bw log( - f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} options: {e.name}. Time fw = {time_fw} ms - Time bw = {time_bw} ms - Mem fw = {mem_fw / (2**30)} GB - Mem bw = {mem_bw / (2**30)} GB", level=LogLevel.DEBUG) + f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} options: {e.name}. Time fw = {time_fw} ms - Time bw = {time_bw} ms - Mem fw = {mem_fw / (2**30)} GB - Mem bw = {mem_bw / (2**30)} GB", level=LogLevel.INFO) log( - f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} options: {e.name}. Time = {tot_time} ms - Mem = {tot_mem / (2**30)} GB", level=LogLevel.DEBUG) - log(f'Fw trace:\n{fw_extrace}', level=LogLevel.DEBUG) - log(f'Bw trace:\n{bw_extrace}', level=LogLevel.DEBUG) + f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} options: {e.name}. Time = {tot_time} ms - Mem = {tot_mem / (2**30)} GB", level=LogLevel.INFO) + log(f'Fw trace:\n{fw_extrace}', level=LogLevel.INFO) + log(f'Bw trace:\n{bw_extrace}', level=LogLevel.INFO) benchmark_cost = tot_time if autotune_type == OptimizerType.RUNTIME else tot_mem if benchmark_cost < best_cost: @@ -421,10 +431,10 @@ def split(): c, m , _ = benchmark_trace(best_fw_extrace, iters=10, apply_del_last_used=False) log( - f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} best fw_extrace (time = {c}, mem = {m}):\n{best_fw_extrace}", level=LogLevel.DEBUG) + f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} best fw_extrace (time = {c}, mem = {m}):\n{best_fw_extrace}", level=LogLevel.INFO) c, m , _ = benchmark_trace(best_bw_extrace, iters=10, apply_del_last_used=False) log( - f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} best bw_extrace (time = {c}, mem = {m}):\n{best_bw_extrace}", level=LogLevel.DEBUG) + f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} best bw_extrace (time = {c}, mem = {m}):\n{best_bw_extrace}", level=LogLevel.INFO) # Update the executor list with the winner executor for the current ex_type cached_executor_list = [ex for ex in cached_executor_list if ex not in to_benchmark] @@ -433,7 +443,7 @@ def split(): cached_executor_list.insert(0, best_executor) best_executor = None log( - f"================================================================================ Before Autotune Tuning: Benchmark {ex_type}, new executor list: {cached_executor_list}", level=LogLevel.DEBUG) + f"================================================================================ Before Autotune Tuning: Benchmark {ex_type}, new executor list: {cached_executor_list}", level=LogLevel.INFO) # Update the compile stats on the last iter if i == len(executors_candidates)-1: @@ -443,7 +453,7 @@ def split(): compile_data.executors_list = list(cached_executor_list) log( - f"================================================================================ Before Autotune Tuning: autotuned split_forward_backward from {executors_candidates}", level=LogLevel.DEBUG) + f"================================================================================ Before Autotune Tuning: autotuned split_forward_backward from {executors_candidates}", level=LogLevel.INFO) if compile_stats is not None: compile_stats.last_traces.append(best_primal_trace) compile_stats.last_traces += best_fw_traces @@ -456,7 +466,7 @@ def split(): compile_data.executors_list = list(cached_executor_list) log( - f"================================================================================ Before Autotune Tuning: not autotuned split_forward_backward from {executors_candidates}", level=LogLevel.DEBUG) + f"================================================================================ Before Autotune Tuning: not autotuned split_forward_backward from {executors_candidates}", level=LogLevel.INFO) primal_trace, fw_extrace, bw_extrace, fw_traces, bw_traces = split() if compile_stats is not None: compile_stats.last_traces.append(primal_trace) @@ -470,7 +480,21 @@ def split(): compile_data.executors_list = list(cached_executor_list) log( - f"================================================================================ Before Autotune Tuning: exception occured, not autotuned split_forward_backward from {executors_candidates}", level=LogLevel.DEBUG) + f"================================================================================ Before Autotune Tuning: exception occured, not autotuned split_forward_backward from {executors_candidates}", level=LogLevel.INFO) + primal_trace, fw_extrace, bw_extrace, fw_traces, bw_traces = split() + if compile_stats is not None: + compile_stats.last_traces.append(primal_trace) + compile_stats.last_traces += fw_traces + compile_stats.last_backward_traces += bw_traces + + return fw_extrace, bw_extrace + except RuntimeError as e: + print(f'Exception occured: {e}') + # Restore before calling split + compile_data.executors_list = list(cached_executor_list) + + log( + f"================================================================================ Before Autotune Tuning: exception occured, not autotuned split_forward_backward from {executors_candidates}", level=LogLevel.INFO) primal_trace, fw_extrace, bw_extrace, fw_traces, bw_traces = split() if compile_stats is not None: compile_stats.last_traces.append(primal_trace) From 89bba8d9a0aeb0e937e617dbf061191f22290ff3 Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 2 Aug 2024 15:24:41 +0300 Subject: [PATCH 026/171] Testing different compile options for `nvfuser` (#7) --- thunder/backend_optimizer/optimizer.py | 424 ++++++++++++++----------- thunder/executors/torch_autograd.py | 1 + 2 files changed, 238 insertions(+), 187 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 53b19883de..12d51bee98 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -83,15 +83,14 @@ def __repr__(self) -> str: # Benchmark only traces will contain traces after the rematerialization call with fw and bw calls, reproducing what will be the real traces after the autotune pass # Non benchmark traces will contain traces after the placement (default) with no call to remat # We have duplciated those in order to maintain thunder compilation flow as the output from the autotuner will be the traces with no pass through rematerialization +# TODO: torchcompile_cat currently is not supported as the autotuner search space in the FusionExecutor section is limited to 1 class FusionStratHelper: def __init__(self) -> None: self.supported_executors: set = set(["nvfuser", "torchcompile"]) self.optimized_traces_mem: list[dict[str | Hashable, TraceCtx]] = [] - self.optimized_traces_mem_benchmark_only: list[dict[str | Hashable, TraceCtx]] = [ - ] + self.optimized_traces_mem_benchmark_only: list[dict[str | Hashable, TraceCtx]] = [] self.optimized_traces_time: list[dict[str | Hashable, TraceCtx]] = [] - self.optimized_traces_time_benchmark_only: list[dict[str | Hashable, TraceCtx]] = [ - ] + self.optimized_traces_time_benchmark_only: list[dict[str | Hashable, TraceCtx]] = [] class ExecutorPlacementOptions: @@ -110,11 +109,17 @@ class LogLevel(Enum): def log(what: str, level: LogLevel): if log_level == LogLevel.DEBUG or log_level == level: - print( - f"================================================================================ Autotune: {what}") + print(f"================================================================================ Autotune: {what}") + + +class FusionCompileOptionsHelper: + def __init__(self, fusion_tag: str, symbol_tag: str) -> None: + self.fusion_tag = fusion_tag + self.symbol_tag = symbol_tag class BackendOptimizer: + # from thunder.common import CompileData def __init__( self, @@ -125,6 +130,7 @@ def __init__( log_file_name="autotune_debug.log", visualizer: Visualizer | None = None, optimizer_type: OptimizerType = OptimizerType.RUNTIME, + compile_data, ) -> None: self.always_executors: tuple[Executor, ...] = get_always_executors() self.empty_executor_hashable_placeholder: str = "empty" @@ -132,6 +138,8 @@ def __init__( self.fusion_executors: Sequence[FusionExecutor] = [ ex for ex in self.executors if isinstance(ex, FusionExecutor) ] + # Helper needed for later + self.fusion_executors_saved_for_later: Sequence[FusionExecutor] = [] self.debug_msg: str = "" self.partial_costs: dict[TraceCtx, float] = {} @@ -153,7 +161,17 @@ def __init__( self.apply_bucketing_bw_trace: bool = apply_bucketing_bw_trace - self.benchmark_iters = 5 + self.benchmark_iters: int = 5 + + self.compile_data = compile_data + + self.known_fusion_ex_compile_options: dict[str | Hashable, list[FusionCompileOptionsHelper]] = { + "nvfuser": [ + FusionCompileOptionsHelper("nv_enable_linear", "linear"), + FusionCompileOptionsHelper("nv_enable_matmul", "matmul"), + FusionCompileOptionsHelper("nv_enable_bookend", "bookend"), + ] + } log("Executors:", level=LogLevel.INFO) for e in self.executors: @@ -170,25 +188,26 @@ def __init__(self, symbol: BoundSymbolInterface, idx: int) -> None: def attach_cached_fw_traces(self, cached_fw_traces: TraceCandidates, executor_name: str) -> None: self.cached_fw_traces[executor_name] = cached_fw_traces - def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): + def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): from thunder.core.transform_common import dce self.trace_type = trace_type # dce for the backward trace will be passed afterwards - self.trace: TraceCtx = dce( - trace) if trace_type == TraceType.FW else trace + self.trace: TraceCtx = dce(trace) if trace_type == TraceType.FW else trace match self.trace_type: case TraceType.FW: log( - f"New forward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}", level=LogLevel.INFO) + f"New forward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}", level=LogLevel.INFO + ) # TODO (matteochen): support bw trace optimization even though with no fw traces cached case TraceType.BW: if not self.cached_fw_traces: - raise AssertionError( - "Can not optimize backward traces before forward traces") + raise AssertionError("Can not optimize backward traces before forward traces") log( - f"New backward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}", level=LogLevel.INFO) + f"New backward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}", + level=LogLevel.INFO, + ) def place_optimizers(self, in_trace, executor_list: list[Executor]) -> TraceCtx: from thunder.executors.passes import _transform_for_operator_executor_execution @@ -205,8 +224,7 @@ def args_eq(a, b) -> bool: return False elif type(obj_a) == type(obj_b) and not isinstance(obj_a, TensorProxy): if obj_a != obj_b: - raise AssertionError( - f"What do you want to do here:\nobj_a:\n{obj_a}\nobj_b:{obj_b}") + raise AssertionError(f"What do you want to do here:\nobj_a:\n{obj_a}\nobj_b:{obj_b}") return True def clear(bsym: BoundSymbol, input): @@ -251,8 +269,7 @@ def visit_helper(bsym: BoundSymbol, ex: Executor) -> None | bool: if ex.name == self.empty_executor_hashable_placeholder: return None - execution_transform: None | Callable = ex.get_execution_transform( - bsym.sym) + execution_transform: None | Callable = ex.get_execution_transform(bsym.sym) out: Any if execution_transform is not None: out = execution_transform(*bsym.args, **bsym.kwargs) @@ -276,8 +293,7 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: return transforms.VISIT_TYPE.NO_OP if visit_helper(bsym, ex) is None else transforms.VISIT_TYPE.REPLACE if len(executor_list) != len(in_trace.bound_symbols): - raise AssertionError( - "len(executor_list) != len(in_trace.bound_symbols)") + raise AssertionError("len(executor_list) != len(in_trace.bound_symbols)") # log(f'Visit transf') # for n, e in zip(in_trace.bound_symbols, executor_list): @@ -288,8 +304,7 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: # Input should have equal length if len(executor_list) != len(in_trace.bound_symbols): - raise AssertionError( - "len(executor_list) != len(extrace.bound_symbols)") + raise AssertionError("len(executor_list) != len(extrace.bound_symbols)") for b, e in zip(in_trace.bound_symbols, executor_list): if isinstance(e, FusionExecutor): @@ -297,8 +312,7 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: if isinstance(b.output, TensorProxy): executor_mapping[b.output.name] = e - extrace = transforms.visitor_transform_paired( - in_trace, visit, zip(in_trace.bound_symbols, executor_list)) + extrace = transforms.visitor_transform_paired(in_trace, visit, zip(in_trace.bound_symbols, executor_list)) # Restores original variables bound_symbols: list[BoundSymbol] = [] @@ -340,11 +354,21 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: restore_correct_args(extrace) # Apply always executors - extrace = _transform_for_operator_executor_execution( - extrace, self.always_executors) + extrace = _transform_for_operator_executor_execution(extrace, self.always_executors) return extrace + def op_in_trace(self, trace: TraceCtx, op: str) -> bool: + # Some optimizations are not available as symbols + always_true = set(["bookend"]) + + if op in always_true: + return True + for b in trace.bound_symbols: + if b.sym.name == op: + return True + return False + def optimize(self, strat: OptimizationAlgorithm = OptimizationAlgorithm.BEST_FUSER): from thunder.core.transform_common import dce from thunder.executors.torch_autograd import update_bw_from_forward_optimization @@ -359,20 +383,26 @@ def optmize_best_fuser(): self.build_placement_options_best_fuser() - if len(self.executor_placement_options.placement_options_time) != len(self.fusion_executors): + if len(self.executor_placement_options.placement_options_time) != len( + self.fusion_executors_saved_for_later + ): raise AssertionError( f"Unexpected time placement options size: {len(self.executor_placement_options.placement_options_time)}. Expected: {len(self.fusion_executors)}" ) - if len(self.executor_placement_options.placement_options_mem) != len(self.fusion_executors): + if len(self.executor_placement_options.placement_options_mem) != len(self.fusion_executors_saved_for_later): raise AssertionError( f"Unexpected mem placement options size: {len(self.executor_placement_options.placement_options_mem)}. Expected: {len(self.fusion_executors)}" ) - for placement, ex in zip(self.executor_placement_options.placement_options_time, self.fusion_executors): + for placement, ex in zip( + self.executor_placement_options.placement_options_time, self.fusion_executors_saved_for_later + ): self.fusion_strat_helper.optimized_traces_time.append( {ex.name: self.place_optimizers(self.trace, placement)} ) - for placement, ex in zip(self.executor_placement_options.placement_options_mem, self.fusion_executors): + for placement, ex in zip( + self.executor_placement_options.placement_options_mem, self.fusion_executors_saved_for_later + ): self.fusion_strat_helper.optimized_traces_mem.append( {ex.name: self.place_optimizers(self.trace, placement)} ) @@ -391,19 +421,16 @@ def match_optimizer_algorithm(): case TraceType.BW: # Cached the bw trace as we need to modify the input trace during the loop cached_self_trace = from_trace(self.trace) - cached_self_trace.bound_symbols = list( - self.trace.bound_symbols) + cached_self_trace.bound_symbols = list(self.trace.bound_symbols) for label, candidate in self.cached_fw_traces.items(): - log(f'Backward optimization with fw from {label}', level=LogLevel.INFO) + log(f"Backward optimization with fw from {label}", level=LogLevel.INFO) fw_traces = candidate.iterable() for trc in fw_traces: - # TODO (matteochen): unify below with the original block # Restore the original bw trace self.trace = from_trace(cached_self_trace) - self.trace.bound_symbols = list( - cached_self_trace.bound_symbols) + self.trace.bound_symbols = list(cached_self_trace.bound_symbols) # Set the current active cached forward trace self.active_fw_trace = trc @@ -415,8 +442,7 @@ def match_optimizer_algorithm(): if self.apply_bucketing_bw_trace: from thunder.distributed.transforms import FSDPCommBucketing - self.trace = FSDPCommBucketing.apply_bucketing_to_backward_trace( - self.trace) + self.trace = FSDPCommBucketing.apply_bucketing_to_backward_trace(self.trace) # Not called in the constructor for bw traces dce(self.trace) @@ -445,13 +471,12 @@ def sequence_hash(s: Sequence) -> str: ): name += e.name # TODO (matteochen): investigate if this is suitable - elif isinstance(e ,int): - name += f'int{e}' + elif isinstance(e, int): + name += f"int{e}" elif e is None: name += "None" else: - raise AssertionError( - f"What? Maybe nested Sequence. type = {type(e)}") + raise AssertionError(f"What? Maybe nested Sequence. type = {type(e)}") return name # TODO (matteochen): Benchmark the optimal executor and call this optimal @@ -472,12 +497,10 @@ def get_placed_trace(mapping: dict[str, Executor], bound_symbols_in: Sequence[Bo # For this partial trace we have to return all not used tensors otherwise the dce will cut them out tensors = return_not_used_vars(trc) - forced_return_bsym = self.trace.bound_symbols[-1].from_bsym( - args=tensors) + forced_return_bsym = self.trace.bound_symbols[-1].from_bsym(args=tensors) executor_configuration = [] - empty_executor = Executor( - name=self.empty_executor_hashable_placeholder) + empty_executor = Executor(name=self.empty_executor_hashable_placeholder) keys = [] for bsym in trc.bound_symbols: if bsym.sym.name == "return": @@ -486,8 +509,7 @@ def get_placed_trace(mapping: dict[str, Executor], bound_symbols_in: Sequence[Bo # keys.append('return') elif isinstance(bsym.output, Sequence): seq_hash = sequence_hash(bsym.output) - executor_configuration.append( - mapping.get(seq_hash, empty_executor)) + executor_configuration.append(mapping.get(seq_hash, empty_executor)) keys.append(seq_hash) elif ( isinstance(bsym.output, CollectionProxy) @@ -496,18 +518,15 @@ def get_placed_trace(mapping: dict[str, Executor], bound_symbols_in: Sequence[Bo or isinstance(bsym.output, FloatProxy) ): if bsym.output.name not in mapping: - raise AssertionError( - f"Expected key {bsym.output.name} in mapping {mapping}") + raise AssertionError(f"Expected key {bsym.output.name} in mapping {mapping}") executor_configuration.append(mapping[bsym.output.name]) keys.append(bsym.output.name) else: - raise AssertionError( - f"Type not handled: {type(bsym.output)}") + raise AssertionError(f"Type not handled: {type(bsym.output)}") if trc.bound_symbols[-1].sym.name != "return": trc.bound_symbols.append(forced_return_bsym) - executor_configuration.append( - Executor(name=self.empty_executor_hashable_placeholder)) + executor_configuration.append(Executor(name=self.empty_executor_hashable_placeholder)) keys.append("return") if len(trc.bound_symbols) != len(executor_configuration) or len(keys) != len(executor_configuration): @@ -522,16 +541,9 @@ def get_placed_trace(mapping: dict[str, Executor], bound_symbols_in: Sequence[Bo placed_trace = self.place_optimizers(trc, executor_configuration) return placed_trace, keys, executor_configuration - ex: FusionExecutor - for ex in self.fusion_executors: - if ex.name not in self.fusion_strat_helper.supported_executors: - raise AssertionError( - f"Fusion operator not supported: {ex.name}") - - log( - f"Searching best placement for fusion executor = {ex.name}", level=LogLevel.DEBUG) + def search(ex: FusionExecutor): + # Each executor has a custom should fuse function, but the current impl need to access local executor object - # TODO (matteochen): each executor has a custom should fuse function, can we make this prettier? def _should_fuse_nvfuser(a: Node, b: Node): def _can_fuse_node(n: Node): # if already merged, then node can be fused @@ -566,8 +578,7 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): for d in dicts: d[bsym_in.output.name] = ex_in else: - raise AssertionError( - f"Type not handled: {type(bsym_in.output)}") + raise AssertionError(f"Type not handled: {type(bsym_in.output)}") bound_symbol_groups = fuse_bound_symbols( self.trace, _should_fuse_nvfuser if ex.name == "nvfuser" else _should_fuse_torchcompile @@ -598,7 +609,11 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): log(f"--> Single group: {current_bsym.sym.name}", level=LogLevel.DEBUG) name = current_bsym.sym.name # Filter out all possible candidates for the current symbol - candidate_executors = [ex for ex in self.executors if self.can_executor_execute(ex, current_bsym) and not isinstance(ex, FusionExecutor)] + candidate_executors = [ + ex + for ex in self.executors + if self.can_executor_execute(ex, current_bsym) and not isinstance(ex, FusionExecutor) + ] if name == "return": dict_time_strat["return"] = Executor(name=self.empty_executor_hashable_placeholder) @@ -609,10 +624,13 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): # Not executors available if not candidate_executors: match_bsym_output( - current_bsym, [dict_time_strat, dict_mem_strat], Executor(name=self.empty_executor_hashable_placeholder)) + current_bsym, + [dict_time_strat, dict_mem_strat], + Executor(name=self.empty_executor_hashable_placeholder), + ) continue else: - log(f'Available executors for single region:\n{candidate_executors}', level=LogLevel.DEBUG) + log(f"Available executors for single region:\n{candidate_executors}", level=LogLevel.DEBUG) # Helpers candidate_best_time = BenchmarkResult() @@ -620,14 +638,11 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): # Search for best candidate for i, candidate in enumerate(candidate_executors): # Match the current candidate to benchmark partial trace - match_bsym_output( - current_bsym, [dict_time_strat, dict_mem_strat], candidate) + match_bsym_output(current_bsym, [dict_time_strat, dict_mem_strat], candidate) # Retrieve partial trace and benchmark, apply remat if possible - trc, _, _ = get_placed_trace( - dict_time_strat, increasing_symbols) + trc, _, _ = get_placed_trace(dict_time_strat, increasing_symbols) if self.trace_type == TraceType.BW and self.active_fw_trace is not None: - _, trc = rematerialize_forward_and_backward( - self.active_fw_trace, trc) + _, trc = rematerialize_forward_and_backward(self.active_fw_trace, trc) t, m, _ = benchmark_trace(trc, self.benchmark_iters) # Update results if t < candidate_best_time.tm: @@ -639,15 +654,21 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): candidate_best_mem.index = i if candidate_best_time.index == -1 or candidate_best_mem.index == -1: - raise AssertionError(f'Failed to get optimal single trace region candidate. Available candidates for {name}:\n{candidate_executors}') + raise AssertionError( + f"Failed to get optimal single trace region candidate. Available candidates for {name}:\n{candidate_executors}" + ) - log(f'Best time OperatorExecutor for single {name}: {candidate_executors[candidate_best_time.index].name}', level=LogLevel.DEBUG) - log(f'Best mem OperatorExecutor for single {name}: {candidate_executors[candidate_best_mem.index].name}', level=LogLevel.DEBUG) + log( + f"Best time OperatorExecutor for single {name}: {candidate_executors[candidate_best_time.index].name}", + level=LogLevel.DEBUG, + ) + log( + f"Best mem OperatorExecutor for single {name}: {candidate_executors[candidate_best_mem.index].name}", + level=LogLevel.DEBUG, + ) - match_bsym_output( - current_bsym, [dict_time_strat], candidate_executors[candidate_best_time.index]) - match_bsym_output( - current_bsym, [dict_mem_strat], candidate_executors[candidate_best_mem.index]) + match_bsym_output(current_bsym, [dict_time_strat], candidate_executors[candidate_best_time.index]) + match_bsym_output(current_bsym, [dict_mem_strat], candidate_executors[candidate_best_mem.index]) continue # Inside groups we should have alwasy tensors as out @@ -667,15 +688,12 @@ def measure_and_update_result(): nonlocal best_res_mem nonlocal best_placement_mem nonlocal best_keys_mem - trc, keys, placements = get_placed_trace( - dict_time_strat, increasing_symbols) + trc, keys, placements = get_placed_trace(dict_time_strat, increasing_symbols) if self.trace_type == TraceType.BW and self.active_fw_trace is not None: - _, trc = rematerialize_forward_and_backward( - self.active_fw_trace, trc) + _, trc = rematerialize_forward_and_backward(self.active_fw_trace, trc) cost, mem, out = benchmark_trace(trc, self.benchmark_iters) del out - log( - f"Placed trace (cost = {cost} ms, mem = {mem/(2**30)} GB)\n{trc}", level=LogLevel.DEBUG) + log(f"Placed trace (cost = {cost} ms, mem = {mem/(2**30)} GB)\n{trc}", level=LogLevel.DEBUG) if cost < best_res_time.tm or (cost == best_res_time.tm and mem < best_res_time.mem): best_res_time.tm = cost best_res_time.mem = mem @@ -702,8 +720,7 @@ def measure_and_update_result(): if last_embedding_idx != -1: # Until last_embedding_idx (included) assigned to current fusion ex for i in range(0, last_embedding_idx + 1, 1): - match_bsym_output( - group[i], [dict_time_strat, dict_mem_strat], ex) + match_bsym_output(group[i], [dict_time_strat, dict_mem_strat], ex) if last_embedding_idx == len(group) - 1: # Benchmark @@ -712,19 +729,17 @@ def measure_and_update_result(): start_idx = last_embedding_idx + 1 n_missing_bsyms = len(group) - start_idx - for i in range(0, n_missing_bsyms, n_missing_bsyms-1 if self.trace_type == TraceType.BW else 1): - # for i in range(0, n_missing_bsyms): + # TODO (matteochen): consider to add the iteration with no fusion regions + for i in range(0, n_missing_bsyms, n_missing_bsyms - 1 if self.trace_type == TraceType.BW else 1): + # for i in range(0, n_missing_bsyms): # From top to bottom (this will include the whole region) # -> First iteration is the one with fusion region with single element # -> Last iteration gives the complete fusion region - for j in range(start_idx, start_idx + i + 1, increment_factor): - match_bsym_output( - group[j], [dict_time_strat, dict_mem_strat], ex) + match_bsym_output(group[j], [dict_time_strat, dict_mem_strat], ex) for k in range(start_idx + i + 1, len(group), increment_factor): match_bsym_output( - group[k], [dict_time_strat, dict_mem_strat], get_first_available_executor( - group[k]) + group[k], [dict_time_strat, dict_mem_strat], get_first_available_executor(group[k]) ) # Benchmark measure_and_update_result() @@ -748,11 +763,11 @@ def measure_and_update_result(): log( f"For group {group_id} best placement with time cost = {best_res_time.tm} ms:\n{best_res_time.trace}", - level=LogLevel.DEBUG + level=LogLevel.DEBUG, ) log( f"For group {group_id} best placement with mem cost = {best_res_mem.mem / (2**30)} GB:\n{best_res_mem.trace}", - level=LogLevel.DEBUG + level=LogLevel.DEBUG, ) # for n, p in zip(best_keys, best_placement): @@ -771,8 +786,7 @@ def measure_and_update_result(): for bsym in self.trace.bound_symbols: if bsym.sym.name == "return": if "return" not in dict_time_strat or "return" not in dict_mem_strat: - raise AssertionError( - f"Expected key return in mapping {dict_time_strat} and {dict_mem_strat}") + raise AssertionError(f"Expected key return in mapping {dict_time_strat} and {dict_mem_strat}") executors_time.append(dict_time_strat["return"]) executors_mem.append(dict_mem_strat["return"]) elif isinstance(bsym.output, Sequence): @@ -796,8 +810,7 @@ def measure_and_update_result(): executors_time.append(dict_time_strat[bsym.output.name]) executors_mem.append(dict_mem_strat[bsym.output.name]) else: - raise AssertionError( - f"Type not handled: {type(bsym.output)}") + raise AssertionError(f"Type not handled: {type(bsym.output)}") # For the forward trace we benchmark (memory) the mocked return statement as we don't know which # Tensor will be returned after the rematerialize_forward_and_backward() call in order to do not overestimate the memory consumption @@ -805,70 +818,106 @@ def measure_and_update_result(): trc = from_trace(self.trace) trc.bound_symbols = list(self.trace.bound_symbols) trc.bound_symbols.pop() - trc.bound_symbols.append( - self.trace.bound_symbols[-1].from_bsym(args=return_not_used_vars(trc))) + trc.bound_symbols.append(self.trace.bound_symbols[-1].from_bsym(args=return_not_used_vars(trc))) # NOTE: Here the active trace to place will be 'trc' and not 'self.trace' trc_time = self.place_optimizers(trc, executors_mem) c, m, o = benchmark_trace(trc_time, self.benchmark_iters) del o log(f"Debug MEM, mem = {m/(2**30)} GB:\n{trc_time}", level=LogLevel.DEBUG) - self.fusion_strat_helper.optimized_traces_mem_benchmark_only.append({ - ex.name: trc_time}) + self.fusion_strat_helper.optimized_traces_mem_benchmark_only.append({ex.name: trc_time}) trc_mem = self.place_optimizers(trc, executors_time) c, m, o = benchmark_trace(trc_mem, self.benchmark_iters) del o log(f"Debug TIME, time = {c} ms:\n{trc_mem}", level=LogLevel.DEBUG) - self.fusion_strat_helper.optimized_traces_time_benchmark_only.append({ - ex.name: trc_mem}) + self.fusion_strat_helper.optimized_traces_time_benchmark_only.append({ex.name: trc_mem}) else: trc = self.place_optimizers(self.trace, executors_mem) - _, trc = rematerialize_forward_and_backward( - self.active_fw_trace, trc) + _, trc = rematerialize_forward_and_backward(self.active_fw_trace, trc) c, m, o = benchmark_trace(trc, self.benchmark_iters) del o log(f"Debug MEM, mem = {m/(2**30)} GB:\n{trc}", level=LogLevel.DEBUG) - self.fusion_strat_helper.optimized_traces_mem_benchmark_only.append({ - ex.name: trc}) + self.fusion_strat_helper.optimized_traces_mem_benchmark_only.append({ex.name: trc}) trc = self.place_optimizers(self.trace, executors_time) - _, trc = rematerialize_forward_and_backward( - self.active_fw_trace, trc) + _, trc = rematerialize_forward_and_backward(self.active_fw_trace, trc) c, m, o = benchmark_trace(trc, self.benchmark_iters) del o log(f"Debug TIME, time = {c} ms:\n{trc}", level=LogLevel.DEBUG) - self.fusion_strat_helper.optimized_traces_time_benchmark_only.append({ - ex.name: trc}) + self.fusion_strat_helper.optimized_traces_time_benchmark_only.append({ex.name: trc}) # Save executors in order to generate real fw and bw trace with correct output with the placer - self.executor_placement_options.placement_options_time.append( - executors_time) - self.executor_placement_options.placement_options_mem.append( - executors_mem) + self.executor_placement_options.placement_options_time.append(executors_time) + self.executor_placement_options.placement_options_mem.append(executors_mem) + + # If executor specific compile option is activated we need to know where a specific + # trace does come from and the zip logic afterward can not be employed with self.fusion_executors list + self.fusion_executors_saved_for_later = [] + ex: FusionExecutor + for ex in self.fusion_executors: + if ex.name not in self.fusion_strat_helper.supported_executors: + raise AssertionError(f"Fusion operator not supported: {ex.name}") + + log(f"Searching best placement for fusion executor = {ex.name}", level=LogLevel.INFO) + + # We try to enable fusion specific compile options + ex_compile_opts = self.known_fusion_ex_compile_options.get(ex.name, []) + self.fusion_executors_saved_for_later.append(ex) + + search(ex) + # Always search with option disabled + + # Currently we are enabling one compile option at the time as testing all the permutations might need too much time. + # Consider implementing patters based on the executor under investingation + if ex_compile_opts: + for opt in ex_compile_opts: + op_in_trace: bool = self.op_in_trace(self.trace, opt.symbol_tag) + if op_in_trace: + # Search with option enabled + old_opt: bool | None = self.compile_data.compile_options.get(opt.fusion_tag, None) + # We test the inverse of the default one + new_opt = True if old_opt is None or old_opt is False else False + + # For nv_enable_bookend, by default is it defaulted to True, hence we try the False path + # https://github.com/Lightning-AI/lightning-thunder/blob/73b31f35ff95b08ceee7d5d5344127d619fd37fe/thunder/executors/nvfuserex_impl.py#L784 + if opt.fusion_tag == "nv_enable_bookend": + new_opt = False + + log( + f"Executor {ex.name} enabling compile option: {opt.fusion_tag} with value {new_opt}", + level=LogLevel.INFO, + ) + self.compile_data.compile_options[opt.fusion_tag] = new_opt + self.fusion_executors_saved_for_later.append(ex) + search(ex) + self.compile_data.compile_options[opt.fusion_tag] = old_opt def get_optimal_fw_traces(self) -> Sequence[TraceCtx]: if not self.cached_fw_traces: raise AssertionError("Failed to obtain optimal fw traces") - return [getattr(candidate, field) for candidate in self.cached_fw_traces.values() for field in ['best_time', 'best_mem']] + return [ + getattr(candidate, field) + for candidate in self.cached_fw_traces.values() + for field in ["best_time", "best_mem"] + ] def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: # This is agnostic from the optimization strat as results are both floats min_value: float = float("inf") ans: FinalOutputCandidates | None = None - log(f'Computing the best pair option (tot options = {len(self.out)})', level=LogLevel.INFO) + log(f"Computing the best pair option (tot options = {len(self.out)})", level=LogLevel.INFO) for pair in self.out: if pair.tot_cost < min_value: log(f"New best pair:\n{pair}", level=LogLevel.DEBUG) min_value = pair.tot_cost ans = pair if ans is None: - raise AssertionError('Best pair not found') + raise AssertionError("Best pair not found") fw = ans.fw c, m, _ = benchmark_trace(fw, iters=self.benchmark_iters) - log(f'Final pair fw: {c} ms - {m / (2**30)} GB\n{fw}', level=LogLevel.INFO) + log(f"Final pair fw: {c} ms - {m / (2**30)} GB\n{fw}", level=LogLevel.INFO) bw = ans.bw c, m, _ = benchmark_trace(bw, iters=self.benchmark_iters) - log(f'Final pair bw: {c} ms - {m / (2**30)} GB\n{bw}', level=LogLevel.INFO) - + log(f"Final pair bw: {c} ms - {m / (2**30)} GB\n{bw}", level=LogLevel.INFO) # To debug this: the traces that we will received in the remat call in should be the same as these and runtime should be in line with the best pair time. # The pairs above are traces with no remat call (in order to be called later on) but their tracking time are made with traces gone under the remat call @@ -878,7 +927,6 @@ def bsym_assigned(self, bsym: BoundSymbol) -> bool: return isinstance(bsym.sym.executor, OperatorExecutor) or isinstance(bsym.sym.executor, FusionExecutor) def benchmark_traces(self): - self.debug_msg += "Traces benchmarks:\n\n" # We cached every optimized fw traces as they might impact differently on the bw trace @@ -887,7 +935,9 @@ def fw_benchmark(): match self.optimization_algorithm: case OptimizationAlgorithm.BEST_FUSER: # The optimizator builds the results in order following the self.fusion_executors list order - for pair_time, pair_mem in zip(self.fusion_strat_helper.optimized_traces_time, self.fusion_strat_helper.optimized_traces_mem): + for pair_time, pair_mem in zip( + self.fusion_strat_helper.optimized_traces_time, self.fusion_strat_helper.optimized_traces_mem + ): # pair is a dict trc_time = list(pair_time.values())[0] trc_mem = list(pair_mem.values())[0] @@ -895,19 +945,18 @@ def fw_benchmark(): # TODO (matteochen): remove the benchmark here as will done later on the bw pass c, m, _ = benchmark_trace(trc_time, self.benchmark_iters) log( - f'Benchmark fw end: Trace = [{label}] (time = {c} ms, mem = {m / (2**30)} GB)":\n{trc_time}', level=LogLevel.INFO) - self.debug_msg += ( - f"Trace name = [{label}] - Target: TIME - Time = {c} ms - Mem = {m / (2**30)} GB\n{trc_time}\n\n" + f'Benchmark fw end: Trace = [{label}] (time = {c} ms, mem = {m / (2**30)} GB)":\n{trc_time}', + level=LogLevel.INFO, ) + self.debug_msg += f"Trace name = [{label}] - Target: TIME - Time = {c} ms - Mem = {m / (2**30)} GB\n{trc_time}\n\n" c, m, _ = benchmark_trace(trc_mem, self.benchmark_iters) log( - f'Benchmark fw end: Trace = [{label}] (time = {c} ms, mem = {m / (2**30)} GB)":\n{trc_mem}', level=LogLevel.INFO) - self.debug_msg += ( - f"Trace name = [{label}] - Target: MEM - Mem = {m / (2**30)} GB - Time = {c} ms\n{trc_mem}\n\n" + f'Benchmark fw end: Trace = [{label}] (time = {c} ms, mem = {m / (2**30)} GB)":\n{trc_mem}', + level=LogLevel.INFO, ) + self.debug_msg += f"Trace name = [{label}] - Target: MEM - Mem = {m / (2**30)} GB - Time = {c} ms\n{trc_mem}\n\n" - self.cached_fw_traces[label] = TraceCandidates(best_time = trc_time, best_mem = trc_mem) - + self.cached_fw_traces[label] = TraceCandidates(best_time=trc_time, best_mem=trc_mem) def bw_benchmark(): time_result = BenchmarkResult() @@ -919,11 +968,10 @@ def bw_benchmark(): label = list(pair.keys())[0] trace = list(pair.values())[0] trace_time, trace_mem, res = benchmark_trace(trace, self.benchmark_iters) - self.debug_msg += ( - f"Trace name = [{label}] - Target: TIME - Time = {trace_time} ms - Mem = {trace_mem / (2**30)} GB\n{trace}\n\n" - ) + self.debug_msg += f"Trace name = [{label}] - Target: TIME - Time = {trace_time} ms - Mem = {trace_mem / (2**30)} GB\n{trace}\n\n" log( - f'Benchmark trace (target TIME) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}', level=LogLevel.INFO + f'Benchmark trace (target TIME) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}', + level=LogLevel.INFO, ) if trace_time < time_result.tm: time_result.tm = trace_time @@ -940,11 +988,10 @@ def bw_benchmark(): trace_time, trace_mem, res = benchmark_trace(trace, self.benchmark_iters) del res - self.debug_msg += ( - f"Trace name = [{label}] - Target: MEM - Mem = {trace_mem / (2**30)} GB - Time = {trace_time} ms\n{trace}\n\n" - ) + self.debug_msg += f"Trace name = [{label}] - Target: MEM - Mem = {trace_mem / (2**30)} GB - Time = {trace_time} ms\n{trace}\n\n" log( - f'Benchmark trace (target MEM) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}', level=LogLevel.INFO + f'Benchmark trace (target MEM) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}', + level=LogLevel.INFO, ) if trace_mem < memory_result.mem: memory_result.tm = trace_time @@ -954,18 +1001,13 @@ def bw_benchmark(): memory_result.index = i log( - f'Benchmark end: Best trace time "{time_result.label} (time = {time_result.tm} ms)":\n{time_result.trace}', level=LogLevel.INFO) + f'Benchmark end: Best trace time "{time_result.label} (time = {time_result.tm} ms)":\n{time_result.trace}', + level=LogLevel.INFO, + ) log( - f'Benchmark end: Best trace mem "{memory_result.label} (mem = {memory_result.mem / (2 ** 30)} GB)":\n{memory_result.trace}', level=LogLevel.INFO) - - # TODO (matteochen): remove this - # log(f"Strat comparison: {self.trace_type}") - # c, m, o = benchmark_trace(tm.trace) - # del o - # log(f"best time: {c} ms, {m/(2**30)} GB") - # c, m, o = benchmark_trace(mem.trace) - # del o - # log(f"best mem: {c} ms, {m/(2**30)} GB") + f'Benchmark end: Best trace mem "{memory_result.label} (mem = {memory_result.mem / (2 ** 30)} GB)":\n{memory_result.trace}', + level=LogLevel.INFO, + ) # Here we have to recover the traces without the pass through remat in order to be compliant # with thunder flow as we might have request for no remat @@ -983,15 +1025,16 @@ def bw_benchmark(): # Now, finally build the pair fw and bw traces for the requested strat # The current fw trace is set by the caller and we take it as is. All current bw traces optimizations are made with the fw trace set by the caller - forward_time, forward_memory, _ = benchmark_trace( - self.active_fw_trace, self.benchmark_iters) + forward_time, forward_memory, _ = benchmark_trace(self.active_fw_trace, self.benchmark_iters) match self.optimizer_type: case OptimizerType.RUNTIME: - # Used the computed benchmark from above + # Use the computed benchmark from above if time_result.tm < memory_result.tm: log( - f"out candidate times from time res: (fw){forward_time} ms, (bw){time_result.tm} ms", level=LogLevel.INFO) + f"Output pair candidate for TIME strat from best_time res: (fw){forward_time} ms, (bw){time_result.tm} ms", + level=LogLevel.INFO, + ) self.out.append( FinalOutputCandidates( fw=self.active_fw_trace, @@ -1001,7 +1044,9 @@ def bw_benchmark(): ) else: log( - f"out candidate times from mem res: (fw){forward_time} ms, (bw){memory_result.tm} ms", level=LogLevel.INFO) + f"Output pair candidate for TIME strat from best_mem res: (fw){forward_time} ms, (bw){memory_result.tm} ms", + level=LogLevel.INFO, + ) self.out.append( FinalOutputCandidates( fw=self.active_fw_trace, @@ -1010,10 +1055,11 @@ def bw_benchmark(): ) ) case OptimizerType.MEMORY: - # Used the computed benchmark from above if time_result.mem < memory_result.mem: log( - f"out candidate mem from time res: (fw){forward_memory / (2**30)} GB (fw){forward_time} ms, (bw){time_result.mem / (2**30)} GB (bw){time_result.tm} ms", level=LogLevel.INFO) + f"Output pair candidate for MEM strat from best_time res: (fw){forward_memory / (2**30)} GB (fw){forward_time} ms, (bw){time_result.mem / (2**30)} GB (bw){time_result.tm} ms", + level=LogLevel.INFO, + ) self.out.append( FinalOutputCandidates( fw=self.active_fw_trace, @@ -1023,7 +1069,9 @@ def bw_benchmark(): ) else: log( - f"out candidate mem from time res: (fw){forward_memory / (2**30)} GB (fw){forward_time} ms, (bw){memory_result.mem / (2**30)} GB (bw){memory_result.tm} ms", level=LogLevel.INFO) + f"Output pair candidate for MEM from strat best_mem res: (fw){forward_memory / (2**30)} GB (fw){forward_time} ms, (bw){memory_result.mem / (2**30)} GB (bw){memory_result.tm} ms", + level=LogLevel.INFO, + ) self.out.append( FinalOutputCandidates( fw=self.active_fw_trace, @@ -1040,6 +1088,7 @@ def bw_benchmark(): if self.produce_log: import time + timestamp: str = str(time.time()) with open(f"{timestamp}-{self.log_file_name}", "w") as file: file.write(self.debug_msg) @@ -1086,7 +1135,14 @@ def is_possible_out(name: str): # TODO (matteochen): move into utils module def benchmark_trace( - trace: TraceCtx, iters: int = 1, show_func=False, apply_del_last_used=True, snapshot=False, snapshot_name="", nvsight: bool = False, nvsight_fn_name: str = "" + trace: TraceCtx, + iters: int = 1, + show_func=False, + apply_del_last_used=True, + snapshot=False, + snapshot_name="", + nvsight: bool = False, + nvsight_fn_name: str = "", ) -> tuple[float, float, Any]: from thunder.executors.passes import del_last_used import inspect @@ -1106,14 +1162,15 @@ def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, f # Benchmark torch.cuda.cudart().cudaProfilerStart() for i in range(iters): - torch.cuda.nvtx.range_push(f'{nvsight_fn_name}-iter{i}') + torch.cuda.nvtx.range_push(f"{nvsight_fn_name}-iter{i}") fn(*args) torch.cuda.nvtx.range_pop() torch.cuda.cudart().cudaProfilerStop() - return float('inf'), float('inf'), None + return float("inf"), float("inf"), None except Exception as e: import inspect + trc = inspect.getsource(fn) print(f"#NVSIGHT FN EXECUTION FAILED:\n{trc}") raise e @@ -1124,10 +1181,8 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl out = None torch.cuda.empty_cache() - start_events = [torch.cuda.Event( - enable_timing=True) for _ in range(iters)] - end_events = [torch.cuda.Event(enable_timing=True) - for _ in range(iters)] + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] # Warm up cycles for _ in range(warm_up_iters): @@ -1136,8 +1191,7 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl if snapshot: torch.cuda.memory._record_memory_history() fn(*args) - torch.cuda.memory._dump_snapshot( - snapshot_name + "_benchmark.pickle") + torch.cuda.memory._dump_snapshot(snapshot_name + "_benchmark.pickle") torch.cuda.memory._record_memory_history(enabled=None) # Benchmark stream = torch.cuda.current_stream() @@ -1151,14 +1205,12 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl fn(*args) end_events[i].record(stream) max_allocated_bytes = max( - max_allocated_bytes, torch.cuda.max_memory_allocated( - torch.cuda.current_device()) + max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device()) ) torch.cuda.synchronize() - times = [s.elapsed_time(e) - for s, e in zip(start_events, end_events)] - print(f'times: {times}') + times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + print(f"times: {times}") tot_time = sum(times) / iters return tot_time, max_allocated_bytes, out except Exception as e: @@ -1228,8 +1280,7 @@ def transform_input_tuple(t: tuple, level=0) -> tuple: elif e is None: res.append(None) else: - raise AssertionError( - f"Input arg type not recognized: {type(e)}") + raise AssertionError(f"Input arg type not recognized: {type(e)}") return tuple(res) def transform_tensor(arg: TensorProxy) -> torch.Tensor: @@ -1290,15 +1341,13 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: input_args.append(e) elif isinstance(arg, IntegerProxy): if arg.python_type is bool: - input_args.append( - False if arg.value is None else arg.value) + input_args.append(False if arg.value is None else arg.value) else: input_args.append(0 if arg.value is None else arg.value) elif isinstance(arg, FloatProxy): input_args.append(0.0 if arg.value is None else arg.value) else: - raise AssertionError( - f"Input arg type not recognized: {type(arg)}") + raise AssertionError(f"Input arg type not recognized: {type(arg)}") else: raise AssertionError("Unexpexcted args type") @@ -1325,14 +1374,16 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: # https://github.com/Lightning-AI/lightning-thunder/issues/664 # Seems that this patch never work ... print(f"Exception:\n{e}") - if "call_method UserDefinedObjectVariable(set) __contains__ [UserDefinedObjectVariable()] {}" in str(e) and not nvsight: + if ( + "call_method UserDefinedObjectVariable(set) __contains__ [UserDefinedObjectVariable()] {}" in str(e) + and not nvsight + ): print( "Executing with torch compile no full graph (this might still fail), see: https://github.com/Lightning-AI/lightning-thunder/issues/664" ) torch_compiled = torch.compile(executable, fullgraph=False) try: - t, m, answer = compute_time_cost_ms( - torch_compiled, iters, *input_args) + t, m, answer = compute_time_cost_ms(torch_compiled, iters, *input_args) except Exception as e: print(f"Compiled trace execution still failed:\n{e}") else: @@ -1341,4 +1392,3 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: reset_tracectx(trace_tok) return t, m, answer - diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 04c1095ade..de29387685 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -210,6 +210,7 @@ def split(): produce_log=True, visualizer=visualizer, optimizer_type=autotune_type, + compile_data=compile_data ) ) From 5be42dc9ddc05ba46bb67b4ac5220079a56bdc11 Mon Sep 17 00:00:00 2001 From: matteochen Date: Mon, 5 Aug 2024 13:48:59 +0300 Subject: [PATCH 027/171] Benchmarking different compile options for nvFuser (#8) From fceed7ef63e075340b2dd47fdefa6adbf03c689d Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 6 Aug 2024 11:11:31 +0300 Subject: [PATCH 028/171] Moved `runtime` or `memory` selection at the end of seach when all candidates are available (#9) --- examples/dev/nanogpt.py | 180 ++++++++++++++++++- thunder/backend_optimizer/optimizer.py | 240 ++++++++++++++++--------- thunder/executors/torch_autograd.py | 30 ++-- 3 files changed, 350 insertions(+), 100 deletions(-) diff --git a/examples/dev/nanogpt.py b/examples/dev/nanogpt.py index aab28ab332..b7a0b221be 100644 --- a/examples/dev/nanogpt.py +++ b/examples/dev/nanogpt.py @@ -6,7 +6,7 @@ warm_up_iters = 50 -def run(): +def run_time(): # ----------------------------------------------------------------------------- batch_size = 12 block_size = 1024 @@ -50,7 +50,159 @@ def run(): # model init gptconf = GPTConfig( block_size = block_size, # how far back does the model look? i.e. context size - n_layer = 2, n_head = 12, n_embd = 768, # size of the model + n_layer = 1, n_head = 12, n_embd = 768, # size of the model + dropout = 0, # for determinism + bias = bias, + ) + model = GPT(gptconf) + model.to(device) + + jmodel_def = thunder.jit(model) + # Currently sdpa does not work? + jmodel_auto = thunder.jit(model, autotune_type='runtime', executors = ['torchcompile', 'nvfuser', 'cudnn', 'torch', 'python']) + + # optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type) + + if compile: + print("Compiling model...") + model = torch.compile(model) # pytorch 2.0 + + if profile: + # useful docs on pytorch profiler: + # - tutorial https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html + # - api https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile + wait, warmup, active = 5, 5, 5 + num_steps = wait + warmup + active + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler('./bench_log'), + record_shapes=False, + profile_memory=False, + with_stack=False, # incurs an additional overhead, disable if not needed + with_flops=True, + with_modules=False, # only for torchscript models atm + ) as prof: + + models = [jmodel_def, jmodel_auto] + + for mod in models: + print('Profiling new model') + X, Y = get_batch('train') + for k in range(num_steps): + with ctx: + _, loss = model(X, Y) + X, Y = get_batch('train') + # optimizer.zero_grad(set_to_none=True) + loss.backward() + # optimizer.step() + lossf = loss.item() + print(f"{k}/{num_steps} loss: {lossf:.4f}") + + prof.step() # notify the profiler at end of each step + + else: + def measure(m, label): + # simple benchmarking + torch.cuda.synchronize() + + X, Y = get_batch('train') + for i in range(warm_up_iters): + with ctx: + _, loss = m(X, Y) + X, Y = get_batch('train') + loss.backward() + + iters = 5 + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + stream = torch.cuda.current_stream() + torch.cuda.synchronize() + X, Y = get_batch('train') + for i in range(iters): + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + start_events[i].record(stream) + with ctx: + _, loss = m(X, Y) + X, Y = get_batch('train') + loss.backward() + end_events[i].record(stream) + + torch.cuda.synchronize() + tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + tot_time = sum(tot) / iters + print('\n\nResults torch benchmark:') + print(f'{label} tot time: {tot_time} ms') + + measure(jmodel_auto, 'auto') + measure(jmodel_def, 'def') + + print('\n\nResults thunder benchmark:') + traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] + traces.reverse() + labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] + labels.reverse() + thunder_fw_bw_benchmark(traces, labels, 5) + + # X, Y = get_batch('train') + # out_eager = model(X, Y) + # out_def = jmodel_def(X, Y) + # out_auto = jmodel_auto(X, Y) + # for a, b in zip(out_eager, out_def): + # print('deviation def:', (a - b).abs().max().item()) + # for a, b in zip(out_eager, out_auto): + # print('deviation auto:', (a - b).abs().max().item()) + + # traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] + # for t in traces: + # print(f'{t}\n############################################') + +def run_memory(): + # ----------------------------------------------------------------------------- + batch_size = 12 + block_size = 1024 + bias = False + real_data = False + seed = 1337 + device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. + dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' + compile = False # use PyTorch 2.0 to compile the model to be faster + profile = False # use pytorch profiler, or just simple benchmarking? + # exec(open('configurator.py').read()) # overrides from command line or config file + # ----------------------------------------------------------------------------- + + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul + torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast + ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] + ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) + + # data loading init + if real_data: + raise RuntimeError('Not supported') + # dataset = 'openwebtext' + # data_dir = os.path.join('data', dataset) + # train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') + # def get_batch(split): + # data = train_data # note ignore split in benchmarking script + # ix = torch.randint(len(data) - block_size, (batch_size,)) + # x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) + # y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) + # x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) + # return x, y + else: + # alternatively, if fixed data is desired to not care about data loading + x = torch.randint(50304, (batch_size, block_size), device=device) + y = torch.randint(50304, (batch_size, block_size), device=device) + get_batch = lambda split: (x, y) + + # model init + gptconf = GPTConfig( + block_size = block_size, # how far back does the model look? i.e. context size + n_layer = 1, n_head = 12, n_embd = 768, # size of the model dropout = 0, # for determinism bias = bias, ) @@ -91,7 +243,7 @@ def run(): X, Y = get_batch('train') for k in range(num_steps): with ctx: - logits, loss = model(X, Y) + _, loss = model(X, Y) X, Y = get_batch('train') # optimizer.zero_grad(set_to_none=True) loss.backward() @@ -109,7 +261,7 @@ def measure(m, label): X, Y = get_batch('train') for i in range(warm_up_iters): with ctx: - logits, loss = m(X, Y) + _, loss = m(X, Y) X, Y = get_batch('train') loss.backward() @@ -124,7 +276,7 @@ def measure(m, label): torch.cuda._sleep(1_000_000) start_events[i].record(stream) with ctx: - logits, loss = m(X, Y) + _, loss = m(X, Y) X, Y = get_batch('train') loss.backward() end_events[i].record(stream) @@ -145,8 +297,18 @@ def measure(m, label): labels.reverse() thunder_fw_bw_benchmark(traces, labels, 5) - traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] - for t in traces: - print(f'{t}\n############################################') + # X, Y = get_batch('train') + # out_eager = model(X, Y) + # out_def = jmodel_def(X, Y) + # out_auto = jmodel_auto(X, Y) + # for a, b in zip(out_eager, out_def): + # print('deviation def:', (a - b).abs().max().item()) + # for a, b in zip(out_eager, out_auto): + # print('deviation auto:', (a - b).abs().max().item()) + + # traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] + # for t in traces: + # print(f'{t}\n############################################') -run() +run_memory() +run_time() diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 12d51bee98..73b177b67a 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -70,8 +70,8 @@ def iterable(self) -> tuple[TraceCtx | None, TraceCtx | None]: return self.best_time, self.best_mem -class FinalOutputCandidates: - def __init__(self, *, fw: TraceCtx, bw: TraceCtx, cost: float) -> None: +class OutputCandidate: + def __init__(self, *, fw: TraceCtx, bw: TraceCtx, cost: float = 0.0) -> None: self.fw: TraceCtx = fw self.bw: TraceCtx = bw self.tot_cost: float = cost @@ -153,7 +153,9 @@ def __init__( self.active_fw_trace: TraceCtx | None = None self.cached_fw_traces: dict[str | Hashable, TraceCandidates] = {} self.bw_trace_candidates: TraceCandidates = TraceCandidates() - self.out: list[FinalOutputCandidates] = [] + self.out_traces_candidates: list[OutputCandidate] = [] + self.best_pair_runtime: OutputCandidate + self.best_pair_memory: OutputCandidate # Strat fusion self.fusion_strat_helper: FusionStratHelper = FusionStratHelper() @@ -449,6 +451,38 @@ def match_optimizer_algorithm(): match_optimizer_algorithm() + from thunder.core.rematerialization import rematerialize_forward_and_backward + min_value_time: float = float("inf") + min_value_mem: float = float("inf") + best_pair_runtime: OutputCandidate + best_pair_memory: OutputCandidate + for pair in self.out_traces_candidates: + # Apply remat and select best trace pair + pair_cost_time = 0 + pair_cost_mem = 0 + remat_fw, remat_bw = rematerialize_forward_and_backward(pair.fw, pair.bw) + t, m, _ = benchmark_trace(remat_fw, iters=self.benchmark_iters) + log(f'Pair fw time: {t}, mem: {m}', level=LogLevel.INFO) + pair_cost_time = pair_cost_time + t + pair_cost_mem = pair_cost_mem + m + t, m, _ = benchmark_trace(remat_bw, iters=self.benchmark_iters) + log(f'Pair bw time: {t}, mem: {m}', level=LogLevel.INFO) + pair_cost_time = pair_cost_time + t + pair_cost_mem = pair_cost_mem + m + + if pair_cost_time < min_value_time: + best_pair_runtime = OutputCandidate(fw=remat_fw, bw=remat_bw, cost=pair_cost_time) + log(f"New best runtime pair:\n{best_pair_runtime}", level=LogLevel.INFO) + min_value_time = pair_cost_time + + if pair_cost_mem < min_value_mem: + best_pair_memory = OutputCandidate(fw=remat_fw, bw=remat_bw, cost=pair_cost_mem) + log(f"New best memory pair:\n{best_pair_memory}", level=LogLevel.INFO) + min_value_mem = pair_cost_mem + + self.best_pair_runtime = best_pair_runtime + self.best_pair_memory = best_pair_memory + def can_executor_execute(self, ex: Executor, bsym: BoundSymbol) -> bool: try: return ex.can_execute(bsym) @@ -900,28 +934,45 @@ def get_optimal_fw_traces(self) -> Sequence[TraceCtx]: ] def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: - # This is agnostic from the optimization strat as results are both floats - min_value: float = float("inf") - ans: FinalOutputCandidates | None = None - log(f"Computing the best pair option (tot options = {len(self.out)})", level=LogLevel.INFO) - for pair in self.out: - if pair.tot_cost < min_value: - log(f"New best pair:\n{pair}", level=LogLevel.DEBUG) - min_value = pair.tot_cost - ans = pair - if ans is None: - raise AssertionError("Best pair not found") - - fw = ans.fw - c, m, _ = benchmark_trace(fw, iters=self.benchmark_iters) - log(f"Final pair fw: {c} ms - {m / (2**30)} GB\n{fw}", level=LogLevel.INFO) - bw = ans.bw - c, m, _ = benchmark_trace(bw, iters=self.benchmark_iters) - log(f"Final pair bw: {c} ms - {m / (2**30)} GB\n{bw}", level=LogLevel.INFO) - - # To debug this: the traces that we will received in the remat call in should be the same as these and runtime should be in line with the best pair time. - # The pairs above are traces with no remat call (in order to be called later on) but their tracking time are made with traces gone under the remat call - return ans.fw, ans.bw + return (self.best_pair_runtime.fw, self.best_pair_runtime.bw) if self.optimizer_type == OptimizerType.RUNTIME else (self.best_pair_memory.fw, self.best_pair_memory.bw) + # from thunder.core.rematerialization import rematerialize_forward_and_backward + + # # This is agnostic from the optimization strat as results are both floats + # min_value: float = float("inf") + # ans: OutputCandidate | None = None + # log(f"Computing the best pair option (tot options = {len(self.out_traces_candidates)})", level=LogLevel.INFO) + # for pair in self.out_traces_candidates: + + # # Apply remat and select best trace pair + # pair_cost = 0 + # remat_fw, remat_bw = rematerialize_forward_and_backward(pair.fw, pair.bw) + # t, m, _ = benchmark_trace(remat_fw, iters=self.benchmark_iters) + # log(f'Pair fw time: {t}, mem: {m}', level=LogLevel.DEBUG) + # pair_cost = pair_cost + t if self.optimizer_type == OptimizerType.RUNTIME else m + # t, m, _ = benchmark_trace(remat_bw, iters=self.benchmark_iters) + # log(f'Pair bw time: {t}, mem: {m}', level=LogLevel.DEBUG) + # pair_cost = pair_cost + t if self.optimizer_type == OptimizerType.RUNTIME else m + # pair.fw = remat_fw + # pair.bw = remat_bw + # pair.tot_cost = pair_cost + + # if pair.tot_cost < min_value: + # log(f"New best pair:\n{pair}", level=LogLevel.DEBUG) + # min_value = pair.tot_cost + # ans = pair + # if ans is None: + # raise AssertionError("Best pair not found") + + # fw = ans.fw + # c, m, _ = benchmark_trace(fw, iters=self.benchmark_iters) + # log(f"Final candidate pair fw: {c} ms - {m / (2**30)} GB\n{fw}", level=LogLevel.INFO) + # bw = ans.bw + # c, m, _ = benchmark_trace(bw, iters=self.benchmark_iters) + # log(f"Final candidate pair bw: {c} ms - {m / (2**30)} GB\n{bw}", level=LogLevel.INFO) + + # # To debug this: the traces that we will received in the remat call in should be the same as these and runtime should be in line with the best pair time. + # # The pairs above are traces with no remat call (in order to be called later on) but their tracking time are made with traces gone under the remat call + # return ans.fw, ans.bw def bsym_assigned(self, bsym: BoundSymbol) -> bool: return isinstance(bsym.sym.executor, OperatorExecutor) or isinstance(bsym.sym.executor, FusionExecutor) @@ -929,7 +980,7 @@ def bsym_assigned(self, bsym: BoundSymbol) -> bool: def benchmark_traces(self): self.debug_msg += "Traces benchmarks:\n\n" - # We cached every optimized fw traces as they might impact differently on the bw trace + # We cache every optimized fw traces as they might impact differently on the bw trace # Number of fw traces to cached are: #fusion_executors * 2 def fw_benchmark(): match self.optimization_algorithm: @@ -1023,62 +1074,91 @@ def bw_benchmark(): log(self.bw_trace_candidates.__repr__(), level=LogLevel.DEBUG) - # Now, finally build the pair fw and bw traces for the requested strat + # Now, finally build the pair fw and bw traces # The current fw trace is set by the caller and we take it as is. All current bw traces optimizations are made with the fw trace set by the caller - forward_time, forward_memory, _ = benchmark_trace(self.active_fw_trace, self.benchmark_iters) - match self.optimizer_type: - case OptimizerType.RUNTIME: - # Use the computed benchmark from above - if time_result.tm < memory_result.tm: - log( - f"Output pair candidate for TIME strat from best_time res: (fw){forward_time} ms, (bw){time_result.tm} ms", - level=LogLevel.INFO, - ) - self.out.append( - FinalOutputCandidates( - fw=self.active_fw_trace, - bw=self.bw_trace_candidates.best_time, - cost=forward_time + time_result.tm, - ) - ) - else: - log( - f"Output pair candidate for TIME strat from best_mem res: (fw){forward_time} ms, (bw){memory_result.tm} ms", - level=LogLevel.INFO, - ) - self.out.append( - FinalOutputCandidates( - fw=self.active_fw_trace, - bw=self.bw_trace_candidates.best_mem, - cost=forward_time + memory_result.tm, - ) - ) - case OptimizerType.MEMORY: - if time_result.mem < memory_result.mem: - log( - f"Output pair candidate for MEM strat from best_time res: (fw){forward_memory / (2**30)} GB (fw){forward_time} ms, (bw){time_result.mem / (2**30)} GB (bw){time_result.tm} ms", - level=LogLevel.INFO, - ) - self.out.append( - FinalOutputCandidates( - fw=self.active_fw_trace, - bw=self.bw_trace_candidates.best_time, - cost=forward_memory + time_result.mem, - ) - ) - else: - log( - f"Output pair candidate for MEM from strat best_mem res: (fw){forward_memory / (2**30)} GB (fw){forward_time} ms, (bw){memory_result.mem / (2**30)} GB (bw){memory_result.tm} ms", - level=LogLevel.INFO, - ) - self.out.append( - FinalOutputCandidates( - fw=self.active_fw_trace, - bw=self.bw_trace_candidates.best_mem, - cost=forward_memory + memory_result.mem, - ) - ) + # forward_time, forward_memory, _ = benchmark_trace(self.active_fw_trace, self.benchmark_iters) + + # from thunder.core.rematerialization import rematerialize_forward_and_backward + + # min_value_time: float = float("inf") + # min_value_mem: float = float("inf") + # best_pair_runtime: OutputCandidate + # best_pair_memory: OutputCandidate + for bw in self.bw_trace_candidates.iterable(): + + self.out_traces_candidates.append(OutputCandidate(fw=self.active_fw_trace, bw=bw)) + + # Apply remat and select best trace pair + # pair_cost_time = 0 + # pair_cost_mem = 0 + # remat_fw, remat_bw = rematerialize_forward_and_backward(self.active_fw_trace, bw) + # t, m, _ = benchmark_trace(remat_fw, iters=self.benchmark_iters) + # log(f'Pair fw time: {t}, mem: {m}', level=LogLevel.INFO) + # pair_cost_time = pair_cost_time + t + # pair_cost_mem = pair_cost_mem + m + # t, m, _ = benchmark_trace(remat_bw, iters=self.benchmark_iters) + # log(f'Pair bw time: {t}, mem: {m}', level=LogLevel.INFO) + # pair_cost_time = pair_cost_time + t + # pair_cost_mem = pair_cost_mem + m + + # if pair_cost_time < min_value_time: + # best_pair_runtime = OutputCandidate(fw=remat_fw, bw=remat_bw, cost=pair_cost_time) + # log(f"New best runtime pair:\n{best_pair_runtime}", level=LogLevel.INFO) + # min_value_time = pair_cost_time + + # if pair_cost_mem < min_value_mem: + # best_pair_memory = OutputCandidate(fw=remat_fw, bw=remat_bw, cost=pair_cost_mem) + # log(f"New best memory pair:\n{best_pair_memory}", level=LogLevel.INFO) + # min_value_mem = pair_cost_mem + + # self.best_pair_runtime = best_pair_runtime + # self.best_pair_memory = best_pair_memory + + # log( + # f"Output pair candidate for TIME strat from best_time res: (fw){forward_time} ms, (bw){time_result.tm} ms", + # level=LogLevel.INFO, + # ) + # self.out_traces_candidates.append( + # OutputCandidate( + # fw=self.active_fw_trace, + # bw=self.bw_trace_candidates.best_time, + # cost=forward_time + time_result.tm, + # ) + # ) + # log( + # f"Output pair candidate for TIME strat from best_mem res: (fw){forward_time} ms, (bw){memory_result.tm} ms", + # level=LogLevel.INFO, + # ) + # self.out_traces_candidates.append( + # OutputCandidate( + # fw=self.active_fw_trace, + # bw=self.bw_trace_candidates.best_mem, + # cost=forward_time + memory_result.tm, + # ) + # ) + # log( + # f"Output pair candidate for MEM strat from best_time res: (fw){forward_memory / (2**30)} GB (fw){forward_time} ms, (bw){time_result.mem / (2**30)} GB (bw){time_result.tm} ms", + # level=LogLevel.INFO, + # ) + # self.out_traces_candidates.append( + # OutputCandidate( + # fw=self.active_fw_trace, + # bw=self.bw_trace_candidates.best_time, + # cost=forward_memory + time_result.mem, + # ) + # ) + # log( + # f"Output pair candidate for MEM from strat best_mem res: (fw){forward_memory / (2**30)} GB (fw){forward_time} ms, (bw){memory_result.mem / (2**30)} GB (bw){memory_result.tm} ms", + # level=LogLevel.INFO, + # ) + # self.out_traces_candidates.append( + # OutputCandidate( + # fw=self.active_fw_trace, + # bw=self.bw_trace_candidates.best_mem, + # cost=forward_memory + memory_result.mem, + # ) + # ) match self.trace_type: case TraceType.FW: diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index de29387685..bebd296ed7 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -253,17 +253,25 @@ def split(): bw_traces.append(bw_extrace) visualizer.set_bw_optimized_trace(bw_extrace) - # TODO Restore request for no rematerialization - # TODO (matteochen): remove these logs - c, m, _ = benchmark_trace(fw_extrace, iters=5) - log(f'before remat fw trace time = {c}, mem = {m}\n{fw_extrace}', level=LogLevel.INFO) - c, m, _ = benchmark_trace(bw_extrace, iters=5) - log(f'before remat bw trace time = {c}, mem = {m}', level=LogLevel.INFO) - fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) - c, m, _ = benchmark_trace(fw_extrace, iters=5) - log(f'after remat fw trace time = {c}, mem = {m}\n{fw_extrace}', level=LogLevel.INFO) - c, m, _ = benchmark_trace(bw_extrace, iters=5) - log(f'after remat bw trace time = {c}, mem = {m}', level=LogLevel.INFO) + if autotune_type is None: + # TODO Restore request for no rematerialization + # TODO (matteochen): remove these logs + c, m, _ = benchmark_trace(fw_extrace, iters=5) + log(f'before remat fw trace time = {c}, mem = {m}\n{fw_extrace}', level=LogLevel.INFO) + c, m, _ = benchmark_trace(bw_extrace, iters=5) + log(f'before remat bw trace time = {c}, mem = {m}', level=LogLevel.INFO) + fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) + c, m, _ = benchmark_trace(fw_extrace, iters=5) + log(f'after remat fw trace time = {c}, mem = {m}\n{fw_extrace}', level=LogLevel.INFO) + c, m, _ = benchmark_trace(bw_extrace, iters=5) + log(f'after remat bw trace time = {c}, mem = {m}', level=LogLevel.INFO) + # Autotuner has been taken care of remat + else: + # TODO (matteochen): remove this + c, m, _ = benchmark_trace(fw_extrace, iters=5) + log(f'after remat fw trace time = {c}, mem = {m}\n{fw_extrace}', level=LogLevel.INFO) + c, m, _ = benchmark_trace(bw_extrace, iters=5) + log(f'after remat bw trace time = {c}, mem = {m}', level=LogLevel.INFO) fw_traces.append(fw_extrace) bw_traces.append(bw_extrace) From 53da20c65616a6f7cd2511b6c49021ca47ea7de8 Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 6 Aug 2024 22:53:55 +0300 Subject: [PATCH 029/171] Refactoring autotune code (#10) --- examples/dev/nanogpt.py | 174 +--- thunder/backend_optimizer/optimizer.py | 1302 +++++++----------------- thunder/backend_optimizer/utils.py | 503 +++++++++ thunder/benchmarks/utils.py | 2 +- 4 files changed, 880 insertions(+), 1101 deletions(-) create mode 100644 thunder/backend_optimizer/utils.py diff --git a/examples/dev/nanogpt.py b/examples/dev/nanogpt.py index b7a0b221be..c034c5bc0e 100644 --- a/examples/dev/nanogpt.py +++ b/examples/dev/nanogpt.py @@ -6,7 +6,9 @@ warm_up_iters = 50 -def run_time(): +def run(target: str = 'runtime'): + if target != 'runtime' and target != 'memory': + raise AssertionError(f'Target {target} not supported. Only runtime and memory available') # ----------------------------------------------------------------------------- batch_size = 12 block_size = 1024 @@ -31,16 +33,6 @@ def run_time(): # data loading init if real_data: raise RuntimeError('Not supported') - # dataset = 'openwebtext' - # data_dir = os.path.join('data', dataset) - # train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') - # def get_batch(split): - # data = train_data # note ignore split in benchmarking script - # ix = torch.randint(len(data) - block_size, (batch_size,)) - # x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) - # y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) - # x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) - # return x, y else: # alternatively, if fixed data is desired to not care about data loading x = torch.randint(50304, (batch_size, block_size), device=device) @@ -50,7 +42,7 @@ def run_time(): # model init gptconf = GPTConfig( block_size = block_size, # how far back does the model look? i.e. context size - n_layer = 1, n_head = 12, n_embd = 768, # size of the model + n_layer = 4, n_head = 12, n_embd = 768, # size of the model dropout = 0, # for determinism bias = bias, ) @@ -59,161 +51,7 @@ def run_time(): jmodel_def = thunder.jit(model) # Currently sdpa does not work? - jmodel_auto = thunder.jit(model, autotune_type='runtime', executors = ['torchcompile', 'nvfuser', 'cudnn', 'torch', 'python']) - - # optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type) - - if compile: - print("Compiling model...") - model = torch.compile(model) # pytorch 2.0 - - if profile: - # useful docs on pytorch profiler: - # - tutorial https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html - # - api https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile - wait, warmup, active = 5, 5, 5 - num_steps = wait + warmup + active - with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], - schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1), - on_trace_ready=torch.profiler.tensorboard_trace_handler('./bench_log'), - record_shapes=False, - profile_memory=False, - with_stack=False, # incurs an additional overhead, disable if not needed - with_flops=True, - with_modules=False, # only for torchscript models atm - ) as prof: - - models = [jmodel_def, jmodel_auto] - - for mod in models: - print('Profiling new model') - X, Y = get_batch('train') - for k in range(num_steps): - with ctx: - _, loss = model(X, Y) - X, Y = get_batch('train') - # optimizer.zero_grad(set_to_none=True) - loss.backward() - # optimizer.step() - lossf = loss.item() - print(f"{k}/{num_steps} loss: {lossf:.4f}") - - prof.step() # notify the profiler at end of each step - - else: - def measure(m, label): - # simple benchmarking - torch.cuda.synchronize() - - X, Y = get_batch('train') - for i in range(warm_up_iters): - with ctx: - _, loss = m(X, Y) - X, Y = get_batch('train') - loss.backward() - - iters = 5 - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - stream = torch.cuda.current_stream() - torch.cuda.synchronize() - X, Y = get_batch('train') - for i in range(iters): - torch.cuda.empty_cache() - torch.cuda._sleep(1_000_000) - start_events[i].record(stream) - with ctx: - _, loss = m(X, Y) - X, Y = get_batch('train') - loss.backward() - end_events[i].record(stream) - - torch.cuda.synchronize() - tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] - tot_time = sum(tot) / iters - print('\n\nResults torch benchmark:') - print(f'{label} tot time: {tot_time} ms') - - measure(jmodel_auto, 'auto') - measure(jmodel_def, 'def') - - print('\n\nResults thunder benchmark:') - traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] - traces.reverse() - labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] - labels.reverse() - thunder_fw_bw_benchmark(traces, labels, 5) - - # X, Y = get_batch('train') - # out_eager = model(X, Y) - # out_def = jmodel_def(X, Y) - # out_auto = jmodel_auto(X, Y) - # for a, b in zip(out_eager, out_def): - # print('deviation def:', (a - b).abs().max().item()) - # for a, b in zip(out_eager, out_auto): - # print('deviation auto:', (a - b).abs().max().item()) - - # traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] - # for t in traces: - # print(f'{t}\n############################################') - -def run_memory(): - # ----------------------------------------------------------------------------- - batch_size = 12 - block_size = 1024 - bias = False - real_data = False - seed = 1337 - device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. - dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' - compile = False # use PyTorch 2.0 to compile the model to be faster - profile = False # use pytorch profiler, or just simple benchmarking? - # exec(open('configurator.py').read()) # overrides from command line or config file - # ----------------------------------------------------------------------------- - - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul - torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn - device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast - ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] - ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) - - # data loading init - if real_data: - raise RuntimeError('Not supported') - # dataset = 'openwebtext' - # data_dir = os.path.join('data', dataset) - # train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') - # def get_batch(split): - # data = train_data # note ignore split in benchmarking script - # ix = torch.randint(len(data) - block_size, (batch_size,)) - # x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) - # y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) - # x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) - # return x, y - else: - # alternatively, if fixed data is desired to not care about data loading - x = torch.randint(50304, (batch_size, block_size), device=device) - y = torch.randint(50304, (batch_size, block_size), device=device) - get_batch = lambda split: (x, y) - - # model init - gptconf = GPTConfig( - block_size = block_size, # how far back does the model look? i.e. context size - n_layer = 1, n_head = 12, n_embd = 768, # size of the model - dropout = 0, # for determinism - bias = bias, - ) - model = GPT(gptconf) - model.to(device) - - jmodel_def = thunder.jit(model) - # Currently sdpa does not work? - jmodel_auto = thunder.jit(model, autotune_type='memory', executors = ['torchcompile', 'nvfuser', 'cudnn', 'torch', 'python']) - - # optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type) + jmodel_auto = thunder.jit(model, autotune_type={target}, executors = ['torchcompile', 'nvfuser', 'cudnn', 'torch', 'python']) if compile: print("Compiling model...") @@ -245,9 +83,7 @@ def run_memory(): with ctx: _, loss = model(X, Y) X, Y = get_batch('train') - # optimizer.zero_grad(set_to_none=True) loss.backward() - # optimizer.step() lossf = loss.item() print(f"{k}/{num_steps} loss: {lossf:.4f}") diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 73b177b67a..98dd8319db 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -1,29 +1,33 @@ -from collections.abc import Callable, Sequence +from collections.abc import Sequence from enum import Enum -from itertools import chain -from thunder.core.dtypes import dtype, is_boolean_dtype +from thunder.backend_optimizer.utils import operation_in_trace from thunder.core.prims import PrimIDs -from thunder.core.utils import check, safe_map_flat -from thunder.core.baseutils import BoundSymbolInterface -from thunder.core.proxies import CollectionProxy, FloatProxy, IntegerProxy, Proxy, TensorProxy, variableify, Variable -from thunder.core.symbol import BoundSymbol, Symbol -from thunder.core.trace import from_trace, set_tracectx, reset_tracectx, get_tracectx, TraceCtx +from thunder.core.proxies import CollectionProxy, FloatProxy, IntegerProxy, TensorProxy +from thunder.core.symbol import BoundSymbol +from thunder.core.trace import from_trace, TraceCtx from thunder.executors.data_dependent_partition import Node from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_always_executors from thunder.visualizer.visualizer_helper import Visualizer -from typing import Any, Hashable -import thunder -import thunder.core.transforms as transforms -import torch +from typing import Hashable +from thunder.backend_optimizer.utils import benchmark_trace + # Currently this manages both time and memory class BenchmarkResult: - def __init__(self) -> None: - self.tm: float = float("inf") - self.mem: float = float("inf") - self.trace: TraceCtx | None = None - self.label: str | Hashable = "" - self.index = -1 + def __init__( + self, + *, + time: float = float("inf"), + memory: float = float("inf"), + trace: TraceCtx = TraceCtx(), + label: str | Hashable = "", + index: int = -1, + ) -> None: + self.runtime: float = time + self.memory: float = memory + self.trace: TraceCtx = trace + self.label: str | Hashable = label + self.index: int = index class OptimizerType(Enum): @@ -118,16 +122,14 @@ def __init__(self, fusion_tag: str, symbol_tag: str) -> None: self.symbol_tag = symbol_tag -class BackendOptimizer: - # from thunder.common import CompileData - +class FusionPlacer: def __init__( self, *, priority_executors: Sequence[Executor], - produce_log=True, + produce_log: bool = True, apply_bucketing_bw_trace: bool, - log_file_name="autotune_debug.log", + log_file_name: str, visualizer: Visualizer | None = None, optimizer_type: OptimizerType = OptimizerType.RUNTIME, compile_data, @@ -148,7 +150,6 @@ def __init__( self.produce_log: bool = produce_log self.optimizer_type: OptimizerType = optimizer_type - self.optimization_algorithm: OptimizationAlgorithm | None = None self.active_fw_trace: TraceCtx | None = None self.cached_fw_traces: dict[str | Hashable, TraceCandidates] = {} @@ -168,368 +169,175 @@ def __init__( self.compile_data = compile_data self.known_fusion_ex_compile_options: dict[str | Hashable, list[FusionCompileOptionsHelper]] = { - "nvfuser": [ - FusionCompileOptionsHelper("nv_enable_linear", "linear"), - FusionCompileOptionsHelper("nv_enable_matmul", "matmul"), - FusionCompileOptionsHelper("nv_enable_bookend", "bookend"), - ] + # "nvfuser": [ + # FusionCompileOptionsHelper("nv_enable_linear", "linear"), + # FusionCompileOptionsHelper("nv_enable_matmul", "matmul"), + # FusionCompileOptionsHelper("nv_enable_bookend", "bookend"), + # ] } - log("Executors:", level=LogLevel.INFO) - for e in self.executors: - log( - f"{e.name} -> is operator = {isinstance(e, OperatorExecutor)}, is fusion = {isinstance(e, FusionExecutor)}", - level=LogLevel.INFO, - ) - - class SearchNode: - def __init__(self, symbol: BoundSymbolInterface, idx: int) -> None: - self.symbol = symbol - self.idx = idx - - def attach_cached_fw_traces(self, cached_fw_traces: TraceCandidates, executor_name: str) -> None: - self.cached_fw_traces[executor_name] = cached_fw_traces - - def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): - from thunder.core.transform_common import dce + """ + ################################################## Internal methods ################################################## + """ - self.trace_type = trace_type - # dce for the backward trace will be passed afterwards - self.trace: TraceCtx = dce(trace) if trace_type == TraceType.FW else trace + def _best_runtime_and_memory_candidates(self, candidates): + from thunder.core.rematerialization import rematerialize_forward_and_backward + from thunder.backend_optimizer.utils import benchmark_trace + + min_value_time: float = float("inf") + min_value_mem: float = float("inf") + best_pair_runtime: OutputCandidate + best_pair_memory: OutputCandidate + for pair in candidates: + # Apply remat and select best trace pair + pair_cost_time = 0 + pair_cost_mem = 0 + remat_fw, remat_bw = rematerialize_forward_and_backward(pair.fw, pair.bw) + t, m, _ = benchmark_trace(remat_fw, iters=self.benchmark_iters) + log(f"Pair fw time: {t}, mem: {m}", level=LogLevel.INFO) + pair_cost_time = pair_cost_time + t + pair_cost_mem = pair_cost_mem + m + t, m, _ = benchmark_trace(remat_bw, iters=self.benchmark_iters) + log(f"Pair bw time: {t}, mem: {m}", level=LogLevel.INFO) + pair_cost_time = pair_cost_time + t + pair_cost_mem = pair_cost_mem + m + + if pair_cost_time < min_value_time: + best_pair_runtime = OutputCandidate(fw=remat_fw, bw=remat_bw, cost=pair_cost_time) + log(f"New best runtime pair:\n{best_pair_runtime}", level=LogLevel.INFO) + min_value_time = pair_cost_time + + if pair_cost_mem < min_value_mem: + best_pair_memory = OutputCandidate(fw=remat_fw, bw=remat_bw, cost=pair_cost_mem) + log(f"New best memory pair:\n{best_pair_memory}", level=LogLevel.INFO) + min_value_mem = pair_cost_mem + + return best_pair_runtime, best_pair_memory + + def _filter_candidates(self): + self.debug_msg += "Traces benchmarks:\n\n" - match self.trace_type: - case TraceType.FW: + # We cache every optimized fw traces as they might impact differently on the bw trace + # Number of fw traces to cached are: #fusion_executors * 2 + def fw_benchmark(): + # The optimizator builds the results in order following the self.fusion_executors list order + for pair_time, pair_mem in zip( + self.fusion_strat_helper.optimized_traces_time, self.fusion_strat_helper.optimized_traces_mem + ): + # pair is a dict + trc_time = list(pair_time.values())[0] + trc_mem = list(pair_mem.values())[0] + label = list(pair_time.keys())[0] + # TODO (matteochen): remove the benchmark here as will done later on the bw pass + c, m, _ = benchmark_trace(trc_time, self.benchmark_iters) log( - f"New forward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}", level=LogLevel.INFO + f'Benchmark fw end: Trace = [{label}] (time = {c} ms, mem = {m / (2**30)} GB)":\n{trc_time}', + level=LogLevel.INFO, ) - # TODO (matteochen): support bw trace optimization even though with no fw traces cached - case TraceType.BW: - if not self.cached_fw_traces: - raise AssertionError("Can not optimize backward traces before forward traces") + self.debug_msg += ( + f"Trace name = [{label}] - Target: TIME - Time = {c} ms - Mem = {m / (2**30)} GB\n{trc_time}\n\n" + ) + c, m, _ = benchmark_trace(trc_mem, self.benchmark_iters) log( - f"New backward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}", + f'Benchmark fw end: Trace = [{label}] (time = {c} ms, mem = {m / (2**30)} GB)":\n{trc_mem}', level=LogLevel.INFO, ) - - def place_optimizers(self, in_trace, executor_list: list[Executor]) -> TraceCtx: - from thunder.executors.passes import _transform_for_operator_executor_execution - - swapmap: dict[Variable, Proxy] = {} - - def restore_correct_args(trace_in: TraceCtx): - def args_eq(a, b) -> bool: - if len(a) != len(b): - return False - for obj_a, obj_b in zip(a, b): - if type(obj_a) == type(obj_b) and isinstance(obj_a, TensorProxy): - if obj_a.name != obj_b.name: - return False - elif type(obj_a) == type(obj_b) and not isinstance(obj_a, TensorProxy): - if obj_a != obj_b: - raise AssertionError(f"What do you want to do here:\nobj_a:\n{obj_a}\nobj_b:{obj_b}") - return True - - def clear(bsym: BoundSymbol, input): - size = len(bsym.subsymbols) - if size > 0: - for subsym in bsym.subsymbols: - if not args_eq(subsym.args, input): - subsym.args = tuple(list(input)) - clear(subsym, input) - - for bsym in trace_in.bound_symbols: - if isinstance(bsym.sym.executor, OperatorExecutor): - clear(bsym, bsym.args) - - def update_swapmap(o: Any, no: Any) -> None: - if isinstance(o, Proxy): - check( - isinstance(no, Proxy), - lambda: f"Expected an execution transform to produce outputs with the same type, but found {type(o)} and {type(no)}", + self.debug_msg += ( + f"Trace name = [{label}] - Target: MEM - Mem = {m / (2**30)} GB - Time = {c} ms\n{trc_mem}\n\n" ) - vo = variableify(o) - vno = variableify(no) - if vo == vno: - return - swapmap[vno] = o - - def preserve_bsym(bsym: BoundSymbol) -> Any: - trace: TraceCtx | None = get_tracectx() - if trace is None: - raise AssertionError("None trace context") - trace.scopes[-1].append(bsym) - for p in chain(bsym.flat_proxy_outs, bsym.flat_proxy_args): - trace.names.add(p.name) - return bsym.output - - def visit_helper(bsym: BoundSymbol, ex: Executor) -> None | bool: - if bsym.sym.python_impl is not None: - return None - - # We have mapped this at previous stages - if ex.name == self.empty_executor_hashable_placeholder: - return None - - execution_transform: None | Callable = ex.get_execution_transform(bsym.sym) - out: Any - if execution_transform is not None: - out = execution_transform(*bsym.args, **bsym.kwargs) - elif isinstance(ex, OperatorExecutor): - # Calls the operator executor's operation - op: Symbol | None = ex.implmap[bsym.sym.id].symbol - if op is None: - raise AssertionError("op is None") - out = op(*bsym.args, **bsym.kwargs) - elif isinstance(ex, FusionExecutor): - # Preserves the symbol as is (it will be handled in the fusion pass) - out = preserve_bsym(bsym) - else: - raise AssertionError("Unknown executor") - - safe_map_flat(update_swapmap, bsym.output, out) - - return True - - def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: - return transforms.VISIT_TYPE.NO_OP if visit_helper(bsym, ex) is None else transforms.VISIT_TYPE.REPLACE - - if len(executor_list) != len(in_trace.bound_symbols): - raise AssertionError("len(executor_list) != len(in_trace.bound_symbols)") - - # log(f'Visit transf') - # for n, e in zip(in_trace.bound_symbols, executor_list): - # print(f'{n.sym.name} -> {e.name}') - cached_subsymbols: dict[str, Sequence[BoundSymbol]] = {} - executor_mapping: dict[str, Executor] = {} - unique_fusion_executors = set() - - # Input should have equal length - if len(executor_list) != len(in_trace.bound_symbols): - raise AssertionError("len(executor_list) != len(extrace.bound_symbols)") - - for b, e in zip(in_trace.bound_symbols, executor_list): - if isinstance(e, FusionExecutor): - unique_fusion_executors.add(e) - if isinstance(b.output, TensorProxy): - executor_mapping[b.output.name] = e - - extrace = transforms.visitor_transform_paired(in_trace, visit, zip(in_trace.bound_symbols, executor_list)) - - # Restores original variables - bound_symbols: list[BoundSymbol] = [] - for bsym in extrace.bound_symbols: - nbsym: BoundSymbol = bsym.from_bsym_swap_proxies(swapmap) - bound_symbols.append(nbsym) - extrace.bound_symbols = bound_symbols - - for bsym in extrace.bound_symbols: - if isinstance(bsym.output, TensorProxy): - t_name = bsym.output.name - if t_name not in executor_mapping: - # Symbol added by the visitor - continue - # raise AssertionError('Failed to retrive key in mapping') - saved_ex = executor_mapping[t_name] - if isinstance(saved_ex, OperatorExecutor): - cached_subsymbols[t_name] = list(bsym.subsymbols) - # This will leave out these symbols from the fusion pass - bsym.subsymbols = [] - - # Perform fusion pass - for ex in unique_fusion_executors: - extrace = ex.fusion_pass(extrace) - - # Restore subsymbols - # TODO (matteochen): Improve this search - for k, v in cached_subsymbols.items(): - # Note some symbols may be cut out by the fusion pass -> CSE - # For example: - # a = 1 + 1 - # b = 1 + 1 - # c = a + b - # being replaced by c = a + a - for bsym in extrace.bound_symbols: - if isinstance(bsym.output, TensorProxy) and bsym.output.name == k: - bsym.subsymbols = v - - restore_correct_args(extrace) - - # Apply always executors - extrace = _transform_for_operator_executor_execution(extrace, self.always_executors) - - return extrace - - def op_in_trace(self, trace: TraceCtx, op: str) -> bool: - # Some optimizations are not available as symbols - always_true = set(["bookend"]) - - if op in always_true: - return True - for b in trace.bound_symbols: - if b.sym.name == op: - return True - return False - - def optimize(self, strat: OptimizationAlgorithm = OptimizationAlgorithm.BEST_FUSER): - from thunder.core.transform_common import dce - from thunder.executors.torch_autograd import update_bw_from_forward_optimization + self.cached_fw_traces[label] = TraceCandidates(best_time=trc_time, best_mem=trc_mem) - self.optimization_algorithm = strat + def bw_benchmark(): + time_result = BenchmarkResult() + memory_result = BenchmarkResult() - def optmize_best_fuser(): - # Reset fusion helpers - self.fusion_strat_helper = FusionStratHelper() - # Reset helpers data structures - self.executor_placement_options = ExecutorPlacementOptions() + # Find best trace for runtime + for i, pair in enumerate(self.fusion_strat_helper.optimized_traces_time_benchmark_only): + # Unpack the dict + label = list(pair.keys())[0] + trace = list(pair.values())[0] + trace_time, trace_mem, res = benchmark_trace(trace, self.benchmark_iters) + self.debug_msg += f"Trace name = [{label}] - Target: TIME - Time = {trace_time} ms - Mem = {trace_mem / (2**30)} GB\n{trace}\n\n" + log( + f'Benchmark trace (target TIME) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}', + level=LogLevel.INFO, + ) + if trace_time < time_result.runtime: + time_result = BenchmarkResult(time=trace_time, memory=trace_mem, trace=trace, label=label, index=i) - self.build_placement_options_best_fuser() + # Find best trace for memory + for i, pair in enumerate(self.fusion_strat_helper.optimized_traces_mem_benchmark_only): + # Unpack the dict + label = list(pair.keys())[0] + trace = list(pair.values())[0] - if len(self.executor_placement_options.placement_options_time) != len( - self.fusion_executors_saved_for_later - ): - raise AssertionError( - f"Unexpected time placement options size: {len(self.executor_placement_options.placement_options_time)}. Expected: {len(self.fusion_executors)}" - ) - if len(self.executor_placement_options.placement_options_mem) != len(self.fusion_executors_saved_for_later): - raise AssertionError( - f"Unexpected mem placement options size: {len(self.executor_placement_options.placement_options_mem)}. Expected: {len(self.fusion_executors)}" + trace_time, trace_mem, res = benchmark_trace(trace, self.benchmark_iters) + del res + self.debug_msg += f"Trace name = [{label}] - Target: MEM - Mem = {trace_mem / (2**30)} GB - Time = {trace_time} ms\n{trace}\n\n" + log( + f'Benchmark trace (target MEM) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}', + level=LogLevel.INFO, ) + if trace_mem < memory_result.memory: + memory_result = BenchmarkResult( + time=trace_time, memory=trace_mem, trace=trace, label=label, index=i + ) - for placement, ex in zip( - self.executor_placement_options.placement_options_time, self.fusion_executors_saved_for_later - ): - self.fusion_strat_helper.optimized_traces_time.append( - {ex.name: self.place_optimizers(self.trace, placement)} - ) - for placement, ex in zip( - self.executor_placement_options.placement_options_mem, self.fusion_executors_saved_for_later - ): - self.fusion_strat_helper.optimized_traces_mem.append( - {ex.name: self.place_optimizers(self.trace, placement)} - ) + log( + f'Benchmark end: Best trace time "{time_result.label} (time = {time_result.runtime} ms)":\n{time_result.trace}', + level=LogLevel.INFO, + ) + log( + f'Benchmark end: Best trace mem "{memory_result.label} (mem = {memory_result.memory / (2 ** 30)} GB)":\n{memory_result.trace}', + level=LogLevel.INFO, + ) - self.benchmark_traces() + # Here we have to recover the traces without the pass through remat in order to be compliant + # with thunder flow as we might have request for no remat + # Unpack dict + trc = list(self.fusion_strat_helper.optimized_traces_time[time_result.index].values())[0] + self.bw_trace_candidates.attach_best_time_candidate(trc) + trc = list(self.fusion_strat_helper.optimized_traces_mem[memory_result.index].values())[0] + self.bw_trace_candidates.attach_best_mem_candidate(trc) - def match_optimizer_algorithm(): - match self.optimization_algorithm: - case OptimizationAlgorithm.BEST_FUSER: - optmize_best_fuser() + # Now, finally build the pair fw and bw traces + # The current fw trace is set by the caller and we take it as is. All current bw traces optimizations are made with the fw trace set by the caller + for bw in self.bw_trace_candidates.iterable(): + self.out_traces_candidates.append(OutputCandidate(fw=self.active_fw_trace, bw=bw)) match self.trace_type: case TraceType.FW: - match_optimizer_algorithm() - # We have multiple cached optimized fw traces, find the best backward + fw_benchmark() case TraceType.BW: - # Cached the bw trace as we need to modify the input trace during the loop - cached_self_trace = from_trace(self.trace) - cached_self_trace.bound_symbols = list(self.trace.bound_symbols) - for label, candidate in self.cached_fw_traces.items(): - log(f"Backward optimization with fw from {label}", level=LogLevel.INFO) - fw_traces = candidate.iterable() - for trc in fw_traces: - # TODO (matteochen): unify below with the original block - - # Restore the original bw trace - self.trace = from_trace(cached_self_trace) - self.trace.bound_symbols = list(cached_self_trace.bound_symbols) - # Set the current active cached forward trace - self.active_fw_trace = trc - - log(f"Cached fw trace:\n{self.active_fw_trace}", level=LogLevel.DEBUG) - log(f"Input bw trace:\n{self.trace}", level=LogLevel.DEBUG) - - self.trace = update_bw_from_forward_optimization(fw=trc, bw=self.trace) - - if self.apply_bucketing_bw_trace: - from thunder.distributed.transforms import FSDPCommBucketing + bw_benchmark() - self.trace = FSDPCommBucketing.apply_bucketing_to_backward_trace(self.trace) + if self.produce_log: + import time + timestamp: str = str(time.time()) + with open(f"{timestamp}-{self.log_file_name}", "w") as file: + file.write(self.debug_msg) + file.close() - # Not called in the constructor for bw traces - dce(self.trace) + self.debug_msg = "" - match_optimizer_algorithm() - - from thunder.core.rematerialization import rematerialize_forward_and_backward - min_value_time: float = float("inf") - min_value_mem: float = float("inf") - best_pair_runtime: OutputCandidate - best_pair_memory: OutputCandidate - for pair in self.out_traces_candidates: - # Apply remat and select best trace pair - pair_cost_time = 0 - pair_cost_mem = 0 - remat_fw, remat_bw = rematerialize_forward_and_backward(pair.fw, pair.bw) - t, m, _ = benchmark_trace(remat_fw, iters=self.benchmark_iters) - log(f'Pair fw time: {t}, mem: {m}', level=LogLevel.INFO) - pair_cost_time = pair_cost_time + t - pair_cost_mem = pair_cost_mem + m - t, m, _ = benchmark_trace(remat_bw, iters=self.benchmark_iters) - log(f'Pair bw time: {t}, mem: {m}', level=LogLevel.INFO) - pair_cost_time = pair_cost_time + t - pair_cost_mem = pair_cost_mem + m - - if pair_cost_time < min_value_time: - best_pair_runtime = OutputCandidate(fw=remat_fw, bw=remat_bw, cost=pair_cost_time) - log(f"New best runtime pair:\n{best_pair_runtime}", level=LogLevel.INFO) - min_value_time = pair_cost_time - - if pair_cost_mem < min_value_mem: - best_pair_memory = OutputCandidate(fw=remat_fw, bw=remat_bw, cost=pair_cost_mem) - log(f"New best memory pair:\n{best_pair_memory}", level=LogLevel.INFO) - min_value_mem = pair_cost_mem - - self.best_pair_runtime = best_pair_runtime - self.best_pair_memory = best_pair_memory - - def can_executor_execute(self, ex: Executor, bsym: BoundSymbol) -> bool: - try: - return ex.can_execute(bsym) - except Exception: - return False - - # For each fusion executor in the input list, find the best trace dispatching for each executor - def build_placement_options_best_fuser(self, increment_factor: int = 1): + def _search_candidates(self, increment_factor: int = 1): from thunder.executors.data_dependent_partition import Node, fuse_bound_symbols from thunder.core.rematerialization import rematerialize_forward_and_backward - - def sequence_hash(s: Sequence) -> str: - name = "" - for e in s: - if ( - isinstance(e, CollectionProxy) - or isinstance(e, TensorProxy) - or isinstance(e, IntegerProxy) - or isinstance(e, FloatProxy) - ): - name += e.name - # TODO (matteochen): investigate if this is suitable - elif isinstance(e, int): - name += f"int{e}" - elif e is None: - name += "None" - else: - raise AssertionError(f"What? Maybe nested Sequence. type = {type(e)}") - return name - - # TODO (matteochen): Benchmark the optimal executor and call this optimal - def get_first_available_executor(bsym: BoundSymbol): - for ex in self.executors: - if isinstance(ex, FusionExecutor): - continue - if self.can_executor_execute(ex, bsym): - return ex - return Executor(name=self.empty_executor_hashable_placeholder) + from thunder.backend_optimizer.utils import ( + get_not_used_intermediate_outsputs, + sequence_hash, + can_executor_execute, + get_first_available_operator_executor, + assign_executors, + ) def get_placed_trace(mapping: dict[str, Executor], bound_symbols_in: Sequence[BoundSymbol]): - # log(f'Input mapping len = {len(mapping)}:') - # log(f'Input bound_symbols len = {len(bound_symbols_in)}:') trc = from_trace(self.trace) trc.bound_symbols = list(bound_symbols_in) # For this partial trace we have to return all not used tensors otherwise the dce will cut them out - tensors = return_not_used_vars(trc) + tensors = get_not_used_intermediate_outsputs(trc) forced_return_bsym = self.trace.bound_symbols[-1].from_bsym(args=tensors) @@ -568,16 +376,20 @@ def get_placed_trace(mapping: dict[str, Executor], bound_symbols_in: Sequence[Bo f"len trc.bound_symbols ({len(trc.bound_symbols)}) != len executor_configuration ({len(executor_configuration)}) != len keys ({len(keys)})" ) - # for b, e in zip(trc.bound_symbols, executor_configuration): - # if isinstance(b.output, TensorProxy): - # print(f'{b.sym.name}: {b.output.name} -> {e.name}') - - placed_trace = self.place_optimizers(trc, executor_configuration) + placed_trace = assign_executors( + in_trace=trc, + executor_list=executor_configuration, + always_executors=self.always_executors, + empty_str=self.empty_executor_hashable_placeholder, + ) return placed_trace, keys, executor_configuration def search(ex: FusionExecutor): - # Each executor has a custom should fuse function, but the current impl need to access local executor object + """ + Fusable fn definition for nvFuser + """ + # Each executor has a custom should fuse function, but the current impl need to access local executor object def _should_fuse_nvfuser(a: Node, b: Node): def _can_fuse_node(n: Node): # if already merged, then node can be fused @@ -590,6 +402,10 @@ def _can_fuse_node(n: Node): return _can_fuse_node(a) and _can_fuse_node(b) + """ + Fusable fn definition for torch.compile + """ + def _should_fuse_torchcompile(a: Node, b: Node): def _can_fuse_node(n: Node): if len(n.group_bsyms) > 1: @@ -641,15 +457,14 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): if len(group) < 2: current_bsym = group[0] log(f"--> Single group: {current_bsym.sym.name}", level=LogLevel.DEBUG) - name = current_bsym.sym.name # Filter out all possible candidates for the current symbol candidate_executors = [ ex for ex in self.executors - if self.can_executor_execute(ex, current_bsym) and not isinstance(ex, FusionExecutor) + if can_executor_execute(ex, current_bsym) and not isinstance(ex, FusionExecutor) ] - if name == "return": + if current_bsym.sym.id == PrimIDs.RETURN: dict_time_strat["return"] = Executor(name=self.empty_executor_hashable_placeholder) dict_mem_strat["return"] = Executor(name=self.empty_executor_hashable_placeholder) # Add the modified return statement at the end of the for loop @@ -679,30 +494,28 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): _, trc = rematerialize_forward_and_backward(self.active_fw_trace, trc) t, m, _ = benchmark_trace(trc, self.benchmark_iters) # Update results - if t < candidate_best_time.tm: - candidate_best_time.tm = t - candidate_best_time.index = i - - if m < candidate_best_mem.mem: - candidate_best_mem.mem = m - candidate_best_mem.index = i + if t < candidate_best_time.runtime: + candidate_best_time = BenchmarkResult(time=t, index=i) + if m < candidate_best_mem.memory: + candidate_best_mem = BenchmarkResult(memory=m, index=i) if candidate_best_time.index == -1 or candidate_best_mem.index == -1: raise AssertionError( - f"Failed to get optimal single trace region candidate. Available candidates for {name}:\n{candidate_executors}" + f"Failed to get optimal single trace region candidate. Available candidates for {current_bsym.sym.name}:\n{candidate_executors}" ) log( - f"Best time OperatorExecutor for single {name}: {candidate_executors[candidate_best_time.index].name}", + f"Best time OperatorExecutor for single {current_bsym.sym.name}: {candidate_executors[candidate_best_time.index].name}", level=LogLevel.DEBUG, ) log( - f"Best mem OperatorExecutor for single {name}: {candidate_executors[candidate_best_mem.index].name}", + f"Best mem OperatorExecutor for single {current_bsym.sym.name}: {candidate_executors[candidate_best_mem.index].name}", level=LogLevel.DEBUG, ) match_bsym_output(current_bsym, [dict_time_strat], candidate_executors[candidate_best_time.index]) match_bsym_output(current_bsym, [dict_mem_strat], candidate_executors[candidate_best_mem.index]) + # Go to next bsym group continue # Inside groups we should have alwasy tensors as out @@ -725,20 +538,14 @@ def measure_and_update_result(): trc, keys, placements = get_placed_trace(dict_time_strat, increasing_symbols) if self.trace_type == TraceType.BW and self.active_fw_trace is not None: _, trc = rematerialize_forward_and_backward(self.active_fw_trace, trc) - cost, mem, out = benchmark_trace(trc, self.benchmark_iters) - del out + cost, mem, _ = benchmark_trace(trc, self.benchmark_iters) log(f"Placed trace (cost = {cost} ms, mem = {mem/(2**30)} GB)\n{trc}", level=LogLevel.DEBUG) - if cost < best_res_time.tm or (cost == best_res_time.tm and mem < best_res_time.mem): - best_res_time.tm = cost - best_res_time.mem = mem - best_res_time.trace = trc + if cost < best_res_time.runtime or (cost == best_res_time.runtime and mem < best_res_time.memory): + best_res_time = BenchmarkResult(time=cost, memory=mem, trace=trc) best_placement_time = placements best_keys_time = keys - - if mem < best_res_mem.mem or (mem == best_res_mem.mem and cost < best_res_mem.tm): - best_res_mem.tm = cost - best_res_mem.mem = mem - best_res_mem.trace = trc + if mem < best_res_mem.memory or (mem == best_res_mem.memory and cost < best_res_mem.runtime): + best_res_mem = BenchmarkResult(time=cost, memory=mem, trace=trc) best_placement_mem = placements best_keys_mem = keys @@ -773,7 +580,13 @@ def measure_and_update_result(): match_bsym_output(group[j], [dict_time_strat, dict_mem_strat], ex) for k in range(start_idx + i + 1, len(group), increment_factor): match_bsym_output( - group[k], [dict_time_strat, dict_mem_strat], get_first_available_executor(group[k]) + group[k], + [dict_time_strat, dict_mem_strat], + get_first_available_operator_executor( + bsym=group[k], + executors=self.executors, + empty_hash=self.empty_executor_hashable_placeholder, + ), ) # Benchmark measure_and_update_result() @@ -796,17 +609,14 @@ def measure_and_update_result(): raise AssertionError("Failed to get best placement") log( - f"For group {group_id} best placement with time cost = {best_res_time.tm} ms:\n{best_res_time.trace}", + f"For group {group_id} best placement with time cost = {best_res_time.runtime} ms:\n{best_res_time.trace}", level=LogLevel.DEBUG, ) log( - f"For group {group_id} best placement with mem cost = {best_res_mem.mem / (2**30)} GB:\n{best_res_mem.trace}", + f"For group {group_id} best placement with mem cost = {best_res_mem.memory / (2**30)} GB:\n{best_res_mem.trace}", level=LogLevel.DEBUG, ) - # for n, p in zip(best_keys, best_placement): - # print(f'{n} -> {p.name}') - # Update our dict for n, p in zip(best_keys_time, best_placement_time): dict_time_strat |= {n: p} @@ -818,7 +628,7 @@ def measure_and_update_result(): executors_time = [] executors_mem = [] for bsym in self.trace.bound_symbols: - if bsym.sym.name == "return": + if bsym.sym.id == PrimIDs.RETURN: if "return" not in dict_time_strat or "return" not in dict_mem_strat: raise AssertionError(f"Expected key return in mapping {dict_time_strat} and {dict_mem_strat}") executors_time.append(dict_time_strat["return"]) @@ -847,36 +657,32 @@ def measure_and_update_result(): raise AssertionError(f"Type not handled: {type(bsym.output)}") # For the forward trace we benchmark (memory) the mocked return statement as we don't know which - # Tensor will be returned after the rematerialize_forward_and_backward() call in order to do not overestimate the memory consumption + # tensor will be returned after the rematerialize_forward_and_backward() call in order to do not underestimate the memory consumption + trace = self.trace if self.trace_type == TraceType.FW: - trc = from_trace(self.trace) - trc.bound_symbols = list(self.trace.bound_symbols) - trc.bound_symbols.pop() - trc.bound_symbols.append(self.trace.bound_symbols[-1].from_bsym(args=return_not_used_vars(trc))) - # NOTE: Here the active trace to place will be 'trc' and not 'self.trace' - trc_time = self.place_optimizers(trc, executors_mem) - c, m, o = benchmark_trace(trc_time, self.benchmark_iters) - del o - log(f"Debug MEM, mem = {m/(2**30)} GB:\n{trc_time}", level=LogLevel.DEBUG) - self.fusion_strat_helper.optimized_traces_mem_benchmark_only.append({ex.name: trc_time}) - trc_mem = self.place_optimizers(trc, executors_time) - c, m, o = benchmark_trace(trc_mem, self.benchmark_iters) - del o - log(f"Debug TIME, time = {c} ms:\n{trc_mem}", level=LogLevel.DEBUG) - self.fusion_strat_helper.optimized_traces_time_benchmark_only.append({ex.name: trc_mem}) - else: - trc = self.place_optimizers(self.trace, executors_mem) - _, trc = rematerialize_forward_and_backward(self.active_fw_trace, trc) - c, m, o = benchmark_trace(trc, self.benchmark_iters) - del o - log(f"Debug MEM, mem = {m/(2**30)} GB:\n{trc}", level=LogLevel.DEBUG) - self.fusion_strat_helper.optimized_traces_mem_benchmark_only.append({ex.name: trc}) - trc = self.place_optimizers(self.trace, executors_time) - _, trc = rematerialize_forward_and_backward(self.active_fw_trace, trc) - c, m, o = benchmark_trace(trc, self.benchmark_iters) - del o - log(f"Debug TIME, time = {c} ms:\n{trc}", level=LogLevel.DEBUG) - self.fusion_strat_helper.optimized_traces_time_benchmark_only.append({ex.name: trc}) + trace = from_trace(self.trace) + trace.bound_symbols = list(self.trace.bound_symbols) + trace.bound_symbols.pop() + trace.bound_symbols.append( + self.trace.bound_symbols[-1].from_bsym(args=get_not_used_intermediate_outsputs(trace)) + ) + # Save the optimal traces that we have found + for executors, container in zip( + [executors_mem, executors_time], + [ + self.fusion_strat_helper.optimized_traces_mem_benchmark_only, + self.fusion_strat_helper.optimized_traces_time_benchmark_only, + ], + ): + trc = assign_executors( + in_trace=trace, + executor_list=executors, + always_executors=self.always_executors, + empty_str=self.empty_executor_hashable_placeholder, + ) + if self.trace_type == TraceType.BW: + _, trc = rematerialize_forward_and_backward(self.active_fw_trace, trc) + container.append({ex.name: trc}) # Save executors in order to generate real fw and bw trace with correct output with the placer self.executor_placement_options.placement_options_time.append(executors_time) @@ -896,14 +702,14 @@ def measure_and_update_result(): ex_compile_opts = self.known_fusion_ex_compile_options.get(ex.name, []) self.fusion_executors_saved_for_later.append(ex) + # Always search with option disabled -> standard flow search(ex) - # Always search with option disabled # Currently we are enabling one compile option at the time as testing all the permutations might need too much time. # Consider implementing patters based on the executor under investingation if ex_compile_opts: for opt in ex_compile_opts: - op_in_trace: bool = self.op_in_trace(self.trace, opt.symbol_tag) + op_in_trace: bool = operation_in_trace(trace=self.trace, op=opt.symbol_tag) if op_in_trace: # Search with option enabled old_opt: bool | None = self.compile_data.compile_options.get(opt.fusion_tag, None) @@ -924,6 +730,10 @@ def measure_and_update_result(): search(ex) self.compile_data.compile_options[opt.fusion_tag] = old_opt + """ + ################################################## Public methods ################################################## + """ + def get_optimal_fw_traces(self) -> Sequence[TraceCtx]: if not self.cached_fw_traces: raise AssertionError("Failed to obtain optimal fw traces") @@ -934,541 +744,171 @@ def get_optimal_fw_traces(self) -> Sequence[TraceCtx]: ] def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: - return (self.best_pair_runtime.fw, self.best_pair_runtime.bw) if self.optimizer_type == OptimizerType.RUNTIME else (self.best_pair_memory.fw, self.best_pair_memory.bw) - # from thunder.core.rematerialization import rematerialize_forward_and_backward - - # # This is agnostic from the optimization strat as results are both floats - # min_value: float = float("inf") - # ans: OutputCandidate | None = None - # log(f"Computing the best pair option (tot options = {len(self.out_traces_candidates)})", level=LogLevel.INFO) - # for pair in self.out_traces_candidates: - - # # Apply remat and select best trace pair - # pair_cost = 0 - # remat_fw, remat_bw = rematerialize_forward_and_backward(pair.fw, pair.bw) - # t, m, _ = benchmark_trace(remat_fw, iters=self.benchmark_iters) - # log(f'Pair fw time: {t}, mem: {m}', level=LogLevel.DEBUG) - # pair_cost = pair_cost + t if self.optimizer_type == OptimizerType.RUNTIME else m - # t, m, _ = benchmark_trace(remat_bw, iters=self.benchmark_iters) - # log(f'Pair bw time: {t}, mem: {m}', level=LogLevel.DEBUG) - # pair_cost = pair_cost + t if self.optimizer_type == OptimizerType.RUNTIME else m - # pair.fw = remat_fw - # pair.bw = remat_bw - # pair.tot_cost = pair_cost - - # if pair.tot_cost < min_value: - # log(f"New best pair:\n{pair}", level=LogLevel.DEBUG) - # min_value = pair.tot_cost - # ans = pair - # if ans is None: - # raise AssertionError("Best pair not found") - - # fw = ans.fw - # c, m, _ = benchmark_trace(fw, iters=self.benchmark_iters) - # log(f"Final candidate pair fw: {c} ms - {m / (2**30)} GB\n{fw}", level=LogLevel.INFO) - # bw = ans.bw - # c, m, _ = benchmark_trace(bw, iters=self.benchmark_iters) - # log(f"Final candidate pair bw: {c} ms - {m / (2**30)} GB\n{bw}", level=LogLevel.INFO) - - # # To debug this: the traces that we will received in the remat call in should be the same as these and runtime should be in line with the best pair time. - # # The pairs above are traces with no remat call (in order to be called later on) but their tracking time are made with traces gone under the remat call - # return ans.fw, ans.bw - - def bsym_assigned(self, bsym: BoundSymbol) -> bool: - return isinstance(bsym.sym.executor, OperatorExecutor) or isinstance(bsym.sym.executor, FusionExecutor) - - def benchmark_traces(self): - self.debug_msg += "Traces benchmarks:\n\n" - - # We cache every optimized fw traces as they might impact differently on the bw trace - # Number of fw traces to cached are: #fusion_executors * 2 - def fw_benchmark(): - match self.optimization_algorithm: - case OptimizationAlgorithm.BEST_FUSER: - # The optimizator builds the results in order following the self.fusion_executors list order - for pair_time, pair_mem in zip( - self.fusion_strat_helper.optimized_traces_time, self.fusion_strat_helper.optimized_traces_mem - ): - # pair is a dict - trc_time = list(pair_time.values())[0] - trc_mem = list(pair_mem.values())[0] - label = list(pair_time.keys())[0] - # TODO (matteochen): remove the benchmark here as will done later on the bw pass - c, m, _ = benchmark_trace(trc_time, self.benchmark_iters) - log( - f'Benchmark fw end: Trace = [{label}] (time = {c} ms, mem = {m / (2**30)} GB)":\n{trc_time}', - level=LogLevel.INFO, - ) - self.debug_msg += f"Trace name = [{label}] - Target: TIME - Time = {c} ms - Mem = {m / (2**30)} GB\n{trc_time}\n\n" - c, m, _ = benchmark_trace(trc_mem, self.benchmark_iters) - log( - f'Benchmark fw end: Trace = [{label}] (time = {c} ms, mem = {m / (2**30)} GB)":\n{trc_mem}', - level=LogLevel.INFO, - ) - self.debug_msg += f"Trace name = [{label}] - Target: MEM - Mem = {m / (2**30)} GB - Time = {c} ms\n{trc_mem}\n\n" + return ( + (self.best_pair_runtime.fw, self.best_pair_runtime.bw) + if self.optimizer_type == OptimizerType.RUNTIME + else (self.best_pair_memory.fw, self.best_pair_memory.bw) + ) - self.cached_fw_traces[label] = TraceCandidates(best_time=trc_time, best_mem=trc_mem) + def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): + from thunder.core.transform_common import dce - def bw_benchmark(): - time_result = BenchmarkResult() - memory_result = BenchmarkResult() + self.trace_type = trace_type + # dce for the backward trace will be passed afterwards + self.trace: TraceCtx = dce(trace) if trace_type == TraceType.FW else trace - # Find best trace for runtime - for i, pair in enumerate(self.fusion_strat_helper.optimized_traces_time_benchmark_only): - # Unpack the dict - label = list(pair.keys())[0] - trace = list(pair.values())[0] - trace_time, trace_mem, res = benchmark_trace(trace, self.benchmark_iters) - self.debug_msg += f"Trace name = [{label}] - Target: TIME - Time = {trace_time} ms - Mem = {trace_mem / (2**30)} GB\n{trace}\n\n" + match self.trace_type: + case TraceType.FW: log( - f'Benchmark trace (target TIME) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}', - level=LogLevel.INFO, + f"New forward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}", level=LogLevel.INFO ) - if trace_time < time_result.tm: - time_result.tm = trace_time - time_result.mem = trace_mem - time_result.trace = trace - time_result.label = label - time_result.index = i - - # Find best trace for memory - for i, pair in enumerate(self.fusion_strat_helper.optimized_traces_mem_benchmark_only): - # Unpack the dict - label = list(pair.keys())[0] - trace = list(pair.values())[0] - - trace_time, trace_mem, res = benchmark_trace(trace, self.benchmark_iters) - del res - self.debug_msg += f"Trace name = [{label}] - Target: MEM - Mem = {trace_mem / (2**30)} GB - Time = {trace_time} ms\n{trace}\n\n" + # TODO (matteochen): support bw trace optimization even though with no fw traces cached + case TraceType.BW: + if not self.cached_fw_traces: + raise AssertionError("Can not optimize backward traces before forward traces") log( - f'Benchmark trace (target MEM) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}', + f"New backward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}", level=LogLevel.INFO, ) - if trace_mem < memory_result.mem: - memory_result.tm = trace_time - memory_result.mem = trace_mem - memory_result.trace = trace - memory_result.label = label - memory_result.index = i - - log( - f'Benchmark end: Best trace time "{time_result.label} (time = {time_result.tm} ms)":\n{time_result.trace}', - level=LogLevel.INFO, - ) - log( - f'Benchmark end: Best trace mem "{memory_result.label} (mem = {memory_result.mem / (2 ** 30)} GB)":\n{memory_result.trace}', - level=LogLevel.INFO, - ) - # Here we have to recover the traces without the pass through remat in order to be compliant - # with thunder flow as we might have request for no remat - match self.optimization_algorithm: - case OptimizationAlgorithm.BEST_FUSER: - # Unpack dict - trc = list(self.fusion_strat_helper.optimized_traces_time[time_result.index].values())[0] - self.bw_trace_candidates.attach_best_time_candidate(trc) + def optimize(self): + from thunder.core.transform_common import dce + from thunder.executors.torch_autograd import update_bw_from_forward_optimization + from thunder.backend_optimizer.utils import assign_executors - # Unpack dict - trc = list(self.fusion_strat_helper.optimized_traces_mem[memory_result.index].values())[0] - self.bw_trace_candidates.attach_best_mem_candidate(trc) + def _optimize(): + # Reset fusion helpers + self.fusion_strat_helper = FusionStratHelper() + # Reset helpers data structures + self.executor_placement_options = ExecutorPlacementOptions() - log(self.bw_trace_candidates.__repr__(), level=LogLevel.DEBUG) + self._search_candidates() - # Now, finally build the pair fw and bw traces - # The current fw trace is set by the caller and we take it as is. All current bw traces optimizations are made with the fw trace set by the caller + if len(self.executor_placement_options.placement_options_time) != len( + self.fusion_executors_saved_for_later + ): + raise AssertionError( + f"Unexpected time placement options size: {len(self.executor_placement_options.placement_options_time)}. Expected: {len(self.fusion_executors)}" + ) + if len(self.executor_placement_options.placement_options_mem) != len(self.fusion_executors_saved_for_later): + raise AssertionError( + f"Unexpected mem placement options size: {len(self.executor_placement_options.placement_options_mem)}. Expected: {len(self.fusion_executors)}" + ) - # forward_time, forward_memory, _ = benchmark_trace(self.active_fw_trace, self.benchmark_iters) + for placement, ex in zip( + self.executor_placement_options.placement_options_time, self.fusion_executors_saved_for_later + ): + self.fusion_strat_helper.optimized_traces_time.append( + { + ex.name: assign_executors( + in_trace=self.trace, + executor_list=placement, + always_executors=self.always_executors, + empty_str=self.empty_executor_hashable_placeholder, + ) + } + ) + for placement, ex in zip( + self.executor_placement_options.placement_options_mem, self.fusion_executors_saved_for_later + ): + self.fusion_strat_helper.optimized_traces_mem.append( + { + ex.name: assign_executors( + in_trace=self.trace, + executor_list=placement, + always_executors=self.always_executors, + empty_str=self.empty_executor_hashable_placeholder, + ) + } + ) + # Filter out the optimal candidates for the current serach iteration + self._filter_candidates() - # from thunder.core.rematerialization import rematerialize_forward_and_backward + match self.trace_type: + case TraceType.FW: + _optimize() + # We have multiple cached optimized fw traces, find the best backward + case TraceType.BW: + # Clear any previous results + self.out_traces_candidates = [] - # min_value_time: float = float("inf") - # min_value_mem: float = float("inf") - # best_pair_runtime: OutputCandidate - # best_pair_memory: OutputCandidate - for bw in self.bw_trace_candidates.iterable(): + # Cached the bw trace as we need to modify the input trace during the loop + cached_self_trace = from_trace(self.trace) + cached_self_trace.bound_symbols = list(self.trace.bound_symbols) + for label, candidate in self.cached_fw_traces.items(): + log(f"Backward optimization with fw from {label}", level=LogLevel.INFO) + fw_traces = candidate.iterable() + for trc in fw_traces: + # TODO (matteochen): unify below with the original block - self.out_traces_candidates.append(OutputCandidate(fw=self.active_fw_trace, bw=bw)) + # Restore the original bw trace + self.trace = from_trace(cached_self_trace) + self.trace.bound_symbols = list(cached_self_trace.bound_symbols) + # Set the current active cached forward trace + self.active_fw_trace = trc - # Apply remat and select best trace pair - # pair_cost_time = 0 - # pair_cost_mem = 0 - # remat_fw, remat_bw = rematerialize_forward_and_backward(self.active_fw_trace, bw) - # t, m, _ = benchmark_trace(remat_fw, iters=self.benchmark_iters) - # log(f'Pair fw time: {t}, mem: {m}', level=LogLevel.INFO) - # pair_cost_time = pair_cost_time + t - # pair_cost_mem = pair_cost_mem + m - # t, m, _ = benchmark_trace(remat_bw, iters=self.benchmark_iters) - # log(f'Pair bw time: {t}, mem: {m}', level=LogLevel.INFO) - # pair_cost_time = pair_cost_time + t - # pair_cost_mem = pair_cost_mem + m - - # if pair_cost_time < min_value_time: - # best_pair_runtime = OutputCandidate(fw=remat_fw, bw=remat_bw, cost=pair_cost_time) - # log(f"New best runtime pair:\n{best_pair_runtime}", level=LogLevel.INFO) - # min_value_time = pair_cost_time - - # if pair_cost_mem < min_value_mem: - # best_pair_memory = OutputCandidate(fw=remat_fw, bw=remat_bw, cost=pair_cost_mem) - # log(f"New best memory pair:\n{best_pair_memory}", level=LogLevel.INFO) - # min_value_mem = pair_cost_mem - - # self.best_pair_runtime = best_pair_runtime - # self.best_pair_memory = best_pair_memory - - # log( - # f"Output pair candidate for TIME strat from best_time res: (fw){forward_time} ms, (bw){time_result.tm} ms", - # level=LogLevel.INFO, - # ) - # self.out_traces_candidates.append( - # OutputCandidate( - # fw=self.active_fw_trace, - # bw=self.bw_trace_candidates.best_time, - # cost=forward_time + time_result.tm, - # ) - # ) - # log( - # f"Output pair candidate for TIME strat from best_mem res: (fw){forward_time} ms, (bw){memory_result.tm} ms", - # level=LogLevel.INFO, - # ) - # self.out_traces_candidates.append( - # OutputCandidate( - # fw=self.active_fw_trace, - # bw=self.bw_trace_candidates.best_mem, - # cost=forward_time + memory_result.tm, - # ) - # ) - # log( - # f"Output pair candidate for MEM strat from best_time res: (fw){forward_memory / (2**30)} GB (fw){forward_time} ms, (bw){time_result.mem / (2**30)} GB (bw){time_result.tm} ms", - # level=LogLevel.INFO, - # ) - # self.out_traces_candidates.append( - # OutputCandidate( - # fw=self.active_fw_trace, - # bw=self.bw_trace_candidates.best_time, - # cost=forward_memory + time_result.mem, - # ) - # ) - # log( - # f"Output pair candidate for MEM from strat best_mem res: (fw){forward_memory / (2**30)} GB (fw){forward_time} ms, (bw){memory_result.mem / (2**30)} GB (bw){memory_result.tm} ms", - # level=LogLevel.INFO, - # ) - # self.out_traces_candidates.append( - # OutputCandidate( - # fw=self.active_fw_trace, - # bw=self.bw_trace_candidates.best_mem, - # cost=forward_memory + memory_result.mem, - # ) - # ) + log(f"Cached fw trace:\n{self.active_fw_trace}", level=LogLevel.INFO) + log(f"Input bw trace:\n{self.trace}", level=LogLevel.DEBUG) - match self.trace_type: - case TraceType.FW: - fw_benchmark() - case TraceType.BW: - bw_benchmark() + self.trace = update_bw_from_forward_optimization(fw=trc, bw=self.trace) - if self.produce_log: - import time + if self.apply_bucketing_bw_trace: + from thunder.distributed.transforms import FSDPCommBucketing - timestamp: str = str(time.time()) - with open(f"{timestamp}-{self.log_file_name}", "w") as file: - file.write(self.debug_msg) - file.close() + self.trace = FSDPCommBucketing.apply_bucketing_to_backward_trace(self.trace) - self.debug_msg = "" + # Not called in the constructor for bw traces + dce(self.trace) + _optimize() -def return_not_used_vars(trace_in: TraceCtx) -> list[TensorProxy]: - def is_in_sequence(seq: Sequence[Any], t: TensorProxy): - for e in seq: - if isinstance(e, TensorProxy) and e.name == t.name: - return True - return False - - # Check if this naming is always valid - def is_possible_out(name: str): - if not name.startswith("t"): - return False - num = name[1:] - return num.isdigit() - - ans: list[TensorProxy] = [] - for b in trace_in.bound_symbols: - f = False - # Not a tensor - if not isinstance(b.output, TensorProxy): - continue - # Not a produced tensor - if not is_possible_out(b.output.name): - continue - for test in trace_in.bound_symbols: - if ( - test.args is not None - and (isinstance(test.args, tuple) or isinstance(test.args, list)) - and is_in_sequence(test.args, b.output) - ): - f = True - break - if not f: - ans.append(b.output) - return ans - - -# TODO (matteochen): move into utils module -def benchmark_trace( - trace: TraceCtx, - iters: int = 1, - show_func=False, - apply_del_last_used=True, - snapshot=False, - snapshot_name="", - nvsight: bool = False, - nvsight_fn_name: str = "", -) -> tuple[float, float, Any]: - from thunder.executors.passes import del_last_used - import inspect - - input_args = [] - - if trace.bound_symbols[-1].sym.id != PrimIDs.RETURN: - raise AssertionError("Missing return statement") - - def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, float, Any]: - try: - warm_up_iters = 50 - torch.cuda.empty_cache() - # Warm up cycles - for _ in range(warm_up_iters): - fn(*args) - # Benchmark - torch.cuda.cudart().cudaProfilerStart() - for i in range(iters): - torch.cuda.nvtx.range_push(f"{nvsight_fn_name}-iter{i}") - fn(*args) - torch.cuda.nvtx.range_pop() - torch.cuda.cudart().cudaProfilerStop() - - return float("inf"), float("inf"), None - except Exception as e: - import inspect - - trc = inspect.getsource(fn) - print(f"#NVSIGHT FN EXECUTION FAILED:\n{trc}") - raise e - - def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[float, float, Any]: - try: - warm_up_iters = 50 - out = None - torch.cuda.empty_cache() - - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - - # Warm up cycles - for _ in range(warm_up_iters): - fn(*args) - # Snapshot request - if snapshot: - torch.cuda.memory._record_memory_history() - fn(*args) - torch.cuda.memory._dump_snapshot(snapshot_name + "_benchmark.pickle") - torch.cuda.memory._record_memory_history(enabled=None) - # Benchmark - stream = torch.cuda.current_stream() - max_allocated_bytes = 0 - torch.cuda.synchronize() - for i in range(iters): - torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) - torch.cuda.empty_cache() - torch.cuda._sleep(1_000_000) - start_events[i].record(stream) - fn(*args) - end_events[i].record(stream) - max_allocated_bytes = max( - max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device()) + self.best_pair_runtime, self.best_pair_memory = self._best_runtime_and_memory_candidates( + self.out_traces_candidates ) - torch.cuda.synchronize() - times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] - print(f"times: {times}") - tot_time = sum(times) / iters - return tot_time, max_allocated_bytes, out - except Exception as e: - print(f"#FN EXECUTION FAILED:\n{repr}") - raise e - - def print_input_args(args, level=0, show_content=False): - for e in args: - if isinstance(e, tuple) or isinstance(e, list): - print_input_args(e, level=level + 1) - else: - print(f"level {level}", type(e)) - - # def print_trace_execution_output(out: Any, show_content=False): - # if isinstance(out, tuple): - # for e in out: - # print(f'{type(e)}') - # else: - # print(f'{type(out)}') - - # TODO (matteochen): convert this into dict - def thunder_to_torch_float_dtype(tp: dtype, byte: int) -> torch.dtype: - if byte == 1: - raise AssertionError("Not implmented: 8 bit float") - # Dispatch flaot 16 type 1 from type 2 - elif byte == 2: - if tp._name == thunder.bfloat16._name: - return torch.bfloat16 - else: - return torch.float16 - elif byte == 4: - return torch.float32 - elif byte == 8: - return torch.float64 - else: - raise AssertionError(f"Not supported byte = {byte}") - - # TODO (matteochen): convert this into dict - def thunder_to_torch_int_dtype(byte: int) -> torch.dtype: - if byte == 1: - return torch.int8 - elif byte == 2: - return torch.int16 - elif byte == 4: - return torch.int32 - elif byte == 8: - return torch.int64 - else: - raise AssertionError(f"Not supported byte = {byte}") - - # TODO (matteochen): use more appropriate mock int and float - def transform_input_tuple(t: tuple, level=0) -> tuple: - res = [] - for e in t: - if type(e) is tuple: - res.append(transform_input_tuple(e, level + 1)) - else: - if isinstance(e, TensorProxy): - res.append(transform_tensor(e)) - elif isinstance(e, IntegerProxy): - if e.python_type is bool: - res.append(False if e.value is None else e.value) - else: - res.append(0 if e.value is None else e.value) - elif isinstance(e, FloatProxy): - res.append(0.0 if e.value is None else e.value) - elif e is None: - res.append(None) - else: - raise AssertionError(f"Input arg type not recognized: {type(e)}") - return tuple(res) - - def transform_tensor(arg: TensorProxy) -> torch.Tensor: - from thunder.core.dtypes import is_float_dtype, is_signedinteger_dtype - - # TODO (matteochen): Missing parallel and fsdp handling... - # TODO (matteochen): Missing support for meta types ... - dtype = arg.dtype - shape = arg.shape - device = arg.device - requires_grad = arg.requires_grad - if dtype is not None and is_float_dtype(dtype): - torch_dtype = thunder_to_torch_float_dtype(dtype, dtype.bytes) - tensor: torch.Tensor = torch.randn( - shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad - ) - elif dtype is not None and is_signedinteger_dtype(dtype): - torch_dtype = thunder_to_torch_int_dtype(dtype.bytes) - tensor: torch.Tensor = torch.randint( - 0, 8, shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad - ) - elif dtype is not None and is_boolean_dtype(dtype): - # TODO (matteochen): maybe random? - tensor: torch.Tensor = torch.zeros( - *shape, dtype=torch.bool, device=device.device_str(), requires_grad=requires_grad + +class BackendOptimizer: + def __init__( + self, + *, + priority_executors: Sequence[Executor], + produce_log=True, + apply_bucketing_bw_trace: bool, + log_file_name="autotune_debug.log", + visualizer: Visualizer | None = None, + optimizer_type: OptimizerType = OptimizerType.RUNTIME, + optimizer_algorithm: OptimizationAlgorithm = OptimizationAlgorithm.BEST_FUSER, + compile_data, + ) -> None: + self.optimizer = ( + FusionPlacer( + priority_executors=priority_executors, + produce_log=produce_log, + apply_bucketing_bw_trace=apply_bucketing_bw_trace, + log_file_name=log_file_name, + visualizer=visualizer, + optimizer_type=optimizer_type, + compile_data=compile_data, ) - else: - raise AssertionError(f"dtype {dtype} not supported yet") - - return tensor - - # print(f'BENCHMARKING:\n{trace}') - # def p(args): - # for e in args: - # if not isinstance(e, Sequence): - # if isinstance(e, torch.Tensor): - # print(f'{e.size()}') - # else: - # try: - # print(f'{e.name} -> {e}') - # except: - # print(f'{e}') - # else: - # print('rec') - # p(e) - # p(trace.args) - # print('##################') - # p(input_args) - - # Can we remove this check? - # TODO (matteochen): use more appropriate mock int and float - if isinstance(trace.args, Sequence): - for arg in trace.args: - if isinstance(arg, tuple): - input_args.append(transform_input_tuple(arg)) - elif isinstance(arg, TensorProxy): - e = transform_tensor(arg) - input_args.append(e) - elif isinstance(arg, IntegerProxy): - if arg.python_type is bool: - input_args.append(False if arg.value is None else arg.value) - else: - input_args.append(0 if arg.value is None else arg.value) - elif isinstance(arg, FloatProxy): - input_args.append(0.0 if arg.value is None else arg.value) - else: - raise AssertionError(f"Input arg type not recognized: {type(arg)}") - else: - raise AssertionError("Unexpexcted args type") - - if apply_del_last_used: - trace = del_last_used(trace) - - trace_tok = set_tracectx(trace) - - # Obtain the python executable string - execuhtable_str = trace.python() - executable = trace.python_callable() - if show_func: - print(inspect.getsource(executable)) - - t = float("inf") - m = float("inf") - answer = None - try: - if nvsight: - t, m, answer = compute_time_cost_nvsight(executable, iters, *input_args) - else: - t, m, answer = compute_time_cost_ms(executable, execuhtable_str, iters, *input_args) - except Exception as e: - # https://github.com/Lightning-AI/lightning-thunder/issues/664 - # Seems that this patch never work ... - print(f"Exception:\n{e}") - if ( - "call_method UserDefinedObjectVariable(set) __contains__ [UserDefinedObjectVariable()] {}" in str(e) - and not nvsight - ): - print( - "Executing with torch compile no full graph (this might still fail), see: https://github.com/Lightning-AI/lightning-thunder/issues/664" + if optimizer_algorithm == OptimizationAlgorithm.BEST_FUSER + else None + ) + + log("Executors:", level=LogLevel.INFO) + for e in priority_executors: + log( + f"{e.name} -> is operator = {isinstance(e, OperatorExecutor)}, is fusion = {isinstance(e, FusionExecutor)}", + level=LogLevel.INFO, ) - torch_compiled = torch.compile(executable, fullgraph=False) - try: - t, m, answer = compute_time_cost_ms(torch_compiled, iters, *input_args) - except Exception as e: - print(f"Compiled trace execution still failed:\n{e}") - else: - print(f"Unknown exception occured:\n{e}") - finally: - reset_tracectx(trace_tok) - - return t, m, answer + + def optimize(self): + self.optimizer.optimize() + + def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): + self.optimizer.attach_trace(trace=trace, trace_type=trace_type) + + def get_optimal_fw_traces(self) -> Sequence[TraceCtx]: + return self.optimizer.get_optimal_fw_traces() + + def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: + return self.optimizer.get_optimal_fw_bw_traces() diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py new file mode 100644 index 0000000000..f05ed6634c --- /dev/null +++ b/thunder/backend_optimizer/utils.py @@ -0,0 +1,503 @@ +from collections.abc import Callable, Hashable, Sequence +from typing import Any +from thunder.core.dtypes import dtype +from thunder.core.prims import PrimIDs +from thunder.core.proxies import CollectionProxy, FloatProxy, IntegerProxy, Proxy, TensorProxy, Variable, variableify +from thunder.core.symbol import BoundSymbol, Symbol +from thunder.core.trace import TraceCtx, get_tracectx, reset_tracectx, set_tracectx +from thunder.extend import Executor, FusionExecutor, OperatorExecutor +from thunder.core.utils import check, safe_map_flat +import thunder.core.transforms as transforms +from itertools import chain +import torch +import thunder + +def sequence_hash(s: Sequence) -> str: + name = "" + for e in s: + if ( + isinstance(e, CollectionProxy) + or isinstance(e, TensorProxy) + or isinstance(e, IntegerProxy) + or isinstance(e, FloatProxy) + ): + name += e.name + # TODO (matteochen): investigate if this is suitable + elif isinstance(e, int): + name += f"int{e}" + elif e is None: + name += "None" + else: + raise AssertionError(f"What? Maybe nested Sequence. type = {type(e)}") + return name + +def can_executor_execute(ex: Executor, bsym: BoundSymbol) -> bool: + try: + return ex.can_execute(bsym) + except Exception: + return False + +def get_first_available_operator_executor(*, bsym: BoundSymbol, executors: Sequence[Executor], empty_hash: str = 'empty'): + for ex in executors: + if isinstance(ex, FusionExecutor): + continue + if can_executor_execute(ex, bsym): + return ex + return Executor(name=empty_hash) + + +def get_not_used_intermediate_outsputs(trace_in: TraceCtx) -> list[TensorProxy]: + def is_in_sequence(seq: Sequence[Any], t: TensorProxy): + for e in seq: + if isinstance(e, TensorProxy) and e.name == t.name: + return True + return False + + def is_possible_out(name: str): + if not name.startswith("t"): + return False + num = name[1:] + return num.isdigit() + + ans: list[TensorProxy] = [] + for b in trace_in.bound_symbols: + f = False + # Not a tensor + if not isinstance(b.output, TensorProxy): + continue + # Not a produced tensor + if not is_possible_out(b.output.name): + continue + for test in trace_in.bound_symbols: + if ( + test.args is not None + and (isinstance(test.args, tuple) or isinstance(test.args, list)) + and is_in_sequence(test.args, b.output) + ): + f = True + break + if not f: + ans.append(b.output) + return ans + +def assign_executors( + *, + in_trace: TraceCtx, + executor_list: list[Executor] | tuple[Executor, ...], + always_executors: list[Executor] | tuple[Executor, ...], + empty_str: str | Hashable, +) -> TraceCtx: + from thunder.executors.passes import _transform_for_operator_executor_execution + + swapmap: dict[Variable, Proxy] = {} + + def restore_correct_args(trace_in: TraceCtx): + def args_eq(a, b) -> bool: + if len(a) != len(b): + return False + for obj_a, obj_b in zip(a, b): + if type(obj_a) == type(obj_b) and isinstance(obj_a, TensorProxy): + if obj_a.name != obj_b.name: + return False + elif type(obj_a) == type(obj_b) and not isinstance(obj_a, TensorProxy): + if obj_a != obj_b: + raise AssertionError(f"What do you want to do here:\nobj_a:\n{obj_a}\nobj_b:{obj_b}") + return True + + def clear(bsym: BoundSymbol, input): + size = len(bsym.subsymbols) + if size > 0: + for subsym in bsym.subsymbols: + if not args_eq(subsym.args, input): + subsym.args = tuple(list(input)) + clear(subsym, input) + + for bsym in trace_in.bound_symbols: + if isinstance(bsym.sym.executor, OperatorExecutor): + clear(bsym, bsym.args) + + def update_swapmap(o: Any, no: Any) -> None: + if isinstance(o, Proxy): + check( + isinstance(no, Proxy), + lambda: f"Expected an execution transform to produce outputs with the same type, but found {type(o)} and {type(no)}", + ) + + vo = variableify(o) + vno = variableify(no) + if vo == vno: + return + swapmap[vno] = o + + def preserve_bsym(bsym: BoundSymbol) -> Any: + trace: TraceCtx | None = get_tracectx() + if trace is None: + raise AssertionError("None trace context") + trace.scopes[-1].append(bsym) + for p in chain(bsym.flat_proxy_outs, bsym.flat_proxy_args): + trace.names.add(p.name) + return bsym.output + + def visit_helper(bsym: BoundSymbol, ex: Executor) -> None | bool: + if bsym.sym.python_impl is not None: + return None + + # We have mapped this at previous stages + if ex.name == empty_str: + return None + + execution_transform: None | Callable = ex.get_execution_transform(bsym.sym) + out: Any + if execution_transform is not None: + out = execution_transform(*bsym.args, **bsym.kwargs) + elif isinstance(ex, OperatorExecutor): + # Calls the operator executor's operation + op: Symbol | None = ex.implmap[bsym.sym.id].symbol + if op is None: + raise AssertionError("op is None") + out = op(*bsym.args, **bsym.kwargs) + elif isinstance(ex, FusionExecutor): + # Preserves the symbol as is (it will be handled in the fusion pass) + out = preserve_bsym(bsym) + else: + raise AssertionError("Unknown executor") + + safe_map_flat(update_swapmap, bsym.output, out) + + return True + + def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: + return transforms.VISIT_TYPE.NO_OP if visit_helper(bsym, ex) is None else transforms.VISIT_TYPE.REPLACE + + if len(executor_list) != len(in_trace.bound_symbols): + raise AssertionError("len(executor_list) != len(in_trace.bound_symbols)") + + cached_subsymbols: dict[str, Sequence[BoundSymbol]] = {} + executor_mapping: dict[str, Executor] = {} + unique_fusion_executors = set() + + # Input should have equal length + if len(executor_list) != len(in_trace.bound_symbols): + raise AssertionError("len(executor_list) != len(extrace.bound_symbols)") + + for b, e in zip(in_trace.bound_symbols, executor_list): + if isinstance(e, FusionExecutor): + unique_fusion_executors.add(e) + if isinstance(b.output, TensorProxy): + executor_mapping[b.output.name] = e + + extrace = transforms.visitor_transform_paired(in_trace, visit, zip(in_trace.bound_symbols, executor_list)) + + # Restores original variables + bound_symbols: list[BoundSymbol] = [] + for bsym in extrace.bound_symbols: + nbsym: BoundSymbol = bsym.from_bsym_swap_proxies(swapmap) + bound_symbols.append(nbsym) + extrace.bound_symbols = bound_symbols + + for bsym in extrace.bound_symbols: + if isinstance(bsym.output, TensorProxy): + t_name = bsym.output.name + if t_name not in executor_mapping: + # Symbol added by the visitor + continue + # raise AssertionError('Failed to retrive key in mapping') + saved_ex = executor_mapping[t_name] + if isinstance(saved_ex, OperatorExecutor): + cached_subsymbols[t_name] = list(bsym.subsymbols) + # This will leave out these symbols from the fusion pass + bsym.subsymbols = [] + + # Perform fusion pass + for ex in unique_fusion_executors: + extrace = ex.fusion_pass(extrace) + + # Restore subsymbols + # TODO (matteochen): Improve this search + for k, v in cached_subsymbols.items(): + # Note some symbols may be cut out by the fusion pass -> CSE + # For example: + # a = 1 + 1 + # b = 1 + 1 + # c = a + b + # being replaced by c = a + a + for bsym in extrace.bound_symbols: + if isinstance(bsym.output, TensorProxy) and bsym.output.name == k: + bsym.subsymbols = v + + restore_correct_args(extrace) + + # Apply always executors + extrace = _transform_for_operator_executor_execution(extrace, always_executors) + + return extrace + +def operation_in_trace(*, trace: TraceCtx, op: str) -> bool: + # Some optimizations are not available as symbols + always_true = set(["bookend"]) + + if op in always_true: + return True + for b in trace.bound_symbols: + if b.sym.name == op: + return True + return False + +def benchmark_trace( + trace: TraceCtx, + iters: int = 1, + show_func=False, + apply_del_last_used=True, + snapshot=False, + snapshot_name="", + nvsight: bool = False, + nvsight_fn_name: str = "", +) -> tuple[float, float, Any]: + from thunder.executors.passes import del_last_used + import inspect + + input_args = [] + + if trace.bound_symbols[-1].sym.id != PrimIDs.RETURN: + raise AssertionError("Missing return statement") + + def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, float, Any]: + try: + warm_up_iters = 50 + torch.cuda.empty_cache() + # Warm up cycles + for _ in range(warm_up_iters): + fn(*args) + # Benchmark + torch.cuda.cudart().cudaProfilerStart() + for i in range(iters): + torch.cuda.nvtx.range_push(f"{nvsight_fn_name}-iter{i}") + fn(*args) + torch.cuda.nvtx.range_pop() + torch.cuda.cudart().cudaProfilerStop() + + return float("inf"), float("inf"), None + except Exception as e: + import inspect + + trc = inspect.getsource(fn) + print(f"#NVSIGHT FN EXECUTION FAILED:\n{trc}") + raise e + + def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[float, float, Any]: + try: + warm_up_iters = 50 + out = None + torch.cuda.empty_cache() + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + + # Warm up cycles + for _ in range(warm_up_iters): + fn(*args) + # Snapshot request + if snapshot: + torch.cuda.memory._record_memory_history() + fn(*args) + torch.cuda.memory._dump_snapshot(snapshot_name + "_benchmark.pickle") + torch.cuda.memory._record_memory_history(enabled=None) + # Benchmark + stream = torch.cuda.current_stream() + max_allocated_bytes = 0 + torch.cuda.synchronize() + for i in range(iters): + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + start_events[i].record(stream) + fn(*args) + end_events[i].record(stream) + max_allocated_bytes = max( + max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device()) + ) + + torch.cuda.synchronize() + times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + print(f"times: {times}") + tot_time = sum(times) / iters + return tot_time, max_allocated_bytes, out + except Exception as e: + print(f"#FN EXECUTION FAILED:\n{repr}") + raise e + + def print_input_args(args, level=0, show_content=False): + for e in args: + if isinstance(e, tuple) or isinstance(e, list): + print_input_args(e, level=level + 1) + else: + print(f"level {level}", type(e)) + + # def print_trace_execution_output(out: Any, show_content=False): + # if isinstance(out, tuple): + # for e in out: + # print(f'{type(e)}') + # else: + # print(f'{type(out)}') + + # TODO (matteochen): convert this into dict + def thunder_to_torch_float_dtype(tp: dtype, byte: int) -> torch.dtype: + if byte == 1: + raise AssertionError("Not implmented: 8 bit float") + # Dispatch flaot 16 type 1 from type 2 + elif byte == 2: + if tp._name == thunder.bfloat16._name: + return torch.bfloat16 + else: + return torch.float16 + elif byte == 4: + return torch.float32 + elif byte == 8: + return torch.float64 + else: + raise AssertionError(f"Not supported byte = {byte}") + + # TODO (matteochen): convert this into dict + def thunder_to_torch_int_dtype(byte: int) -> torch.dtype: + if byte == 1: + return torch.int8 + elif byte == 2: + return torch.int16 + elif byte == 4: + return torch.int32 + elif byte == 8: + return torch.int64 + else: + raise AssertionError(f"Not supported byte = {byte}") + + # TODO (matteochen): use more appropriate mock int and float + def transform_input_tuple(t: tuple, level=0) -> tuple: + res = [] + for e in t: + if type(e) is tuple: + res.append(transform_input_tuple(e, level + 1)) + else: + if isinstance(e, TensorProxy): + res.append(transform_tensor(e)) + elif isinstance(e, IntegerProxy): + if e.python_type is bool: + res.append(False if e.value is None else e.value) + else: + res.append(0 if e.value is None else e.value) + elif isinstance(e, FloatProxy): + res.append(0.0 if e.value is None else e.value) + elif e is None: + res.append(None) + else: + raise AssertionError(f"Input arg type not recognized: {type(e)}") + return tuple(res) + + def transform_tensor(arg: TensorProxy) -> torch.Tensor: + from thunder.core.dtypes import is_float_dtype, is_signedinteger_dtype, is_boolean_dtype + + # TODO (matteochen): Missing parallel and fsdp handling... + # TODO (matteochen): Missing support for meta types ... + dtype = arg.dtype + shape = arg.shape + device = arg.device + requires_grad = arg.requires_grad + if dtype is not None and is_float_dtype(dtype): + torch_dtype = thunder_to_torch_float_dtype(dtype, dtype.bytes) + tensor: torch.Tensor = torch.randn( + shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad + ) + elif dtype is not None and is_signedinteger_dtype(dtype): + torch_dtype = thunder_to_torch_int_dtype(dtype.bytes) + tensor: torch.Tensor = torch.randint( + 0, 8, shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad + ) + elif dtype is not None and is_boolean_dtype(dtype): + # TODO (matteochen): maybe random? + tensor: torch.Tensor = torch.zeros( + *shape, dtype=torch.bool, device=device.device_str(), requires_grad=requires_grad + ) + else: + raise AssertionError(f"dtype {dtype} not supported yet") + + return tensor + + # print(f'BENCHMARKING:\n{trace}') + # def p(args): + # for e in args: + # if not isinstance(e, Sequence): + # if isinstance(e, torch.Tensor): + # print(f'{e.size()}') + # else: + # try: + # print(f'{e.name} -> {e}') + # except: + # print(f'{e}') + # else: + # print('rec') + # p(e) + # p(trace.args) + # print('##################') + # p(input_args) + + # Can we remove this check? + # TODO (matteochen): use more appropriate mock int and float + if isinstance(trace.args, Sequence): + for arg in trace.args: + if isinstance(arg, tuple): + input_args.append(transform_input_tuple(arg)) + elif isinstance(arg, TensorProxy): + e = transform_tensor(arg) + input_args.append(e) + elif isinstance(arg, IntegerProxy): + if arg.python_type is bool: + input_args.append(False if arg.value is None else arg.value) + else: + input_args.append(0 if arg.value is None else arg.value) + elif isinstance(arg, FloatProxy): + input_args.append(0.0 if arg.value is None else arg.value) + else: + raise AssertionError(f"Input arg type not recognized: {type(arg)}") + else: + raise AssertionError("Unexpexcted args type") + + if apply_del_last_used: + trace = del_last_used(trace) + + trace_tok = set_tracectx(trace) + + # Obtain the python executable string + executable_str = trace.python() + executable = trace.python_callable() + if show_func: + print(inspect.getsource(executable)) + + t = float("inf") + m = float("inf") + answer = None + try: + if nvsight: + t, m, answer = compute_time_cost_nvsight(executable, iters, *input_args) + else: + t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) + except Exception as e: + # https://github.com/Lightning-AI/lightning-thunder/issues/664 + # Seems that this patch never work ... + print(f"Exception:\n{e}") + if ( + "call_method UserDefinedObjectVariable(set) __contains__ [UserDefinedObjectVariable()] {}" in str(e) + and not nvsight + ): + print( + "Executing with torch compile no full graph (this might still fail), see: https://github.com/Lightning-AI/lightning-thunder/issues/664" + ) + torch_compiled = torch.compile(executable, fullgraph=False) + try: + t, m, answer = compute_time_cost_ms(torch_compiled, executable_str, iters, *input_args) + except Exception as e: + print(f"Compiled trace execution still failed:\n{e}") + else: + print(f"Unknown exception occured:\n{e}") + finally: + reset_tracectx(trace_tok) + + return t, m, answer diff --git a/thunder/benchmarks/utils.py b/thunder/benchmarks/utils.py index 8ea9a90f6a..54d0806e11 100644 --- a/thunder/benchmarks/utils.py +++ b/thunder/benchmarks/utils.py @@ -1,5 +1,5 @@ import torch -from thunder.backend_optimizer.optimizer import benchmark_trace +from thunder.backend_optimizer.utils import benchmark_trace warm_up_iters = 50 From 80a11e66bbead1267dabf8de690f7a232bae96e5 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 7 Aug 2024 18:44:39 +0300 Subject: [PATCH 030/171] Enlarged benchmark iters and use no remat traces a chance to be the optimal (#11) --- examples/dev/litGPT.py | 6 +++--- thunder/backend_optimizer/optimizer.py | 24 ++++++++++++++++++++++- thunder/backend_optimizer/utils.py | 27 +++++++++++++------------- 3 files changed, 40 insertions(+), 17 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 0668921d5a..56f3430839 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -9,9 +9,9 @@ def __init__(self, layers: int, autotune_type: str) -> None: self.layers = layers self.autotune_type = autotune_type -layers = [Test(8, 'runtime'), Test(8, 'runtime'), Test(16, 'runtime')] +layers = [Test(8, 'runtime')] -model_name = 'Llama-2-7b-hf' +model_name = 'Llama-3-8B' for test in layers: try: @@ -25,7 +25,7 @@ def __init__(self, layers: int, autotune_type: str) -> None: jmodel_def = thunder.jit(model) # Torchcompile gives some troubles for now - jmodel_auto = thunder.jit(model, autotune_type=test.autotune_type, executors = ['nvfuser', 'cudnn', 'torch', 'python']) + jmodel_auto = thunder.jit(model, autotune_type=test.autotune_type, executors = ['nvfuser', 'torchcompile', 'cudnn', 'torch', 'python']) print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 98dd8319db..8744b5477a 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -164,7 +164,7 @@ def __init__( self.apply_bucketing_bw_trace: bool = apply_bucketing_bw_trace - self.benchmark_iters: int = 5 + self.benchmark_iters: int = 20 self.compile_data = compile_data @@ -189,6 +189,28 @@ def _best_runtime_and_memory_candidates(self, candidates): best_pair_runtime: OutputCandidate best_pair_memory: OutputCandidate for pair in candidates: + # No remat + pair_cost_time = 0 + pair_cost_mem = 0 + t, m, _ = benchmark_trace(pair.fw, iters=self.benchmark_iters) + log(f"Pair fw time: {t}, mem: {m}", level=LogLevel.INFO) + pair_cost_time = pair_cost_time + t + pair_cost_mem = pair_cost_mem + m + t, m, _ = benchmark_trace(pair.bw, iters=self.benchmark_iters) + log(f"Pair bw time: {t}, mem: {m}", level=LogLevel.INFO) + pair_cost_time = pair_cost_time + t + pair_cost_mem = pair_cost_mem + m + + if pair_cost_time < min_value_time: + best_pair_runtime = OutputCandidate(fw=pair.fw, bw=pair.bw, cost=pair_cost_time) + log(f"New best runtime pair (no remat):\n{best_pair_runtime}", level=LogLevel.INFO) + min_value_time = pair_cost_time + + if pair_cost_mem < min_value_mem: + best_pair_memory = OutputCandidate(fw=pair.fw, bw=pair.bw, cost=pair_cost_mem) + log(f"New best memory pair (no remat):\n{best_pair_memory}", level=LogLevel.INFO) + min_value_mem = pair_cost_mem + # Apply remat and select best trace pair pair_cost_time = 0 pair_cost_mem = 0 diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index f05ed6634c..0d8bdafc72 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -290,9 +290,6 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl out = None torch.cuda.empty_cache() - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - # Warm up cycles for _ in range(warm_up_iters): fn(*args) @@ -306,16 +303,20 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl stream = torch.cuda.current_stream() max_allocated_bytes = 0 torch.cuda.synchronize() - for i in range(iters): - torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) - torch.cuda.empty_cache() - torch.cuda._sleep(1_000_000) - start_events[i].record(stream) - fn(*args) - end_events[i].record(stream) - max_allocated_bytes = max( - max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device()) - ) + warm_up_iters = 10 + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + for i in range(iters + warm_up_iters): + if i >= warm_up_iters: + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + start_events[i-warm_up_iters].record(stream) + fn(*args) + end_events[i-warm_up_iters].record(stream) + max_allocated_bytes = max( + max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device()) + ) torch.cuda.synchronize() times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] From 04b3139e8e8235ac29e8b6b70ab419e1d1a629ff Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 7 Aug 2024 19:32:43 +0300 Subject: [PATCH 031/171] Fixed nv fuser compile options, now good traces will be generated (#12) --- examples/dev/LLaMAMLP.py | 2 +- examples/dev/nv_compile_options.py | 45 ++++ thunder/backend_optimizer/optimizer.py | 270 +++++++++++--------- thunder/backend_optimizer/utils.py | 333 +++++++++++++++---------- 4 files changed, 403 insertions(+), 247 deletions(-) create mode 100644 examples/dev/nv_compile_options.py diff --git a/examples/dev/LLaMAMLP.py b/examples/dev/LLaMAMLP.py index db208eb108..074c6e6a58 100644 --- a/examples/dev/LLaMAMLP.py +++ b/examples/dev/LLaMAMLP.py @@ -24,7 +24,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: model = LLaMAMLP(a, b) jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type='memory', executors = ['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python']) + jmodel_auto = thunder.jit(model, autotune_type='runtime', executors = ['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python']) print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) diff --git a/examples/dev/nv_compile_options.py b/examples/dev/nv_compile_options.py new file mode 100644 index 0000000000..29d54634ff --- /dev/null +++ b/examples/dev/nv_compile_options.py @@ -0,0 +1,45 @@ +import torch +import thunder +from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark + +class Module(torch.nn.Module): + def __init__(self, in_features, out_features) -> None: + super().__init__() + self.linear_a = torch.nn.Linear(in_features, out_features) + self.linear_b = torch.nn.Linear(out_features, in_features) + self.linear_c = torch.nn.Linear(in_features, out_features) + self.linear_d = torch.nn.Linear(out_features, in_features) + self.silu = torch.nn.SiLU() + + def forward(self, x: torch.Tensor): + b = self.linear_d(self.linear_c(self.linear_b(self.linear_a(x)))) + c = b @ torch.transpose(b, 0, 1) + for _ in range(10): + c = c @ torch.transpose(c, 0, 1) + return self.silu(c) + +with torch.device('cuda'): + in_features = 1 << 8 + out_features = 1 << 10 + model = Module(in_features, out_features) + x = torch.randn(1 << 9, in_features, requires_grad=True) + + jmodel_def = thunder.jit(model) + jmodel_auto = thunder.jit(model, autotune_type='runtime', executors=['nvfuser', 'cudnn', 'torch', 'python'], ) + + y = jmodel_def(x) + y = jmodel_auto(x) + + print('Results thunder benchmark:') + traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] + labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] + thunder_fw_bw_benchmark(traces, labels, 50) + + callables = [jmodel_def, jmodel_auto] + labels = ['def', 'auto'] + inputs = [x, x] + print('Results torch benchmark:') + torch_fw_bw_benchmark(callables, labels, inputs, 50) + + for t in traces: + print(f'{t}\n#########################################') diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 8744b5477a..94baaafe48 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -1,11 +1,10 @@ -from collections.abc import Sequence +from collections.abc import Callable, Sequence from enum import Enum -from thunder.backend_optimizer.utils import operation_in_trace +from thunder.backend_optimizer.utils import operation_in_trace, wrap_fn_with_exeuctor_compile_option from thunder.core.prims import PrimIDs from thunder.core.proxies import CollectionProxy, FloatProxy, IntegerProxy, TensorProxy from thunder.core.symbol import BoundSymbol from thunder.core.trace import from_trace, TraceCtx -from thunder.executors.data_dependent_partition import Node from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_always_executors from thunder.visualizer.visualizer_helper import Visualizer from typing import Hashable @@ -44,19 +43,34 @@ class OptimizationAlgorithm(Enum): BEST_FUSER = 0 -class OptimizerNode: - def __init__(self, node: Node): - self.node: Node = node - self.candidate_executors: dict[Hashable, float] = {} +class FusionCompileOptionsHelper: + def __init__(self, fusion_tag: str, symbol_tag: str, id: PrimIDs, impl: Callable, checker: Callable) -> None: + self.fusion_tag = fusion_tag + self.symbol_tag = symbol_tag + self.id: PrimIDs = id + self.impl: Callable = impl + self.checker: Callable = checker - def add_candidate(self, ex: Executor, benchmark: float): - self.candidate_executors[ex] = benchmark + +class TraceCandidate: + def __init__(self, *, trace: TraceCtx, compile_opt: FusionCompileOptionsHelper | None = None, label: str) -> None: + self.trace: TraceCtx = trace + self.compile_opt: FusionCompileOptionsHelper | None = compile_opt + self.label: str = label class TraceCandidates: - def __init__(self, best_time: TraceCtx | None = None, best_mem: TraceCtx | None = None) -> None: + def __init__( + self, + best_time: TraceCtx | None = None, + best_mem: TraceCtx | None = None, + compile_opt_time: FusionCompileOptionsHelper | None = None, + compile_opt_mem: FusionCompileOptionsHelper | None = None, + ) -> None: self.best_time: TraceCtx | None = best_time self.best_mem: TraceCtx | None = best_mem + self.compile_opt_time: FusionCompileOptionsHelper | None = compile_opt_time + self.compile_opt_mem: FusionCompileOptionsHelper | None = compile_opt_mem def __repr__(self) -> str: return f"\nBest runtime candidate:\n{self.best_time}\nBest memory candidate:\n{self.best_mem}" @@ -73,11 +87,17 @@ def attach_best_mem_candidate(self, trace: TraceCtx): def iterable(self) -> tuple[TraceCtx | None, TraceCtx | None]: return self.best_time, self.best_mem + def compile_opt_iterables(self) -> tuple[FusionCompileOptionsHelper | None, FusionCompileOptionsHelper | None]: + return self.compile_opt_time, self.compile_opt_mem + class OutputCandidate: - def __init__(self, *, fw: TraceCtx, bw: TraceCtx, cost: float = 0.0) -> None: + def __init__( + self, *, fw: TraceCtx, bw: TraceCtx, compile_opt: FusionCompileOptionsHelper | None = None, cost: float = 0.0 + ) -> None: self.fw: TraceCtx = fw self.bw: TraceCtx = bw + self.compile_opt: FusionCompileOptionsHelper | None = compile_opt self.tot_cost: float = cost def __repr__(self) -> str: @@ -91,16 +111,22 @@ def __repr__(self) -> str: class FusionStratHelper: def __init__(self) -> None: self.supported_executors: set = set(["nvfuser", "torchcompile"]) - self.optimized_traces_mem: list[dict[str | Hashable, TraceCtx]] = [] + self.optimized_traces_mem: list[dict[str | Hashable, tuple[TraceCtx, FusionCompileOptionsHelper | None]]] = [] self.optimized_traces_mem_benchmark_only: list[dict[str | Hashable, TraceCtx]] = [] - self.optimized_traces_time: list[dict[str | Hashable, TraceCtx]] = [] + self.optimized_traces_time: list[dict[str | Hashable, tuple[TraceCtx, FusionCompileOptionsHelper | None]]] = [] self.optimized_traces_time_benchmark_only: list[dict[str | Hashable, TraceCtx]] = [] +class FusionExecutorsPlacementCtx: + def __init__(self, *, placement: list, compile_options: FusionCompileOptionsHelper | None = None) -> None: + self.placement: list = placement + self.compile_options: FusionCompileOptionsHelper | None = compile_options + + class ExecutorPlacementOptions: def __init__(self) -> None: - self.placement_options_mem: list[list[Executor]] = [] - self.placement_options_time: list[list[Executor]] = [] + self.placement_options_mem: list[FusionExecutorsPlacementCtx] = [] + self.placement_options_time: list[FusionExecutorsPlacementCtx] = [] class LogLevel(Enum): @@ -116,12 +142,6 @@ def log(what: str, level: LogLevel): print(f"================================================================================ Autotune: {what}") -class FusionCompileOptionsHelper: - def __init__(self, fusion_tag: str, symbol_tag: str) -> None: - self.fusion_tag = fusion_tag - self.symbol_tag = symbol_tag - - class FusionPlacer: def __init__( self, @@ -151,8 +171,8 @@ def __init__( self.optimizer_type: OptimizerType = optimizer_type - self.active_fw_trace: TraceCtx | None = None - self.cached_fw_traces: dict[str | Hashable, TraceCandidates] = {} + self.active_fw_trace_ctx: tuple[TraceCtx | None, FusionCompileOptionsHelper | None] = None, None + self.cached_fw_traces: list[TraceCandidate] = [] self.bw_trace_candidates: TraceCandidates = TraceCandidates() self.out_traces_candidates: list[OutputCandidate] = [] self.best_pair_runtime: OutputCandidate @@ -168,12 +188,15 @@ def __init__( self.compile_data = compile_data + from thunder.executors.nvfuserex_impl import linear, _linear_check + from thunder.executors.nvfuserex_impl import matmul, _matmul_check + self.known_fusion_ex_compile_options: dict[str | Hashable, list[FusionCompileOptionsHelper]] = { - # "nvfuser": [ - # FusionCompileOptionsHelper("nv_enable_linear", "linear"), - # FusionCompileOptionsHelper("nv_enable_matmul", "matmul"), - # FusionCompileOptionsHelper("nv_enable_bookend", "bookend"), - # ] + "nvfuser": [ + FusionCompileOptionsHelper("nv_enable_linear", "linear", PrimIDs.LINEAR, linear, _linear_check), + FusionCompileOptionsHelper("nv_enable_matmul", "matmul", PrimIDs.MATMUL, matmul, _matmul_check), + # FusionCompileOptionsHelper("nv_enable_bookend", "bookend"), + ] } """ @@ -188,6 +211,7 @@ def _best_runtime_and_memory_candidates(self, candidates): min_value_mem: float = float("inf") best_pair_runtime: OutputCandidate best_pair_memory: OutputCandidate + pair: OutputCandidate for pair in candidates: # No remat pair_cost_time = 0 @@ -214,7 +238,14 @@ def _best_runtime_and_memory_candidates(self, candidates): # Apply remat and select best trace pair pair_cost_time = 0 pair_cost_mem = 0 - remat_fw, remat_bw = rematerialize_forward_and_backward(pair.fw, pair.bw) + # In order to call rematerialize_forward_and_backward we need to set the cached compile options + # derived from the forward trace generation + if pair.compile_opt: + remat_fw, remat_bw = wrap_fn_with_exeuctor_compile_option( + pair.compile_opt, rematerialize_forward_and_backward, pair.fw, pair.bw + ) + else: + remat_fw, remat_bw = rematerialize_forward_and_backward(pair.fw, pair.bw) t, m, _ = benchmark_trace(remat_fw, iters=self.benchmark_iters) log(f"Pair fw time: {t}, mem: {m}", level=LogLevel.INFO) pair_cost_time = pair_cost_time + t @@ -243,12 +274,13 @@ def _filter_candidates(self): # Number of fw traces to cached are: #fusion_executors * 2 def fw_benchmark(): # The optimizator builds the results in order following the self.fusion_executors list order + pair_time: dict + pair_mem: dict for pair_time, pair_mem in zip( self.fusion_strat_helper.optimized_traces_time, self.fusion_strat_helper.optimized_traces_mem ): - # pair is a dict - trc_time = list(pair_time.values())[0] - trc_mem = list(pair_mem.values())[0] + trc_time, compile_opt_time = list(pair_time.values())[0] + trc_mem, compile_opt_mem = list(pair_mem.values())[0] label = list(pair_time.keys())[0] # TODO (matteochen): remove the benchmark here as will done later on the bw pass c, m, _ = benchmark_trace(trc_time, self.benchmark_iters) @@ -267,8 +299,19 @@ def fw_benchmark(): self.debug_msg += ( f"Trace name = [{label}] - Target: MEM - Mem = {m / (2**30)} GB - Time = {c} ms\n{trc_mem}\n\n" ) + # For forward trace we cache the best placement for both runtime and memory for the current Fusion executor (represented by label) + if compile_opt_time is not None: + print(f"Caching fw with compile options time: {compile_opt_time.fusion_tag}") + if compile_opt_mem is not None: + print(f"Caching fw with compile options mem: {compile_opt_mem.fusion_tag}") + + for t, o in zip([trc_time, trc_mem], [compile_opt_time, compile_opt_mem]): + print(f'Caching fw candidate\n{t}\nwith option {o.fusion_tag if o else "None"}') + self.cached_fw_traces.append( + TraceCandidate(trace=t, compile_opt=o, label=label + '_enabled_' + o.fusion_tag if o is not None else label) + ) - self.cached_fw_traces[label] = TraceCandidates(best_time=trc_time, best_mem=trc_mem) + log("End fw time mem pair", level=LogLevel.INFO) def bw_benchmark(): time_result = BenchmarkResult() @@ -318,15 +361,17 @@ def bw_benchmark(): # Here we have to recover the traces without the pass through remat in order to be compliant # with thunder flow as we might have request for no remat # Unpack dict - trc = list(self.fusion_strat_helper.optimized_traces_time[time_result.index].values())[0] + trc = list(self.fusion_strat_helper.optimized_traces_time[time_result.index].values())[0][0] self.bw_trace_candidates.attach_best_time_candidate(trc) - trc = list(self.fusion_strat_helper.optimized_traces_mem[memory_result.index].values())[0] + trc = list(self.fusion_strat_helper.optimized_traces_mem[memory_result.index].values())[0][0] self.bw_trace_candidates.attach_best_mem_candidate(trc) # Now, finally build the pair fw and bw traces # The current fw trace is set by the caller and we take it as is. All current bw traces optimizations are made with the fw trace set by the caller for bw in self.bw_trace_candidates.iterable(): - self.out_traces_candidates.append(OutputCandidate(fw=self.active_fw_trace, bw=bw)) + self.out_traces_candidates.append( + OutputCandidate(fw=self.active_fw_trace_ctx[0], bw=bw, compile_opt=self.active_fw_trace_ctx[1]) + ) match self.trace_type: case TraceType.FW: @@ -336,6 +381,7 @@ def bw_benchmark(): if self.produce_log: import time + timestamp: str = str(time.time()) with open(f"{timestamp}-{self.log_file_name}", "w") as file: file.write(self.debug_msg) @@ -406,7 +452,7 @@ def get_placed_trace(mapping: dict[str, Executor], bound_symbols_in: Sequence[Bo ) return placed_trace, keys, executor_configuration - def search(ex: FusionExecutor): + def _search(ex: FusionExecutor, executor_compile_option: FusionCompileOptionsHelper | None = None): """ Fusable fn definition for nvFuser """ @@ -506,14 +552,17 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): # Helpers candidate_best_time = BenchmarkResult() candidate_best_mem = BenchmarkResult() - # Search for best candidate + # Search for best candidate, by default remat will be called to find the optimal choice + # TODO: enable requests for no remat becnhmarks for i, candidate in enumerate(candidate_executors): # Match the current candidate to benchmark partial trace match_bsym_output(current_bsym, [dict_time_strat, dict_mem_strat], candidate) # Retrieve partial trace and benchmark, apply remat if possible trc, _, _ = get_placed_trace(dict_time_strat, increasing_symbols) - if self.trace_type == TraceType.BW and self.active_fw_trace is not None: - _, trc = rematerialize_forward_and_backward(self.active_fw_trace, trc) + # Apply fw bw remat + if self.trace_type == TraceType.BW and self.active_fw_trace_ctx[0] is not None: + _, trc = rematerialize_forward_and_backward(self.active_fw_trace_ctx[0], trc) + # Now, benchmark t, m, _ = benchmark_trace(trc, self.benchmark_iters) # Update results if t < candidate_best_time.runtime: @@ -558,8 +607,8 @@ def measure_and_update_result(): nonlocal best_placement_mem nonlocal best_keys_mem trc, keys, placements = get_placed_trace(dict_time_strat, increasing_symbols) - if self.trace_type == TraceType.BW and self.active_fw_trace is not None: - _, trc = rematerialize_forward_and_backward(self.active_fw_trace, trc) + if self.trace_type == TraceType.BW and self.active_fw_trace_ctx[0] is not None: + _, trc = rematerialize_forward_and_backward(self.active_fw_trace_ctx[0], trc) cost, mem, _ = benchmark_trace(trc, self.benchmark_iters) log(f"Placed trace (cost = {cost} ms, mem = {mem/(2**30)} GB)\n{trc}", level=LogLevel.DEBUG) if cost < best_res_time.runtime or (cost == best_res_time.runtime and mem < best_res_time.memory): @@ -688,7 +737,7 @@ def measure_and_update_result(): trace.bound_symbols.append( self.trace.bound_symbols[-1].from_bsym(args=get_not_used_intermediate_outsputs(trace)) ) - # Save the optimal traces that we have found + # Save the optimal traces (both for runtime and memory consumption) that we have found for executors, container in zip( [executors_mem, executors_time], [ @@ -702,13 +751,20 @@ def measure_and_update_result(): always_executors=self.always_executors, empty_str=self.empty_executor_hashable_placeholder, ) + # print(f"Assigned trace:\n{trc}") if self.trace_type == TraceType.BW: - _, trc = rematerialize_forward_and_backward(self.active_fw_trace, trc) + # pass + _, trc = rematerialize_forward_and_backward(self.active_fw_trace_ctx[0], trc) container.append({ex.name: trc}) # Save executors in order to generate real fw and bw trace with correct output with the placer - self.executor_placement_options.placement_options_time.append(executors_time) - self.executor_placement_options.placement_options_mem.append(executors_mem) + # We add any provided compile option reference + self.executor_placement_options.placement_options_time.append( + FusionExecutorsPlacementCtx(placement=executors_time, compile_options=executor_compile_option) + ) + self.executor_placement_options.placement_options_mem.append( + FusionExecutorsPlacementCtx(placement=executors_mem, compile_options=executor_compile_option) + ) # If executor specific compile option is activated we need to know where a specific # trace does come from and the zip logic afterward can not be employed with self.fusion_executors list @@ -720,37 +776,25 @@ def measure_and_update_result(): log(f"Searching best placement for fusion executor = {ex.name}", level=LogLevel.INFO) - # We try to enable fusion specific compile options - ex_compile_opts = self.known_fusion_ex_compile_options.get(ex.name, []) + # We try to enable fusion specific compile options only for fw traces + # Backward traces will follow fw traces options + ex_compile_opts = ( + self.known_fusion_ex_compile_options.get(ex.name, []) if self.trace_type == TraceType.FW else [] + ) self.fusion_executors_saved_for_later.append(ex) # Always search with option disabled -> standard flow - search(ex) + _search(ex) # Currently we are enabling one compile option at the time as testing all the permutations might need too much time. - # Consider implementing patters based on the executor under investingation + # TODO: Consider implementing patters based on the executor under investingation if ex_compile_opts: for opt in ex_compile_opts: + # Search only if we have an instruction related to the compile option op_in_trace: bool = operation_in_trace(trace=self.trace, op=opt.symbol_tag) if op_in_trace: - # Search with option enabled - old_opt: bool | None = self.compile_data.compile_options.get(opt.fusion_tag, None) - # We test the inverse of the default one - new_opt = True if old_opt is None or old_opt is False else False - - # For nv_enable_bookend, by default is it defaulted to True, hence we try the False path - # https://github.com/Lightning-AI/lightning-thunder/blob/73b31f35ff95b08ceee7d5d5344127d619fd37fe/thunder/executors/nvfuserex_impl.py#L784 - if opt.fusion_tag == "nv_enable_bookend": - new_opt = False - - log( - f"Executor {ex.name} enabling compile option: {opt.fusion_tag} with value {new_opt}", - level=LogLevel.INFO, - ) - self.compile_data.compile_options[opt.fusion_tag] = new_opt self.fusion_executors_saved_for_later.append(ex) - search(ex) - self.compile_data.compile_options[opt.fusion_tag] = old_opt + wrap_fn_with_exeuctor_compile_option(opt, _search, ex, opt) """ ################################################## Public methods ################################################## @@ -759,11 +803,7 @@ def measure_and_update_result(): def get_optimal_fw_traces(self) -> Sequence[TraceCtx]: if not self.cached_fw_traces: raise AssertionError("Failed to obtain optimal fw traces") - return [ - getattr(candidate, field) - for candidate in self.cached_fw_traces.values() - for field in ["best_time", "best_mem"] - ] + return [candidate.trace for candidate in self.cached_fw_traces] def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: return ( @@ -810,44 +850,48 @@ def _optimize(): self.fusion_executors_saved_for_later ): raise AssertionError( - f"Unexpected time placement options size: {len(self.executor_placement_options.placement_options_time)}. Expected: {len(self.fusion_executors)}" + f"Unexpected time placement options size: {len(self.executor_placement_options.placement_options_time)}. Expected: {len(self.fusion_executors_saved_for_later)}" ) if len(self.executor_placement_options.placement_options_mem) != len(self.fusion_executors_saved_for_later): raise AssertionError( - f"Unexpected mem placement options size: {len(self.executor_placement_options.placement_options_mem)}. Expected: {len(self.fusion_executors)}" + f"Unexpected mem placement options size: {len(self.executor_placement_options.placement_options_mem)}. Expected: {len(self.fusion_executors_saved_for_later)}" ) - - for placement, ex in zip( + for placement_ctx, ex in zip( self.executor_placement_options.placement_options_time, self.fusion_executors_saved_for_later ): + trc = assign_executors( + in_trace=self.trace, + executor_list=placement_ctx.placement, + always_executors=self.always_executors, + empty_str=self.empty_executor_hashable_placeholder, + compile_data=self.compile_data, + fusion_executor_compile_options_to_activate=placement_ctx.compile_options, + ) self.fusion_strat_helper.optimized_traces_time.append( - { - ex.name: assign_executors( - in_trace=self.trace, - executor_list=placement, - always_executors=self.always_executors, - empty_str=self.empty_executor_hashable_placeholder, - ) - } + {ex.name: tuple([trc, placement_ctx.compile_options])} ) - for placement, ex in zip( + for placement_ctx, ex in zip( self.executor_placement_options.placement_options_mem, self.fusion_executors_saved_for_later ): + trc = assign_executors( + in_trace=self.trace, + executor_list=placement_ctx.placement, + always_executors=self.always_executors, + empty_str=self.empty_executor_hashable_placeholder, + compile_data=self.compile_data, + fusion_executor_compile_options_to_activate=placement_ctx.compile_options, + ) self.fusion_strat_helper.optimized_traces_mem.append( - { - ex.name: assign_executors( - in_trace=self.trace, - executor_list=placement, - always_executors=self.always_executors, - empty_str=self.empty_executor_hashable_placeholder, - ) - } + {ex.name: tuple([trc, placement_ctx.compile_options])} ) + # Filter out the optimal candidates for the current serach iteration self._filter_candidates() match self.trace_type: case TraceType.FW: + # Clear any previous results + self.cached_fw_traces = [] _optimize() # We have multiple cached optimized fw traces, find the best backward case TraceType.BW: @@ -857,31 +901,35 @@ def _optimize(): # Cached the bw trace as we need to modify the input trace during the loop cached_self_trace = from_trace(self.trace) cached_self_trace.bound_symbols = list(self.trace.bound_symbols) - for label, candidate in self.cached_fw_traces.items(): - log(f"Backward optimization with fw from {label}", level=LogLevel.INFO) - fw_traces = candidate.iterable() - for trc in fw_traces: - # TODO (matteochen): unify below with the original block - # Restore the original bw trace - self.trace = from_trace(cached_self_trace) - self.trace.bound_symbols = list(cached_self_trace.bound_symbols) - # Set the current active cached forward trace - self.active_fw_trace = trc + # Now we can generate backward solutions from the cached fw traces + for fw_trace_candidate in self.cached_fw_traces: + log(f"Backward optimization with fw from {fw_trace_candidate.label}", level=LogLevel.INFO) + # Restore the original bw trace + self.trace = from_trace(cached_self_trace) + self.trace.bound_symbols = list(cached_self_trace.bound_symbols) + # Set the current active cached forward trace context + print( + f'Current fw cached ctx:\n{fw_trace_candidate.trace}\nOptions: {fw_trace_candidate.compile_opt.fusion_tag if fw_trace_candidate.compile_opt is not None else "None"}' + ) + self.active_fw_trace_ctx = fw_trace_candidate.trace, fw_trace_candidate.compile_opt - log(f"Cached fw trace:\n{self.active_fw_trace}", level=LogLevel.INFO) - log(f"Input bw trace:\n{self.trace}", level=LogLevel.DEBUG) + log(f"Input bw trace:\n{self.trace}", level=LogLevel.DEBUG) - self.trace = update_bw_from_forward_optimization(fw=trc, bw=self.trace) + self.trace = update_bw_from_forward_optimization(fw=fw_trace_candidate.trace, bw=self.trace) - if self.apply_bucketing_bw_trace: - from thunder.distributed.transforms import FSDPCommBucketing + if self.apply_bucketing_bw_trace: + from thunder.distributed.transforms import FSDPCommBucketing - self.trace = FSDPCommBucketing.apply_bucketing_to_backward_trace(self.trace) + self.trace = FSDPCommBucketing.apply_bucketing_to_backward_trace(self.trace) - # Not called in the constructor for bw traces - dce(self.trace) + # Not called in the constructor for bw traces + self.trace = dce(self.trace) + # Enable any forward active compilation flag + if fw_trace_candidate.compile_opt: + wrap_fn_with_exeuctor_compile_option(fw_trace_candidate.compile_opt, _optimize) + else: _optimize() self.best_pair_runtime, self.best_pair_memory = self._best_runtime_and_memory_candidates( diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 0d8bdafc72..3afff9fe63 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -12,6 +12,7 @@ import torch import thunder + def sequence_hash(s: Sequence) -> str: name = "" for e in s: @@ -31,13 +32,17 @@ def sequence_hash(s: Sequence) -> str: raise AssertionError(f"What? Maybe nested Sequence. type = {type(e)}") return name + def can_executor_execute(ex: Executor, bsym: BoundSymbol) -> bool: try: return ex.can_execute(bsym) except Exception: return False -def get_first_available_operator_executor(*, bsym: BoundSymbol, executors: Sequence[Executor], empty_hash: str = 'empty'): + +def get_first_available_operator_executor( + *, bsym: BoundSymbol, executors: Sequence[Executor], empty_hash: str = "empty" +): for ex in executors: if isinstance(ex, FusionExecutor): continue @@ -80,157 +85,167 @@ def is_possible_out(name: str): ans.append(b.output) return ans + def assign_executors( *, in_trace: TraceCtx, - executor_list: list[Executor] | tuple[Executor, ...], + executor_list: list[Executor | FusionExecutor | OperatorExecutor] + | tuple[Executor | FusionExecutor | OperatorExecutor, ...], always_executors: list[Executor] | tuple[Executor, ...], empty_str: str | Hashable, + compile_data=None, + fusion_executor_compile_options_to_activate: Any | None = None, ) -> TraceCtx: from thunder.executors.passes import _transform_for_operator_executor_execution - swapmap: dict[Variable, Proxy] = {} - - def restore_correct_args(trace_in: TraceCtx): - def args_eq(a, b) -> bool: - if len(a) != len(b): - return False - for obj_a, obj_b in zip(a, b): - if type(obj_a) == type(obj_b) and isinstance(obj_a, TensorProxy): - if obj_a.name != obj_b.name: - return False - elif type(obj_a) == type(obj_b) and not isinstance(obj_a, TensorProxy): - if obj_a != obj_b: - raise AssertionError(f"What do you want to do here:\nobj_a:\n{obj_a}\nobj_b:{obj_b}") + def _assign_executors(): + swapmap: dict[Variable, Proxy] = {} + + def restore_correct_args(trace_in: TraceCtx): + def args_eq(a, b) -> bool: + if len(a) != len(b): + return False + for obj_a, obj_b in zip(a, b): + if type(obj_a) == type(obj_b) and isinstance(obj_a, TensorProxy): + if obj_a.name != obj_b.name: + return False + elif type(obj_a) == type(obj_b) and not isinstance(obj_a, TensorProxy): + if obj_a != obj_b: + raise AssertionError(f"What do you want to do here:\nobj_a:\n{obj_a}\nobj_b:{obj_b}") + return True + + def clear(bsym: BoundSymbol, input): + size = len(bsym.subsymbols) + if size > 0: + for subsym in bsym.subsymbols: + if not args_eq(subsym.args, input): + subsym.args = tuple(list(input)) + clear(subsym, input) + + for bsym in trace_in.bound_symbols: + if isinstance(bsym.sym.executor, OperatorExecutor): + clear(bsym, bsym.args) + + def update_swapmap(o: Any, no: Any) -> None: + if isinstance(o, Proxy): + check( + isinstance(no, Proxy), + lambda: f"Expected an execution transform to produce outputs with the same type, but found {type(o)} and {type(no)}", + ) + + vo = variableify(o) + vno = variableify(no) + if vo == vno: + return + swapmap[vno] = o + + def preserve_bsym(bsym: BoundSymbol) -> Any: + trace: TraceCtx | None = get_tracectx() + if trace is None: + raise AssertionError("None trace context") + trace.scopes[-1].append(bsym) + for p in chain(bsym.flat_proxy_outs, bsym.flat_proxy_args): + trace.names.add(p.name) + return bsym.output + + def visit_helper(bsym: BoundSymbol, ex: Executor) -> None | bool: + if bsym.sym.python_impl is not None: + return None + + # We have mapped this at previous stages + if ex.name == empty_str: + return None + + execution_transform: None | Callable = ex.get_execution_transform(bsym.sym) + out: Any + if execution_transform is not None: + out = execution_transform(*bsym.args, **bsym.kwargs) + elif isinstance(ex, OperatorExecutor): + # Calls the operator executor's operation + op: Symbol | None = ex.implmap[bsym.sym.id].symbol + if op is None: + raise AssertionError("op is None") + out = op(*bsym.args, **bsym.kwargs) + elif isinstance(ex, FusionExecutor): + # Preserves the symbol as is (it will be handled in the fusion pass) + out = preserve_bsym(bsym) + else: + raise AssertionError("Unknown executor") + + safe_map_flat(update_swapmap, bsym.output, out) + return True - def clear(bsym: BoundSymbol, input): - size = len(bsym.subsymbols) - if size > 0: - for subsym in bsym.subsymbols: - if not args_eq(subsym.args, input): - subsym.args = tuple(list(input)) - clear(subsym, input) - - for bsym in trace_in.bound_symbols: - if isinstance(bsym.sym.executor, OperatorExecutor): - clear(bsym, bsym.args) - - def update_swapmap(o: Any, no: Any) -> None: - if isinstance(o, Proxy): - check( - isinstance(no, Proxy), - lambda: f"Expected an execution transform to produce outputs with the same type, but found {type(o)} and {type(no)}", - ) + def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: + return transforms.VISIT_TYPE.NO_OP if visit_helper(bsym, ex) is None else transforms.VISIT_TYPE.REPLACE - vo = variableify(o) - vno = variableify(no) - if vo == vno: - return - swapmap[vno] = o - - def preserve_bsym(bsym: BoundSymbol) -> Any: - trace: TraceCtx | None = get_tracectx() - if trace is None: - raise AssertionError("None trace context") - trace.scopes[-1].append(bsym) - for p in chain(bsym.flat_proxy_outs, bsym.flat_proxy_args): - trace.names.add(p.name) - return bsym.output - - def visit_helper(bsym: BoundSymbol, ex: Executor) -> None | bool: - if bsym.sym.python_impl is not None: - return None - - # We have mapped this at previous stages - if ex.name == empty_str: - return None - - execution_transform: None | Callable = ex.get_execution_transform(bsym.sym) - out: Any - if execution_transform is not None: - out = execution_transform(*bsym.args, **bsym.kwargs) - elif isinstance(ex, OperatorExecutor): - # Calls the operator executor's operation - op: Symbol | None = ex.implmap[bsym.sym.id].symbol - if op is None: - raise AssertionError("op is None") - out = op(*bsym.args, **bsym.kwargs) - elif isinstance(ex, FusionExecutor): - # Preserves the symbol as is (it will be handled in the fusion pass) - out = preserve_bsym(bsym) - else: - raise AssertionError("Unknown executor") + if len(executor_list) != len(in_trace.bound_symbols): + raise AssertionError("len(executor_list) != len(in_trace.bound_symbols)") - safe_map_flat(update_swapmap, bsym.output, out) + cached_subsymbols: dict[str, Sequence[BoundSymbol]] = {} + executor_mapping: dict[str, Executor] = {} + unique_fusion_executors = set() - return True + # Input should have equal length + if len(executor_list) != len(in_trace.bound_symbols): + raise AssertionError("len(executor_list) != len(extrace.bound_symbols)") - def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: - return transforms.VISIT_TYPE.NO_OP if visit_helper(bsym, ex) is None else transforms.VISIT_TYPE.REPLACE - - if len(executor_list) != len(in_trace.bound_symbols): - raise AssertionError("len(executor_list) != len(in_trace.bound_symbols)") - - cached_subsymbols: dict[str, Sequence[BoundSymbol]] = {} - executor_mapping: dict[str, Executor] = {} - unique_fusion_executors = set() - - # Input should have equal length - if len(executor_list) != len(in_trace.bound_symbols): - raise AssertionError("len(executor_list) != len(extrace.bound_symbols)") - - for b, e in zip(in_trace.bound_symbols, executor_list): - if isinstance(e, FusionExecutor): - unique_fusion_executors.add(e) - if isinstance(b.output, TensorProxy): - executor_mapping[b.output.name] = e - - extrace = transforms.visitor_transform_paired(in_trace, visit, zip(in_trace.bound_symbols, executor_list)) - - # Restores original variables - bound_symbols: list[BoundSymbol] = [] - for bsym in extrace.bound_symbols: - nbsym: BoundSymbol = bsym.from_bsym_swap_proxies(swapmap) - bound_symbols.append(nbsym) - extrace.bound_symbols = bound_symbols - - for bsym in extrace.bound_symbols: - if isinstance(bsym.output, TensorProxy): - t_name = bsym.output.name - if t_name not in executor_mapping: - # Symbol added by the visitor - continue - # raise AssertionError('Failed to retrive key in mapping') - saved_ex = executor_mapping[t_name] - if isinstance(saved_ex, OperatorExecutor): - cached_subsymbols[t_name] = list(bsym.subsymbols) - # This will leave out these symbols from the fusion pass - bsym.subsymbols = [] - - # Perform fusion pass - for ex in unique_fusion_executors: - extrace = ex.fusion_pass(extrace) - - # Restore subsymbols - # TODO (matteochen): Improve this search - for k, v in cached_subsymbols.items(): - # Note some symbols may be cut out by the fusion pass -> CSE - # For example: - # a = 1 + 1 - # b = 1 + 1 - # c = a + b - # being replaced by c = a + a - for bsym in extrace.bound_symbols: - if isinstance(bsym.output, TensorProxy) and bsym.output.name == k: - bsym.subsymbols = v + for b, e in zip(in_trace.bound_symbols, executor_list): + if isinstance(e, FusionExecutor): + unique_fusion_executors.add(e) + if isinstance(b.output, TensorProxy): + executor_mapping[b.output.name] = e - restore_correct_args(extrace) + extrace = transforms.visitor_transform_paired(in_trace, visit, zip(in_trace.bound_symbols, executor_list)) - # Apply always executors - extrace = _transform_for_operator_executor_execution(extrace, always_executors) + # Restores original variables + bound_symbols: list[BoundSymbol] = [] + for bsym in extrace.bound_symbols: + nbsym: BoundSymbol = bsym.from_bsym_swap_proxies(swapmap) + bound_symbols.append(nbsym) + extrace.bound_symbols = bound_symbols + + for bsym in extrace.bound_symbols: + if isinstance(bsym.output, TensorProxy): + t_name = bsym.output.name + if t_name not in executor_mapping: + # Symbol added by the visitor + continue + # raise AssertionError('Failed to retrive key in mapping') + saved_ex = executor_mapping[t_name] + if isinstance(saved_ex, OperatorExecutor): + cached_subsymbols[t_name] = list(bsym.subsymbols) + # This will leave out these symbols from the fusion pass + bsym.subsymbols = [] + + # Perform fusion pass + for ex in unique_fusion_executors: + extrace = ex.fusion_pass(extrace) + + # Restore subsymbols + # TODO (matteochen): Improve this search + for k, v in cached_subsymbols.items(): + # Note some symbols may be cut out by the fusion pass -> CSE + # For example: + # a = 1 + 1 + # b = 1 + 1 + # c = a + b + # being replaced by c = a + a + for bsym in extrace.bound_symbols: + if isinstance(bsym.output, TensorProxy) and bsym.output.name == k: + bsym.subsymbols = v + + restore_correct_args(extrace) + + # Apply always executors + extrace = _transform_for_operator_executor_execution(extrace, always_executors) + + return extrace + + if fusion_executor_compile_options_to_activate: + return wrap_fn_with_exeuctor_compile_option(fusion_executor_compile_options_to_activate, _assign_executors) + return _assign_executors() - return extrace def operation_in_trace(*, trace: TraceCtx, op: str) -> bool: # Some optimizations are not available as symbols @@ -243,6 +258,7 @@ def operation_in_trace(*, trace: TraceCtx, op: str) -> bool: return True return False + def benchmark_trace( trace: TraceCtx, iters: int = 1, @@ -502,3 +518,50 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: reset_tracectx(trace_tok) return t, m, answer + + +def register_impl_executor(ex: Executor, id: PrimIDs, fn: Callable, checker: Callable) -> None: + if ex.name == "nvfuser": + from thunder.executors.nvfuserex_impl import register_supported + + register_supported(id, fn, checker) + + +def recover_ex_from_compile_option(option: str) -> Executor: + if option.startswith("nv"): + from thunder.executors.nvfuserex_impl import ex + + return ex + else: + raise AssertionError(f"Compile option not recognized: {option}") + + +def wrap_fn_with_exeuctor_compile_option(option, fn: Callable | None = None, *args): + from thunder.core import compile_data + + cd = compile_data.get_compile_data() + if option is not None: + # Update compile option context + if cd is None: + raise AssertionError("compile_data is None") + # TODO: use getattr + old_opt: bool | None = cd.compile_options.get(option.fusion_tag, None) + new_opt = True if old_opt is None or old_opt is False else False + cd.compile_options[option.fusion_tag] = new_opt + # Register the impl for the executor in order to be able to execute the id + register_impl_executor( + recover_ex_from_compile_option(option.fusion_tag), + option.id, + option.impl, + option.checker, + ) + # Call fn and return output + if fn: + out = fn(*args) + else: + out = None + # Restore compile option + if option is not None: + cd.compile_options[option.fusion_tag] = old_opt + + return out From fd44dc266dd939f17058a2554b520b284295adbc Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 7 Aug 2024 22:31:16 +0300 Subject: [PATCH 032/171] Allowing duplicates during eval trace write / Added missing fn call during benchmark utils (#13) --- thunder/backend_optimizer/utils.py | 2 ++ thunder/core/transforms.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 3afff9fe63..e10d20d93e 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -333,6 +333,8 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl max_allocated_bytes = max( max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device()) ) + else: + fn(*args) torch.cuda.synchronize() times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index f14f89ddcb..1d40046cda 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -1610,7 +1610,7 @@ def read(x: Variable): else: return x - def write(v: Variable, val: Any, allow_duplicates=False) -> None: + def write(v: Variable, val: Any, allow_duplicates=True) -> None: if not isinstance(v, Variable): return # Duplicates are allowed and overwritten From 670a58235546cb5937c33d6c4b99a8a31c84b4ad Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 7 Aug 2024 22:34:16 +0300 Subject: [PATCH 033/171] Updated test runner --- examples/dev/litGPT.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 56f3430839..344541190b 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -5,26 +5,26 @@ import torch class Test: - def __init__(self, layers: int, autotune_type: str) -> None: + def __init__(self, layers: int, autotune_type: str, batch_size: int) -> None: self.layers = layers self.autotune_type = autotune_type + self.batch_size = batch_size -layers = [Test(8, 'runtime')] +layers = [Test(8, 'runtime', 1), Test(8, 'runtime', 4)] model_name = 'Llama-3-8B' for test in layers: try: - print('Layers:', test.layers) + print('\n\nLayers:', test.layers) cfg = Config.from_name(model_name) cfg.n_layer = test.layers torch.set_default_dtype(torch.bfloat16) with torch.device('cuda'): model = GPT(cfg) - x = torch.randint(1, model.config.vocab_size, (1, 512)) + x = torch.randint(1, model.config.vocab_size, (test.batch_size, 512)) jmodel_def = thunder.jit(model) - # Torchcompile gives some troubles for now jmodel_auto = thunder.jit(model, autotune_type=test.autotune_type, executors = ['nvfuser', 'torchcompile', 'cudnn', 'torch', 'python']) print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) From dfc7fdc7b36b3a776e9749830ef3bf70b67e26e9 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 7 Aug 2024 22:36:20 +0300 Subject: [PATCH 034/171] Prev commit --- examples/dev/litGPT.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 344541190b..96a6d7f19b 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -30,17 +30,17 @@ def __init__(self, layers: int, autotune_type: str, batch_size: int) -> None: print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) - - print('Results thunder benchmark:') + iters = 100 + print(f'Results thunder benchmark ({iters} iters):') traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] - thunder_fw_bw_benchmark(traces, labels, 50) + thunder_fw_bw_benchmark(traces, labels, iters) - print('\n\nResults torch fw bw benchmark:') + print(f'\n\nResults torch fw bw benchmark ({iters} iters):') callables = [jmodel_def, jmodel_auto] labels = ['def', 'auto'] inputs = [x, x] - torch_fw_bw_benchmark(callables, labels, inputs, 50) + torch_fw_bw_benchmark(callables, labels, inputs, iters) print('\n\n\n\n\n\n') print(f'{thunder.last_traces(jmodel_def)[-1]}') From 9b8eb4d52211f405be9b0c5db7c282506337ebce Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 7 Aug 2024 22:38:13 +0300 Subject: [PATCH 035/171] Added comment --- thunder/backend_optimizer/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 94baaafe48..01f51a51d8 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -239,7 +239,7 @@ def _best_runtime_and_memory_candidates(self, candidates): pair_cost_time = 0 pair_cost_mem = 0 # In order to call rematerialize_forward_and_backward we need to set the cached compile options - # derived from the forward trace generation + # derived from the forward trace generation. At this stage all the infos are contained inside the pair object. if pair.compile_opt: remat_fw, remat_bw = wrap_fn_with_exeuctor_compile_option( pair.compile_opt, rematerialize_forward_and_backward, pair.fw, pair.bw From 823398a244a04863920c0543a93c73201289d127 Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 9 Aug 2024 13:46:40 +0300 Subject: [PATCH 036/171] Updated log --- thunder/backend_optimizer/optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 01f51a51d8..99b335fd85 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -217,11 +217,11 @@ def _best_runtime_and_memory_candidates(self, candidates): pair_cost_time = 0 pair_cost_mem = 0 t, m, _ = benchmark_trace(pair.fw, iters=self.benchmark_iters) - log(f"Pair fw time: {t}, mem: {m}", level=LogLevel.INFO) + log(f"Pair fw time no remat: {t}, mem: {m}", level=LogLevel.INFO) pair_cost_time = pair_cost_time + t pair_cost_mem = pair_cost_mem + m t, m, _ = benchmark_trace(pair.bw, iters=self.benchmark_iters) - log(f"Pair bw time: {t}, mem: {m}", level=LogLevel.INFO) + log(f"Pair bw time no remat: {t}, mem: {m}", level=LogLevel.INFO) pair_cost_time = pair_cost_time + t pair_cost_mem = pair_cost_mem + m From a0f0d2b57a1dc86ee2eb3d8a720961fb45e696dd Mon Sep 17 00:00:00 2001 From: matteochen Date: Mon, 12 Aug 2024 21:47:26 +0300 Subject: [PATCH 037/171] Defaults empty executors list to all executors --- thunder/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/thunder/__init__.py b/thunder/__init__.py index 8c825a7b44..be6b866fb6 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -330,6 +330,11 @@ def jit( else: raise AssertionError(f'Not supported optimization: {autotune_type}') + # Default the executors list to all_executors if no options are given + # Otherwise the user restricted choice will be used + if not executors: + executors = get_all_executors() + # Resolve names of executors executors = resolve_executors(executors) From 9b3215828b5e29cc63853c5ec872b9a761b48ddc Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 13 Aug 2024 15:04:04 +0300 Subject: [PATCH 038/171] Fixed formatting --- thunder/executors/torch_autograd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 4318b1d577..2b212f4f3a 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -146,7 +146,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat from thunder.visualizer.visualizer_helper import Visualizer from thunder.backend_optimizer.optimizer import log, LogLevel, TraceType, BackendOptimizer, OptimizerType, benchmark_trace - def split(): + def split(): utils.check(compile_data is not None, lambda: "`compile_data` is required") # NOTE: This function is rather slow, so it's intended to be used # behind a cache. From fdeddc74de8eb3e8322010dd872abb6105f76cf8 Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 13 Aug 2024 17:46:56 +0300 Subject: [PATCH 039/171] Cuda graphs integration / minors con benchmarks and profiler / made interface for Placer object --- examples/dev/litGPT.py | 4 +- examples/dev/nanogpt.py | 30 +-- ...le_options.py => nvfuser_optimizations.py} | 25 ++- examples/dev/simple.py | 12 +- thunder/__init__.py | 2 +- thunder/backend_optimizer/optimizer.py | 183 +++++++++++------- thunder/backend_optimizer/utils.py | 48 ++--- thunder/benchmarks/utils.py | 29 +-- thunder/common.py | 2 +- thunder/executors/cudagraphex.py | 9 +- thunder/executors/torch_autograd.py | 29 +-- 11 files changed, 205 insertions(+), 168 deletions(-) rename examples/dev/{nv_compile_options.py => nvfuser_optimizations.py} (63%) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 96a6d7f19b..3cb68b9993 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -1,5 +1,5 @@ from litgpt import GPT -from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark +from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_total_benchmark from thunder.tests.litgpt_model import Config import thunder import torch @@ -41,6 +41,8 @@ def __init__(self, layers: int, autotune_type: str, batch_size: int) -> None: labels = ['def', 'auto'] inputs = [x, x] torch_fw_bw_benchmark(callables, labels, inputs, iters) + print(f'\n\nResults torch total benchmark ({iters} iters):') + torch_total_benchmark(callables, labels, inputs, iters) print('\n\n\n\n\n\n') print(f'{thunder.last_traces(jmodel_def)[-1]}') diff --git a/examples/dev/nanogpt.py b/examples/dev/nanogpt.py index c034c5bc0e..361d795d52 100644 --- a/examples/dev/nanogpt.py +++ b/examples/dev/nanogpt.py @@ -11,7 +11,7 @@ def run(target: str = 'runtime'): raise AssertionError(f'Target {target} not supported. Only runtime and memory available') # ----------------------------------------------------------------------------- batch_size = 12 - block_size = 1024 + block_size = 512 bias = False real_data = False seed = 1337 @@ -42,16 +42,16 @@ def run(target: str = 'runtime'): # model init gptconf = GPTConfig( block_size = block_size, # how far back does the model look? i.e. context size - n_layer = 4, n_head = 12, n_embd = 768, # size of the model + n_layer = 1, n_head = 6, n_embd = 768, # size of the model dropout = 0, # for determinism bias = bias, ) model = GPT(gptconf) model.to(device) - jmodel_def = thunder.jit(model) + jmodel_def = thunder.jit(model, use_cudagraphs=True) # Currently sdpa does not work? - jmodel_auto = thunder.jit(model, autotune_type={target}, executors = ['torchcompile', 'nvfuser', 'cudnn', 'torch', 'python']) + jmodel_auto = thunder.jit(model, autotune_type=target, executors = ['torchcompile', 'nvfuser', 'cudnn', 'torch', 'python'], use_cudagraphs=True) if compile: print("Compiling model...") @@ -81,7 +81,7 @@ def run(target: str = 'runtime'): X, Y = get_batch('train') for k in range(num_steps): with ctx: - _, loss = model(X, Y) + _, loss = mod(X, Y) X, Y = get_batch('train') loss.backward() lossf = loss.item() @@ -101,7 +101,7 @@ def measure(m, label): X, Y = get_batch('train') loss.backward() - iters = 5 + iters = 100 start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] stream = torch.cuda.current_stream() @@ -127,11 +127,16 @@ def measure(m, label): measure(jmodel_def, 'def') print('\n\nResults thunder benchmark:') - traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] + traces = [ + thunder.last_traces(jmodel_def)[-1], + thunder.last_traces(jmodel_auto)[-1], + thunder.last_backward_traces(jmodel_def)[-1], + thunder.last_backward_traces(jmodel_auto)[-1], + ] traces.reverse() labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] labels.reverse() - thunder_fw_bw_benchmark(traces, labels, 5) + thunder_fw_bw_benchmark(traces, labels, 100) # X, Y = get_batch('train') # out_eager = model(X, Y) @@ -142,9 +147,8 @@ def measure(m, label): # for a, b in zip(out_eager, out_auto): # print('deviation auto:', (a - b).abs().max().item()) - # traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] - # for t in traces: - # print(f'{t}\n############################################') + traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] + for t in traces: + print(f'{t}\n############################################') -run_memory() -run_time() +run() diff --git a/examples/dev/nv_compile_options.py b/examples/dev/nvfuser_optimizations.py similarity index 63% rename from examples/dev/nv_compile_options.py rename to examples/dev/nvfuser_optimizations.py index 29d54634ff..95abc755ba 100644 --- a/examples/dev/nv_compile_options.py +++ b/examples/dev/nvfuser_optimizations.py @@ -1,20 +1,22 @@ import torch import thunder -from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark +from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_fw_bw_benchmark_nvsight class Module(torch.nn.Module): def __init__(self, in_features, out_features) -> None: super().__init__() - self.linear_a = torch.nn.Linear(in_features, out_features) - self.linear_b = torch.nn.Linear(out_features, in_features) - self.linear_c = torch.nn.Linear(in_features, out_features) - self.linear_d = torch.nn.Linear(out_features, in_features) + self.linear = torch.nn.Sequential( + torch.nn.Linear(in_features, out_features), + torch.nn.Linear(out_features, in_features), + torch.nn.Linear(in_features, out_features), + torch.nn.Linear(out_features, in_features) + ) self.silu = torch.nn.SiLU() def forward(self, x: torch.Tensor): - b = self.linear_d(self.linear_c(self.linear_b(self.linear_a(x)))) + b = self.linear(x) c = b @ torch.transpose(b, 0, 1) - for _ in range(10): + for _ in range(4): c = c @ torch.transpose(c, 0, 1) return self.silu(c) @@ -25,13 +27,18 @@ def forward(self, x: torch.Tensor): x = torch.randn(1 << 9, in_features, requires_grad=True) jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type='runtime', executors=['nvfuser', 'cudnn', 'torch', 'python'], ) + jmodel_auto = thunder.jit(model, autotune_type='runtime', executors=['nvfuser', 'cudnn', 'torch', 'python']) y = jmodel_def(x) y = jmodel_auto(x) print('Results thunder benchmark:') - traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] + traces = [ + thunder.last_traces(jmodel_def)[-1], + thunder.last_traces(jmodel_auto)[-1], + thunder.last_backward_traces(jmodel_def)[-1], + thunder.last_backward_traces(jmodel_auto)[-1], + ] labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] thunder_fw_bw_benchmark(traces, labels, 50) diff --git a/examples/dev/simple.py b/examples/dev/simple.py index b995a83102..0add49560a 100644 --- a/examples/dev/simple.py +++ b/examples/dev/simple.py @@ -15,14 +15,13 @@ def forward(self, x: torch.Tensor): return self.silu(c) with torch.device('cuda'): - multiplier = 100 - in_features = 20 * multiplier - out_features = 30 * multiplier + in_features = 4096 + out_features = 11008 model = Module(in_features, out_features) x = torch.randn(128, in_features, requires_grad=True) - jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type='runtime', executors=['nvfuser', 'torchcompile', 'cudnn', 'sdpa', 'torch', 'python']) + jmodel_def = thunder.jit(model, ) + jmodel_auto = thunder.jit(model, autotune_type='runtime', executors=['nvfuser', 'torchcompile', 'cudnn', 'torch', 'python'], ) y = jmodel_def(x) y = jmodel_auto(x) @@ -37,3 +36,6 @@ def forward(self, x: torch.Tensor): inputs = [x, x] print('Results torch benchmark:') torch_fw_bw_benchmark(callables, labels, inputs, 50) + + for t in traces: + print(f'{t}\n###################') diff --git a/thunder/__init__.py b/thunder/__init__.py index be6b866fb6..61d6e0adeb 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -353,7 +353,7 @@ def jit( # TODO RC1 Refine the compile data option to remove unused options # TODO: refine options # NOTE(fixme): use_cudagraphs is being absorbed into compile_options - use_cudagraphs = compile_options.get("use_cudagraphs", False) + use_cudagraphs = compile_options.get("use_cudagraphs", None) cd = CompileData( fn=fn, langctx=langctx, diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 99b335fd85..00fbb0f16a 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -137,12 +137,11 @@ class LogLevel(Enum): log_level: LogLevel = LogLevel.INFO -def log(what: str, level: LogLevel): +def log(what: str, level: LogLevel = LogLevel.INFO): if log_level == LogLevel.DEBUG or log_level == level: print(f"================================================================================ Autotune: {what}") - -class FusionPlacer: +class PlacerBase: def __init__( self, *, @@ -153,7 +152,7 @@ def __init__( visualizer: Visualizer | None = None, optimizer_type: OptimizerType = OptimizerType.RUNTIME, compile_data, - ) -> None: + ) -> None: self.always_executors: tuple[Executor, ...] = get_always_executors() self.empty_executor_hashable_placeholder: str = "empty" self.executors: Sequence[Executor] = priority_executors @@ -178,19 +177,52 @@ def __init__( self.best_pair_runtime: OutputCandidate self.best_pair_memory: OutputCandidate - # Strat fusion - self.fusion_strat_helper: FusionStratHelper = FusionStratHelper() - self.executor_placement_options: ExecutorPlacementOptions = ExecutorPlacementOptions() - self.apply_bucketing_bw_trace: bool = apply_bucketing_bw_trace self.benchmark_iters: int = 20 self.compile_data = compile_data + def optimize(self): + pass + + def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): + pass + + def get_optimal_fw_traces(self) -> Sequence[TraceCtx]: + return [] + + def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: + return (TraceCtx(), TraceCtx()) + +class FusionPlacer_BeamSearch(PlacerBase): + def __init__( + self, + *, + priority_executors: Sequence[Executor], + produce_log: bool = True, + apply_bucketing_bw_trace: bool, + log_file_name: str, + visualizer: Visualizer | None = None, + optimizer_type: OptimizerType = OptimizerType.RUNTIME, + compile_data, + ) -> None: + super().__init__( + priority_executors=priority_executors, + produce_log=produce_log, + apply_bucketing_bw_trace=apply_bucketing_bw_trace, + log_file_name=log_file_name, + visualizer=visualizer, + optimizer_type=optimizer_type, + compile_data=compile_data + ) + + # Strat fusion + self.fusion_strat_helper: FusionStratHelper = FusionStratHelper() + self.executor_placement_options: ExecutorPlacementOptions = ExecutorPlacementOptions() + from thunder.executors.nvfuserex_impl import linear, _linear_check from thunder.executors.nvfuserex_impl import matmul, _matmul_check - self.known_fusion_ex_compile_options: dict[str | Hashable, list[FusionCompileOptionsHelper]] = { "nvfuser": [ FusionCompileOptionsHelper("nv_enable_linear", "linear", PrimIDs.LINEAR, linear, _linear_check), @@ -206,6 +238,7 @@ def __init__( def _best_runtime_and_memory_candidates(self, candidates): from thunder.core.rematerialization import rematerialize_forward_and_backward from thunder.backend_optimizer.utils import benchmark_trace + from thunder.executors.cudagraphex import cudagraphex min_value_time: float = float("inf") min_value_mem: float = float("inf") @@ -213,57 +246,44 @@ def _best_runtime_and_memory_candidates(self, candidates): best_pair_memory: OutputCandidate pair: OutputCandidate for pair in candidates: - # No remat - pair_cost_time = 0 - pair_cost_mem = 0 - t, m, _ = benchmark_trace(pair.fw, iters=self.benchmark_iters) - log(f"Pair fw time no remat: {t}, mem: {m}", level=LogLevel.INFO) - pair_cost_time = pair_cost_time + t - pair_cost_mem = pair_cost_mem + m - t, m, _ = benchmark_trace(pair.bw, iters=self.benchmark_iters) - log(f"Pair bw time no remat: {t}, mem: {m}", level=LogLevel.INFO) - pair_cost_time = pair_cost_time + t - pair_cost_mem = pair_cost_mem + m - - if pair_cost_time < min_value_time: - best_pair_runtime = OutputCandidate(fw=pair.fw, bw=pair.bw, cost=pair_cost_time) - log(f"New best runtime pair (no remat):\n{best_pair_runtime}", level=LogLevel.INFO) - min_value_time = pair_cost_time - - if pair_cost_mem < min_value_mem: - best_pair_memory = OutputCandidate(fw=pair.fw, bw=pair.bw, cost=pair_cost_mem) - log(f"New best memory pair (no remat):\n{best_pair_memory}", level=LogLevel.INFO) - min_value_mem = pair_cost_mem - - # Apply remat and select best trace pair - pair_cost_time = 0 - pair_cost_mem = 0 - # In order to call rematerialize_forward_and_backward we need to set the cached compile options - # derived from the forward trace generation. At this stage all the infos are contained inside the pair object. if pair.compile_opt: remat_fw, remat_bw = wrap_fn_with_exeuctor_compile_option( pair.compile_opt, rematerialize_forward_and_backward, pair.fw, pair.bw ) else: remat_fw, remat_bw = rematerialize_forward_and_backward(pair.fw, pair.bw) - t, m, _ = benchmark_trace(remat_fw, iters=self.benchmark_iters) - log(f"Pair fw time: {t}, mem: {m}", level=LogLevel.INFO) - pair_cost_time = pair_cost_time + t - pair_cost_mem = pair_cost_mem + m - t, m, _ = benchmark_trace(remat_bw, iters=self.benchmark_iters) - log(f"Pair bw time: {t}, mem: {m}", level=LogLevel.INFO) - pair_cost_time = pair_cost_time + t - pair_cost_mem = pair_cost_mem + m - - if pair_cost_time < min_value_time: - best_pair_runtime = OutputCandidate(fw=remat_fw, bw=remat_bw, cost=pair_cost_time) - log(f"New best runtime pair:\n{best_pair_runtime}", level=LogLevel.INFO) - min_value_time = pair_cost_time - - if pair_cost_mem < min_value_mem: - best_pair_memory = OutputCandidate(fw=remat_fw, bw=remat_bw, cost=pair_cost_mem) - log(f"New best memory pair:\n{best_pair_memory}", level=LogLevel.INFO) - min_value_mem = pair_cost_mem + # Create pair final options by applying final optimizations: cudagraphs and rematerialization + pair_options: list[tuple[TraceCtx, TraceCtx]] = [ + (pair.fw, pair.bw), + (cudagraphex.fusion_pass(pair.fw), cudagraphex.fusion_pass(pair.bw)), + (remat_fw, remat_bw), + (cudagraphex.fusion_pass(remat_fw), cudagraphex.fusion_pass(remat_bw)), + ] + # Select the best options + for pair_option in pair_options: + fw = pair_option[0] + bw = pair_option[1] + + pair_cost_time = 0 + pair_cost_mem = 0 + t, m, _ = benchmark_trace(fw, iters=self.benchmark_iters) + log(f"Pair fw time: {t}, mem: {m}", level=LogLevel.INFO) + pair_cost_time = pair_cost_time + t + pair_cost_mem = pair_cost_mem + m + t, m, _ = benchmark_trace(bw, iters=self.benchmark_iters) + log(f"Pair bw time: {t}, mem: {m}", level=LogLevel.INFO) + pair_cost_time = pair_cost_time + t + pair_cost_mem = pair_cost_mem + m + + if pair_cost_time < min_value_time: + best_pair_runtime = OutputCandidate(fw=fw, bw=bw, cost=pair_cost_time) + log(f"New best runtime pair (no remat):\n{best_pair_runtime}", level=LogLevel.INFO) + min_value_time = pair_cost_time + + if pair_cost_mem < min_value_mem: + best_pair_memory = OutputCandidate(fw=fw, bw=bw, cost=pair_cost_mem) + log(f"New best memory pair (no remat):\n{best_pair_memory}", level=LogLevel.INFO) + min_value_mem = pair_cost_mem return best_pair_runtime, best_pair_memory @@ -300,10 +320,10 @@ def fw_benchmark(): f"Trace name = [{label}] - Target: MEM - Mem = {m / (2**30)} GB - Time = {c} ms\n{trc_mem}\n\n" ) # For forward trace we cache the best placement for both runtime and memory for the current Fusion executor (represented by label) - if compile_opt_time is not None: - print(f"Caching fw with compile options time: {compile_opt_time.fusion_tag}") - if compile_opt_mem is not None: - print(f"Caching fw with compile options mem: {compile_opt_mem.fusion_tag}") + # if compile_opt_time is not None: + # print(f"Caching fw with compile options time: {compile_opt_time.fusion_tag}") + # if compile_opt_mem is not None: + # print(f"Caching fw with compile options mem: {compile_opt_mem.fusion_tag}") for t, o in zip([trc_time, trc_mem], [compile_opt_time, compile_opt_mem]): print(f'Caching fw candidate\n{t}\nwith option {o.fusion_tag if o else "None"}') @@ -456,7 +476,6 @@ def _search(ex: FusionExecutor, executor_compile_option: FusionCompileOptionsHel """ Fusable fn definition for nvFuser """ - # Each executor has a custom should fuse function, but the current impl need to access local executor object def _should_fuse_nvfuser(a: Node, b: Node): def _can_fuse_node(n: Node): @@ -473,7 +492,6 @@ def _can_fuse_node(n: Node): """ Fusable fn definition for torch.compile """ - def _should_fuse_torchcompile(a: Node, b: Node): def _can_fuse_node(n: Node): if len(n.group_bsyms) > 1: @@ -498,8 +516,14 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): else: raise AssertionError(f"Type not handled: {type(bsym_in.output)}") + merge_fn: Callable + match ex.name: + case 'nvfuser': + merge_fn = _should_fuse_nvfuser + case 'torchcompile': + merge_fn = _should_fuse_torchcompile bound_symbol_groups = fuse_bound_symbols( - self.trace, _should_fuse_nvfuser if ex.name == "nvfuser" else _should_fuse_torchcompile + self.trace, merge_fn ) log(f"Num of groups = {len(bound_symbol_groups)}", level=LogLevel.DEBUG) @@ -554,6 +578,7 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): candidate_best_mem = BenchmarkResult() # Search for best candidate, by default remat will be called to find the optimal choice # TODO: enable requests for no remat becnhmarks + # TODO: we should consider also FusionExecutor that can execute this single bsym in this beam search for i, candidate in enumerate(candidate_executors): # Match the current candidate to benchmark partial trace match_bsym_output(current_bsym, [dict_time_strat, dict_mem_strat], candidate) @@ -642,7 +667,7 @@ def measure_and_update_result(): n_missing_bsyms = len(group) - start_idx # TODO (matteochen): consider to add the iteration with no fusion regions - for i in range(0, n_missing_bsyms, n_missing_bsyms - 1 if self.trace_type == TraceType.BW else 1): + for i in range(0, n_missing_bsyms, 1 if self.trace_type == TraceType.BW else 1): # for i in range(0, n_missing_bsyms): # From top to bottom (this will include the whole region) # -> First iteration is the one with fusion region with single element @@ -666,13 +691,21 @@ def measure_and_update_result(): # From bottom to up (this will exclude the full region as being handled in the for cycle above) # -> First iteration is the one with len(fusion_region) - 1 # -> Last iteration gives no fusion regions - # for j in range(0, i+1, increment_factor): - # dict_time_strat[group[j].output.name] = get_default_executor(group[j]) - # for k in range(i+1, len(group), increment_factor): - # dict_time_strat[group[k].output.name] = ex + for j in range(start_idx, start_idx + i + 1, increment_factor): + match_bsym_output( + group[j], + [dict_time_strat, dict_mem_strat], + get_first_available_operator_executor( + bsym=group[j], + executors=self.executors, + empty_hash=self.empty_executor_hashable_placeholder, + ), + ) + for k in range(start_idx + i + 1, len(group), increment_factor): + match_bsym_output(group[k], [dict_time_strat, dict_mem_strat], ex) # Benchmark this placement - # measure_and_update_result() + measure_and_update_result() if best_placement_time is None or best_keys_time is None: raise AssertionError("Failed to get best time placement") @@ -772,7 +805,8 @@ def measure_and_update_result(): ex: FusionExecutor for ex in self.fusion_executors: if ex.name not in self.fusion_strat_helper.supported_executors: - raise AssertionError(f"Fusion operator not supported: {ex.name}") + # log(f"Fusion operator not supported: {ex.name}. Skipping it.") + continue log(f"Searching best placement for fusion executor = {ex.name}", level=LogLevel.INFO) @@ -787,7 +821,7 @@ def measure_and_update_result(): _search(ex) # Currently we are enabling one compile option at the time as testing all the permutations might need too much time. - # TODO: Consider implementing patters based on the executor under investingation + # TODO: Consider implementing patterns based on the executor under investingation if ex_compile_opts: for opt in ex_compile_opts: # Search only if we have an instruction related to the compile option @@ -824,7 +858,7 @@ def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): log( f"New forward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}", level=LogLevel.INFO ) - # TODO (matteochen): support bw trace optimization even though with no fw traces cached + # TODO (matteochen): support bw trace optimization even though with no fw traces cached (computational trace?) case TraceType.BW: if not self.cached_fw_traces: raise AssertionError("Can not optimize backward traces before forward traces") @@ -894,6 +928,7 @@ def _optimize(): self.cached_fw_traces = [] _optimize() # We have multiple cached optimized fw traces, find the best backward + # TODO: make this prettier with a machine state for example case TraceType.BW: # Clear any previous results self.out_traces_candidates = [] @@ -950,8 +985,10 @@ def __init__( optimizer_algorithm: OptimizationAlgorithm = OptimizationAlgorithm.BEST_FUSER, compile_data, ) -> None: - self.optimizer = ( - FusionPlacer( + if optimizer_algorithm != OptimizationAlgorithm.BEST_FUSER: + raise AssertionError(f'Optimization {optimizer_algorithm} not implemented') + self.optimizer: PlacerBase = ( + FusionPlacer_BeamSearch( priority_executors=priority_executors, produce_log=produce_log, apply_bucketing_bw_trace=apply_bucketing_bw_trace, @@ -960,8 +997,6 @@ def __init__( optimizer_type=optimizer_type, compile_data=compile_data, ) - if optimizer_algorithm == OptimizationAlgorithm.BEST_FUSER - else None ) log("Executors:", level=LogLevel.INFO) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index e10d20d93e..a2d3987805 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -1,6 +1,6 @@ from collections.abc import Callable, Hashable, Sequence from typing import Any -from thunder.core.dtypes import dtype +from thunder.core.dtypes import dtype, to_torch_dtype from thunder.core.prims import PrimIDs from thunder.core.proxies import CollectionProxy, FloatProxy, IntegerProxy, Proxy, TensorProxy, Variable, variableify from thunder.core.symbol import BoundSymbol, Symbol @@ -281,10 +281,13 @@ def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, f try: warm_up_iters = 50 torch.cuda.empty_cache() + torch.cuda.synchronize() # Warm up cycles for _ in range(warm_up_iters): fn(*args) # Benchmark + torch.cuda.empty_cache() + torch.cuda.synchronize() torch.cuda.cudart().cudaProfilerStart() for i in range(iters): torch.cuda.nvtx.range_push(f"{nvsight_fn_name}-iter{i}") @@ -304,13 +307,14 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl try: warm_up_iters = 50 out = None - torch.cuda.empty_cache() # Warm up cycles for _ in range(warm_up_iters): fn(*args) # Snapshot request if snapshot: + torch.cuda.empty_cache() + torch.cuda.synchronize() torch.cuda.memory._record_memory_history() fn(*args) torch.cuda.memory._dump_snapshot(snapshot_name + "_benchmark.pickle") @@ -318,27 +322,23 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl # Benchmark stream = torch.cuda.current_stream() max_allocated_bytes = 0 - torch.cuda.synchronize() - warm_up_iters = 10 start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - for i in range(iters + warm_up_iters): - if i >= warm_up_iters: - torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) - torch.cuda.empty_cache() - torch.cuda._sleep(1_000_000) - start_events[i-warm_up_iters].record(stream) - fn(*args) - end_events[i-warm_up_iters].record(stream) - max_allocated_bytes = max( - max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device()) - ) - else: - fn(*args) + torch.cuda.synchronize() + for i in range(iters): + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + start_events[i].record(stream) + fn(*args) + end_events[i].record(stream) + max_allocated_bytes = max( + max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device()) + ) torch.cuda.synchronize() times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] - print(f"times: {times}") + # print(f"times: {times}") tot_time = sum(times) / iters return tot_time, max_allocated_bytes, out except Exception as e: @@ -420,17 +420,19 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: shape = arg.shape device = arg.device requires_grad = arg.requires_grad - if dtype is not None and is_float_dtype(dtype): - torch_dtype = thunder_to_torch_float_dtype(dtype, dtype.bytes) + + torch_dtype = to_torch_dtype(dtype) + if torch_dtype is None: + raise AssertionError(f'Unrecognized thunder dtype: {dtype}') + if is_float_dtype(dtype): tensor: torch.Tensor = torch.randn( shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad ) - elif dtype is not None and is_signedinteger_dtype(dtype): - torch_dtype = thunder_to_torch_int_dtype(dtype.bytes) + elif is_signedinteger_dtype(dtype): tensor: torch.Tensor = torch.randint( 0, 8, shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad ) - elif dtype is not None and is_boolean_dtype(dtype): + elif is_boolean_dtype(dtype): # TODO (matteochen): maybe random? tensor: torch.Tensor = torch.zeros( *shape, dtype=torch.bool, device=device.device_str(), requires_grad=requires_grad diff --git a/thunder/benchmarks/utils.py b/thunder/benchmarks/utils.py index 54d0806e11..b2b5843cd2 100644 --- a/thunder/benchmarks/utils.py +++ b/thunder/benchmarks/utils.py @@ -3,35 +3,22 @@ warm_up_iters = 50 -def torch_fw_bw_benchmark_nvsight(models: list, torch_module: torch.nn.Module | None, labels: list, inputs: list, iters: int, int_input_tensor: bool = False) -> None: +def torch_fw_bw_benchmark_nvsight(models: list, labels: list, inputs: list, iters: int) -> None: for m, input, label in zip(models, inputs, labels): # Warm up - for _ in range(10): + for _ in range(warm_up_iters): y = m(input) - # Not supported by autograd - if int_input_tensor: - torch.autograd.grad(y.sum(), torch_module.parameters()) - else: - torch.autograd.grad(y, input, grad_outputs=torch.ones_like(y)) + y.sum().backward() + torch.cuda.empty_cache() + torch.cuda.synchronize() torch.cuda.cudart().cudaProfilerStart() - for i in range(iters): + for _ in range(iters): torch.cuda.empty_cache() - torch.cuda.nvtx.range_push(f'iteration_nvsight-{label}') - torch.cuda.nvtx.range_push("fw_nvsight") + torch.cuda.nvtx.range_push(f"{label}: fw-bw") y = m(input) - torch.cuda.nvtx.range_pop() - # Not supported by autograd - if int_input_tensor: - torch.cuda.nvtx.range_push("bw_nvsight") - torch.autograd.grad(y.sum(), torch_module.parameters()) - torch.cuda.nvtx.range_pop() - else: - torch.cuda.nvtx.range_push("bw_nvsight") - torch.autograd.grad(y, input, grad_outputs=torch.ones_like(y)) - torch.cuda.nvtx.range_pop() - + y.sum().backward() torch.cuda.nvtx.range_pop() torch.cuda.cudart().cudaProfilerStop() diff --git a/thunder/common.py b/thunder/common.py index e0fc05439b..95414015bd 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -195,7 +195,7 @@ def __init__( using_jit: bool = False, only_execute_prims: bool = False, disable_preprocessing: bool = False, - use_cudagraphs: bool = False, + use_cudagraphs: bool | None = None, disable_torch_autograd_support: bool = False, use_rematerialization: bool = False, debug_log: None | StringIO = None, diff --git a/thunder/executors/cudagraphex.py b/thunder/executors/cudagraphex.py index 963acdb497..95182315b8 100644 --- a/thunder/executors/cudagraphex.py +++ b/thunder/executors/cudagraphex.py @@ -94,7 +94,14 @@ def __call__(self, *args): for static_input, arg in utils.safe_zip(static_inputs, args): if id(static_input) != id(arg) and isinstance(static_input, torch.Tensor) and isinstance(arg, torch.Tensor): - static_input.copy_(arg) + try: + static_input.copy_(arg) + except RuntimeError as e: + if ( + "unsupported operation: more than one element of the written-to tensor refers to a single memory location. Please clone() the tensor before performing the operation." + in str(e) + ): + static_input.clone().copy_(arg) graph.replay() diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 2b212f4f3a..7f66f3758d 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -111,8 +111,8 @@ def update_bw_from_forward_optimization(*, fw: TraceCtx, bw:TraceCtx) -> TraceCt # Some of the optimization passes change proxies in the trace and # any change in the forward trace must be reflected in the backward # trace. - original_bw_saved_tensors_for_backward = bw_trace.args[0][0] - new_fw_saved_tensors_for_backward = get_saved_for_backward_tensors(fw_extrace) + original_bw_saved_tensors_for_backward = bw.args[0][0] + new_fw_saved_tensors_for_backward = get_saved_for_backward_tensors(fw) swap_map = { variableify(x): y for x, y in zip(original_bw_saved_tensors_for_backward, new_fw_saved_tensors_for_backward) @@ -352,6 +352,7 @@ def split(): visualizer.set_fw_final_trace(fw_extrace) visualizer.set_bw_final_trace(bw_extrace) + # TODO: implement new visualizer # visualizer.produce() if autotune_type is None: @@ -484,27 +485,17 @@ def split(): compile_stats.last_backward_traces += bw_traces return fw_extrace, bw_extrace - except AssertionError as e: - print(f'Exception occured: {e}') + except Exception as e: + import traceback + print(f'Exception occured:\n{e}\nTraceback:') + traceback.print_exc() # Restore before calling split compile_data.executors_list = list(cached_executor_list) log( - f"================================================================================ Before Autotune Tuning: exception occured, not autotuned split_forward_backward from {executors_candidates}", level=LogLevel.INFO) - primal_trace, fw_extrace, bw_extrace, fw_traces, bw_traces = split() - if compile_stats is not None: - compile_stats.last_traces.append(primal_trace) - compile_stats.last_traces += fw_traces - compile_stats.last_backward_traces += bw_traces - - return fw_extrace, bw_extrace - except RuntimeError as e: - print(f'Exception occured: {e}') - # Restore before calling split - compile_data.executors_list = list(cached_executor_list) - - log( - f"================================================================================ Before Autotune Tuning: exception occured, not autotuned split_forward_backward from {executors_candidates}", level=LogLevel.INFO) + f"================================================================================ Before Autotune Tuning: exception occured, executors:\n{executors_candidates} will not be autotuned", + level=LogLevel.INFO, + ) primal_trace, fw_extrace, bw_extrace, fw_traces, bw_traces = split() if compile_stats is not None: compile_stats.last_traces.append(primal_trace) From 31c781e6251bbdbb2d5177528946e1e0af2abe7d Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 13 Aug 2024 17:49:18 +0300 Subject: [PATCH 040/171] Restore nanogpt config --- examples/dev/nanogpt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dev/nanogpt.py b/examples/dev/nanogpt.py index 361d795d52..4cb828345a 100644 --- a/examples/dev/nanogpt.py +++ b/examples/dev/nanogpt.py @@ -11,7 +11,7 @@ def run(target: str = 'runtime'): raise AssertionError(f'Target {target} not supported. Only runtime and memory available') # ----------------------------------------------------------------------------- batch_size = 12 - block_size = 512 + block_size = 1024 bias = False real_data = False seed = 1337 @@ -42,7 +42,7 @@ def run(target: str = 'runtime'): # model init gptconf = GPTConfig( block_size = block_size, # how far back does the model look? i.e. context size - n_layer = 1, n_head = 6, n_embd = 768, # size of the model + n_layer = 4, n_head = 12, n_embd = 768, # size of the model dropout = 0, # for determinism bias = bias, ) From 03188f8a6ec2aacca2642cef63e5393ecc6b26d5 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 14 Aug 2024 10:20:23 +0300 Subject: [PATCH 041/171] Added nvsight to bench --- examples/dev/litGPT.py | 4 +++- examples/dev/nanogpt.py | 48 ++++++++++++++++++++++++++--------------- 2 files changed, 34 insertions(+), 18 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 3cb68b9993..8d3aad2b7d 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -1,5 +1,5 @@ from litgpt import GPT -from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_total_benchmark +from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_fw_bw_benchmark_nvsight, torch_total_benchmark from thunder.tests.litgpt_model import Config import thunder import torch @@ -44,6 +44,8 @@ def __init__(self, layers: int, autotune_type: str, batch_size: int) -> None: print(f'\n\nResults torch total benchmark ({iters} iters):') torch_total_benchmark(callables, labels, inputs, iters) + torch_fw_bw_benchmark_nvsight(callables, labels, inputs, iters) + print('\n\n\n\n\n\n') print(f'{thunder.last_traces(jmodel_def)[-1]}') print('###############################################################################') diff --git a/examples/dev/nanogpt.py b/examples/dev/nanogpt.py index 4cb828345a..56795829fd 100644 --- a/examples/dev/nanogpt.py +++ b/examples/dev/nanogpt.py @@ -42,16 +42,16 @@ def run(target: str = 'runtime'): # model init gptconf = GPTConfig( block_size = block_size, # how far back does the model look? i.e. context size - n_layer = 4, n_head = 12, n_embd = 768, # size of the model + n_layer = 12, n_head = 12, n_embd = 768, # size of the model dropout = 0, # for determinism bias = bias, ) model = GPT(gptconf) model.to(device) - jmodel_def = thunder.jit(model, use_cudagraphs=True) + jmodel_def = thunder.jit(model) # Currently sdpa does not work? - jmodel_auto = thunder.jit(model, autotune_type=target, executors = ['torchcompile', 'nvfuser', 'cudnn', 'torch', 'python'], use_cudagraphs=True) + jmodel_auto = thunder.jit(model, autotune_type=target, executors = ['torchcompile', 'nvfuser', 'cudnn', 'torch', 'python']) if compile: print("Compiling model...") @@ -90,30 +90,28 @@ def run(target: str = 'runtime'): prof.step() # notify the profiler at end of each step else: + # simple benchmarking def measure(m, label): - # simple benchmarking + iters = 100 torch.cuda.synchronize() - X, Y = get_batch('train') for i in range(warm_up_iters): + X, Y = get_batch('train') with ctx: _, loss = m(X, Y) - X, Y = get_batch('train') loss.backward() - iters = 100 start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] stream = torch.cuda.current_stream() torch.cuda.synchronize() - X, Y = get_batch('train') for i in range(iters): torch.cuda.empty_cache() torch.cuda._sleep(1_000_000) + X, Y = get_batch('train') start_events[i].record(stream) with ctx: _, loss = m(X, Y) - X, Y = get_batch('train') loss.backward() end_events[i].record(stream) @@ -123,6 +121,28 @@ def measure(m, label): print('\n\nResults torch benchmark:') print(f'{label} tot time: {tot_time} ms') + def measure_nvsight(m, label): + # Warm up + for _ in range(warm_up_iters): + X, Y = get_batch('train') + with ctx: + _, loss = m(X, Y) + loss.backward() + + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.cudart().cudaProfilerStart() + # Perform less iterations + for _ in range(20): + torch.cuda.empty_cache() + X, Y = get_batch('train') + torch.cuda.nvtx.range_push(f"{label}: fw-bw") + with ctx: + _, loss = m(X, Y) + loss.sum().backward() + torch.cuda.nvtx.range_pop() + torch.cuda.cudart().cudaProfilerStop() + measure(jmodel_auto, 'auto') measure(jmodel_def, 'def') @@ -138,14 +158,8 @@ def measure(m, label): labels.reverse() thunder_fw_bw_benchmark(traces, labels, 100) - # X, Y = get_batch('train') - # out_eager = model(X, Y) - # out_def = jmodel_def(X, Y) - # out_auto = jmodel_auto(X, Y) - # for a, b in zip(out_eager, out_def): - # print('deviation def:', (a - b).abs().max().item()) - # for a, b in zip(out_eager, out_auto): - # print('deviation auto:', (a - b).abs().max().item()) + measure_nvsight(jmodel_def, 'def') + measure_nvsight(jmodel_auto, 'auto') traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] for t in traces: From 74c57601def2b58848cf693618049050a3a28a9e Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 14 Aug 2024 10:21:25 +0300 Subject: [PATCH 042/171] Restored old value --- thunder/backend_optimizer/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 00fbb0f16a..e1c51b8fe0 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -667,7 +667,7 @@ def measure_and_update_result(): n_missing_bsyms = len(group) - start_idx # TODO (matteochen): consider to add the iteration with no fusion regions - for i in range(0, n_missing_bsyms, 1 if self.trace_type == TraceType.BW else 1): + for i in range(0, n_missing_bsyms, n_missing_bsyms-1 if self.trace_type == TraceType.BW else 1): # for i in range(0, n_missing_bsyms): # From top to bottom (this will include the whole region) # -> First iteration is the one with fusion region with single element From 51f45922da7af49bc01b2b42a86704b56afc5193 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 14 Aug 2024 11:52:48 +0300 Subject: [PATCH 043/171] Updated nsight iter / updated test models --- examples/dev/MLP.py | 21 +++++++++++++++------ examples/dev/nvfuser_optimizations.py | 10 +++++++--- thunder/backend_optimizer/utils.py | 1 + thunder/benchmarks/utils.py | 4 ++-- 4 files changed, 25 insertions(+), 11 deletions(-) diff --git a/examples/dev/MLP.py b/examples/dev/MLP.py index ef0bb2a7d6..9ae0fcfd13 100644 --- a/examples/dev/MLP.py +++ b/examples/dev/MLP.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import thunder -from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_fw_bw_benchmark_nvsight +from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_fw_bw_benchmark_nvsight, torch_total_benchmark class ModelConfig: def __init__(self, n_embd=256, n_head=8, dropout=0.1, block_size=64, bias=True): @@ -28,7 +28,7 @@ def forward(self, x): with torch.device('cuda'): embeddings = 3072 - config = ModelConfig(n_embd=embeddings) + config = ModelConfig(n_embd=embeddings, dropout=0.0, bias=False) dtype = torch.float32 x = torch.randn(16, 1024, embeddings, requires_grad=True) @@ -36,23 +36,32 @@ def forward(self, x): jmodel_def = thunder.jit(model) # This model fails under some circumstances after passed the placed traced under the rematelizer - jmodel_auto = thunder.jit(model, autotune_type='memory', executors = ['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python']) + jmodel_auto = thunder.jit(model, autotune_type='runtime', executors = ['nvfuser', 'torchcompile', 'sdpa', 'torch', 'python'], use_cudagraphs=False) print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) + iters = 100 callables = [jmodel_auto, jmodel_def] labels = ['auto', 'def'] inputs = [x, x] print('Results with torch fw bw benchmark:') - torch_fw_bw_benchmark(callables, labels, inputs, 5) + torch_fw_bw_benchmark(callables, labels, inputs, iters) + torch_total_benchmark(callables, labels, inputs, iters) + torch_fw_bw_benchmark_nvsight(callables, labels, inputs, iters) print('Results with thunder benchmark:') - traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] + traces = [ + thunder.last_traces(jmodel_def)[-1], + thunder.last_traces(jmodel_auto)[-1], + thunder.last_backward_traces(jmodel_def)[-1], + thunder.last_backward_traces(jmodel_auto)[-1], + ] traces.reverse() labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] labels.reverse() - thunder_fw_bw_benchmark(traces, labels, 5, nvsight = False) + thunder_fw_bw_benchmark(traces, labels, iters, nvsight = False) + thunder_fw_bw_benchmark(traces, labels, iters, nvsight = True) # for t in traces: # print(t) diff --git a/examples/dev/nvfuser_optimizations.py b/examples/dev/nvfuser_optimizations.py index 95abc755ba..487b1b24aa 100644 --- a/examples/dev/nvfuser_optimizations.py +++ b/examples/dev/nvfuser_optimizations.py @@ -1,6 +1,6 @@ import torch import thunder -from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_fw_bw_benchmark_nvsight +from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_fw_bw_benchmark_nvsight, torch_total_benchmark class Module(torch.nn.Module): def __init__(self, in_features, out_features) -> None: @@ -32,6 +32,7 @@ def forward(self, x: torch.Tensor): y = jmodel_def(x) y = jmodel_auto(x) + iters = 100 print('Results thunder benchmark:') traces = [ thunder.last_traces(jmodel_def)[-1], @@ -40,13 +41,16 @@ def forward(self, x: torch.Tensor): thunder.last_backward_traces(jmodel_auto)[-1], ] labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] - thunder_fw_bw_benchmark(traces, labels, 50) + thunder_fw_bw_benchmark(traces, labels, iters) + thunder_fw_bw_benchmark(traces, labels, iters, nvsight=True) callables = [jmodel_def, jmodel_auto] labels = ['def', 'auto'] inputs = [x, x] print('Results torch benchmark:') - torch_fw_bw_benchmark(callables, labels, inputs, 50) + torch_fw_bw_benchmark(callables, labels, inputs, iters) + torch_total_benchmark(callables, labels, inputs, iters) + torch_fw_bw_benchmark_nvsight(callables, labels, inputs, iters) for t in traces: print(f'{t}\n#########################################') diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index a2d3987805..a9b349156e 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -290,6 +290,7 @@ def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, f torch.cuda.synchronize() torch.cuda.cudart().cudaProfilerStart() for i in range(iters): + torch.cuda.empty_cache() torch.cuda.nvtx.range_push(f"{nvsight_fn_name}-iter{i}") fn(*args) torch.cuda.nvtx.range_pop() diff --git a/thunder/benchmarks/utils.py b/thunder/benchmarks/utils.py index b2b5843cd2..598eab28c8 100644 --- a/thunder/benchmarks/utils.py +++ b/thunder/benchmarks/utils.py @@ -14,9 +14,9 @@ def torch_fw_bw_benchmark_nvsight(models: list, labels: list, inputs: list, iter torch.cuda.empty_cache() torch.cuda.synchronize() torch.cuda.cudart().cudaProfilerStart() - for _ in range(iters): + for i in range(iters): torch.cuda.empty_cache() - torch.cuda.nvtx.range_push(f"{label}: fw-bw") + torch.cuda.nvtx.range_push(f"{label}: fw-bw iter {i}") y = m(input) y.sum().backward() torch.cuda.nvtx.range_pop() From 0a101b1d2feed6ffe18566bb0f14e10b18a931b0 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 14 Aug 2024 11:53:10 +0300 Subject: [PATCH 044/171] Using cuda graphs only is not disabled --- thunder/backend_optimizer/optimizer.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index e1c51b8fe0..1dd0491bd3 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -238,7 +238,6 @@ def __init__( def _best_runtime_and_memory_candidates(self, candidates): from thunder.core.rematerialization import rematerialize_forward_and_backward from thunder.backend_optimizer.utils import benchmark_trace - from thunder.executors.cudagraphex import cudagraphex min_value_time: float = float("inf") min_value_mem: float = float("inf") @@ -255,10 +254,17 @@ def _best_runtime_and_memory_candidates(self, candidates): # Create pair final options by applying final optimizations: cudagraphs and rematerialization pair_options: list[tuple[TraceCtx, TraceCtx]] = [ (pair.fw, pair.bw), - (cudagraphex.fusion_pass(pair.fw), cudagraphex.fusion_pass(pair.bw)), (remat_fw, remat_bw), - (cudagraphex.fusion_pass(remat_fw), cudagraphex.fusion_pass(remat_bw)), ] + # We want to verify that it is not set to false + if self.compile_data.use_cudagraphs is None or self.compile_data.use_cudagraphs == True: + from thunder.executors.cudagraphex import cudagraphex + pair_options.extend( + [ + (cudagraphex.fusion_pass(pair.fw), cudagraphex.fusion_pass(pair.bw)), + (cudagraphex.fusion_pass(remat_fw), cudagraphex.fusion_pass(remat_bw)), + ] + ) # Select the best options for pair_option in pair_options: fw = pair_option[0] From 940a704400c029dda19de6702948424a53aae600 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 14 Aug 2024 16:18:34 +0300 Subject: [PATCH 045/171] Updated bench fn --- thunder/backend_optimizer/utils.py | 2 +- thunder/benchmarks/utils.py | 27 +++++++++++++++++---------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index a9b349156e..eed92df03f 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -291,7 +291,7 @@ def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, f torch.cuda.cudart().cudaProfilerStart() for i in range(iters): torch.cuda.empty_cache() - torch.cuda.nvtx.range_push(f"{nvsight_fn_name}-iter{i}") + torch.cuda.nvtx.range_push(f"thunder benchmark fn:{nvsight_fn_name}, iter{i}") fn(*args) torch.cuda.nvtx.range_pop() torch.cuda.cudart().cudaProfilerStop() diff --git a/thunder/benchmarks/utils.py b/thunder/benchmarks/utils.py index 598eab28c8..8b0109ecd9 100644 --- a/thunder/benchmarks/utils.py +++ b/thunder/benchmarks/utils.py @@ -16,9 +16,14 @@ def torch_fw_bw_benchmark_nvsight(models: list, labels: list, inputs: list, iter torch.cuda.cudart().cudaProfilerStart() for i in range(iters): torch.cuda.empty_cache() - torch.cuda.nvtx.range_push(f"{label}: fw-bw iter {i}") + torch.cuda.nvtx.range_push(f"torch training {label}, iter {i}") + torch.cuda.nvtx.range_push('forward') y = m(input) - y.sum().backward() + torch.cuda.nvtx.range_pop() + loss = y.sum() + torch.cuda.nvtx.range_push('backward') + loss.backward() + torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop() torch.cuda.cudart().cudaProfilerStop() @@ -30,15 +35,15 @@ def torch_fw_bw_benchmark(models: list, labels: list, inputs: list, iters: int) y = m(input) y.sum().backward() - torch.cuda.synchronize() start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] stream = torch.cuda.current_stream() max_allocated_bytes = 0 + torch.cuda.synchronize() for i in range(iters): - torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) torch.cuda.empty_cache() torch.cuda._sleep(1_000_000) + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) start_events[i].record(stream) y = m(input) @@ -55,19 +60,20 @@ def torch_fw_bw_benchmark(models: list, labels: list, inputs: list, iters: int) print(f'{label} tot fw time: {tot_time} ms') print(f'{label} max fw allocated memory: {max_allocated_bytes / (2**30)} GB') - torch.cuda.synchronize() start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] stream = torch.cuda.current_stream() max_allocated_bytes = 0 + torch.cuda.synchronize() for i in range(iters): - torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) torch.cuda.empty_cache() torch.cuda._sleep(1_000_000) y = m(input) + loss = y.sum() + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) start_events[i].record(stream) - y.sum().backward() + loss.backward() end_events[i].record(stream) max_allocated_bytes = max( @@ -91,17 +97,18 @@ def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int) start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - torch.cuda.synchronize() stream = torch.cuda.current_stream() max_allocated_bytes = 0 + torch.cuda.synchronize() for i in range(iters): - torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) torch.cuda.empty_cache() torch.cuda._sleep(1_000_000) + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) start_events[i].record(stream) y = m(input) - y.sum().backward() + loss = y.sum() + loss.backward() end_events[i].record(stream) max_allocated_bytes = max( From 9ef9d0f0c5cba3769792d4583af8420f208f62e5 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 14 Aug 2024 16:35:51 +0300 Subject: [PATCH 046/171] Fixed nanogpt test --- examples/dev/nanogpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dev/nanogpt.py b/examples/dev/nanogpt.py index 56795829fd..9de4111e08 100644 --- a/examples/dev/nanogpt.py +++ b/examples/dev/nanogpt.py @@ -139,7 +139,7 @@ def measure_nvsight(m, label): torch.cuda.nvtx.range_push(f"{label}: fw-bw") with ctx: _, loss = m(X, Y) - loss.sum().backward() + loss.backward() torch.cuda.nvtx.range_pop() torch.cuda.cudart().cudaProfilerStop() From 20992d6f4272fc27a9362b0a4d3d505024c64647 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 14 Aug 2024 17:41:47 +0300 Subject: [PATCH 047/171] Updated log --- thunder/executors/cudagraphex.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/thunder/executors/cudagraphex.py b/thunder/executors/cudagraphex.py index 95182315b8..77b740c485 100644 --- a/thunder/executors/cudagraphex.py +++ b/thunder/executors/cudagraphex.py @@ -94,15 +94,7 @@ def __call__(self, *args): for static_input, arg in utils.safe_zip(static_inputs, args): if id(static_input) != id(arg) and isinstance(static_input, torch.Tensor) and isinstance(arg, torch.Tensor): - try: - static_input.copy_(arg) - except RuntimeError as e: - if ( - "unsupported operation: more than one element of the written-to tensor refers to a single memory location. Please clone() the tensor before performing the operation." - in str(e) - ): - static_input.clone().copy_(arg) - + static_input.copy_(arg) graph.replay() return static_outputs From 0ce7921dd0f4f4e637366e2ecd8d7b7478898e4b Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 14 Aug 2024 17:43:07 +0300 Subject: [PATCH 048/171] Updated log --- thunder/backend_optimizer/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 1dd0491bd3..0c54e69b58 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -332,7 +332,7 @@ def fw_benchmark(): # print(f"Caching fw with compile options mem: {compile_opt_mem.fusion_tag}") for t, o in zip([trc_time, trc_mem], [compile_opt_time, compile_opt_mem]): - print(f'Caching fw candidate\n{t}\nwith option {o.fusion_tag if o else "None"}') + log(f'Caching fw candidate [compile option: {o.fusion_tag if o else "None"}]\n{t}') self.cached_fw_traces.append( TraceCandidate(trace=t, compile_opt=o, label=label + '_enabled_' + o.fusion_tag if o is not None else label) ) From dcb6326315b17c601107e884b0b6ad23fae0c160 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 14 Aug 2024 17:45:56 +0300 Subject: [PATCH 049/171] Added gitignore --- examples/dev/.gitignore | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 examples/dev/.gitignore diff --git a/examples/dev/.gitignore b/examples/dev/.gitignore new file mode 100644 index 0000000000..4ec10c2e86 --- /dev/null +++ b/examples/dev/.gitignore @@ -0,0 +1,4 @@ +*.log +*.txt +*.pickle +*.nsys-rep From 9f6ccb19e620a1bafe7736c33e999591956f1772 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 14 Aug 2024 17:46:40 +0300 Subject: [PATCH 050/171] Added fa3 to autotuner --- thunder/executors/torch_autograd.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 7f66f3758d..4b6ad4fb2d 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -364,10 +364,11 @@ def split(): # TODO (matteochen): integrate Transofrmer Engine from thunder.executors.sdpaex import sdpa_ex from thunder.executors.cudnnex import cudnn_ex + from thunder.executors.fa3ex import fa3_ex from thunder.executors.transformer_engineex import transformer_engine_ex executors_candidates: dict[str, list] = { - 'scaled_dot_product_attention': [sdpa_ex.name, cudnn_ex.name], + 'scaled_dot_product_attention': [sdpa_ex.name, cudnn_ex.name, fa3_ex.name], 'linear_layer': [transformer_engine_ex.name] } @@ -405,6 +406,8 @@ def split(): f"================================================================================ Before Autotune Tuning: Skipping optimization for {ex_type} as not requested.", level=LogLevel.INFO) + # TODO: do we want to give a chance to the configuration with no executor? + # e.g. assigning scaled_dot_product_attention to torch and not to sdpa / cudnn for e in to_benchmark: compile_data.executors_list = [ex for ex in cached_executor_list if ex not in to_benchmark] # Make it with most priority From 57e6163fdce99382923d8bfb26bc7d7af4a3f19c Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 14 Aug 2024 17:47:06 +0300 Subject: [PATCH 051/171] Fixed log --- thunder/backend_optimizer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index eed92df03f..6a269396f1 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -518,7 +518,7 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: except Exception as e: print(f"Compiled trace execution still failed:\n{e}") else: - print(f"Unknown exception occured:\n{e}") + print(f"Unhandled exception occured:\n{e}") finally: reset_tracectx(trace_tok) From cc96535b55d6c77a29b73c83ade0e4a94624ef08 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 14 Aug 2024 18:18:20 +0300 Subject: [PATCH 052/171] Disabled te for now / updated exceptions log --- examples/dev/LLaMAMLP.py | 22 ++++++++++++------ examples/dev/nanogpt.py | 2 +- thunder/backend_optimizer/utils.py | 36 +++++++++++++---------------- thunder/executors/torch_autograd.py | 7 +++--- 4 files changed, 36 insertions(+), 31 deletions(-) diff --git a/examples/dev/LLaMAMLP.py b/examples/dev/LLaMAMLP.py index 074c6e6a58..dd9b3ce262 100644 --- a/examples/dev/LLaMAMLP.py +++ b/examples/dev/LLaMAMLP.py @@ -1,6 +1,6 @@ import torch import thunder -from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark +from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_fw_bw_benchmark_nvsight, torch_total_benchmark class LLaMAMLP(torch.nn.Module): def __init__(self, n_embd, intermediate_size) -> None: @@ -15,7 +15,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.proj(x) with torch.device('cuda'): - # See changes from mult = 1 to mult = 4 mult = 1 a = 4096 * mult b = 11008 * mult @@ -24,21 +23,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: model = LLaMAMLP(a, b) jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type='runtime', executors = ['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python']) + jmodel_auto = thunder.jit(model, autotune_type='runtime', use_cudagraphs=False) print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) + iters = 100 print('Results with thunder benchmark:') - traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] + traces = [ + thunder.last_traces(jmodel_def)[-1], + thunder.last_traces(jmodel_auto)[-1], + thunder.last_backward_traces(jmodel_def)[-1], + thunder.last_backward_traces(jmodel_auto)[-1], + ] labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] - thunder_fw_bw_benchmark(traces, labels, 50, nvsight = False) + thunder_fw_bw_benchmark(traces, labels, iters, nvsight = False) callables = [jmodel_def, jmodel_auto] labels = ['def', 'auto'] inputs = [x, x] - print('Results with torch fw bw benchmark:') - torch_fw_bw_benchmark(callables, labels, inputs, 50) + print('\nResults with torch fw bw benchmark:') + torch_fw_bw_benchmark(callables, labels, inputs, iters) + print('\nResults with torch total benchmark:') + torch_total_benchmark(callables, labels, inputs, iters) + torch_fw_bw_benchmark_nvsight(callables, labels, inputs, iters) for t in traces: print(f'{t}\n#####################################') diff --git a/examples/dev/nanogpt.py b/examples/dev/nanogpt.py index 9de4111e08..eb519adec5 100644 --- a/examples/dev/nanogpt.py +++ b/examples/dev/nanogpt.py @@ -51,7 +51,7 @@ def run(target: str = 'runtime'): jmodel_def = thunder.jit(model) # Currently sdpa does not work? - jmodel_auto = thunder.jit(model, autotune_type=target, executors = ['torchcompile', 'nvfuser', 'cudnn', 'torch', 'python']) + jmodel_auto = thunder.jit(model, autotune_type=target, executors = ['torchcompile', 'nvfuser', 'cudnn', 'sdpa', 'torch', 'python'], use_cudagraphs=False) if compile: print("Compiling model...") diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 6a269396f1..5bc3baceae 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -14,24 +14,21 @@ def sequence_hash(s: Sequence) -> str: - name = "" - for e in s: - if ( - isinstance(e, CollectionProxy) - or isinstance(e, TensorProxy) - or isinstance(e, IntegerProxy) - or isinstance(e, FloatProxy) - ): - name += e.name - # TODO (matteochen): investigate if this is suitable - elif isinstance(e, int): - name += f"int{e}" - elif e is None: - name += "None" - else: - raise AssertionError(f"What? Maybe nested Sequence. type = {type(e)}") - return name + def rec(s) -> str: + name = "[" + for e in s: + if e is None: + name += "None" + elif hasattr(e, "name"): + name += e.name + elif isinstance(e, Sequence): + name += rec(e) + else: + raise AssertionError(f"Unsupported type = {type(e)}") + name += ']' + return name + return rec(s) def can_executor_execute(ex: Executor, bsym: BoundSymbol) -> bool: try: @@ -502,9 +499,10 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: else: t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) except Exception as e: + import traceback + traceback.print_exc() # https://github.com/Lightning-AI/lightning-thunder/issues/664 # Seems that this patch never work ... - print(f"Exception:\n{e}") if ( "call_method UserDefinedObjectVariable(set) __contains__ [UserDefinedObjectVariable()] {}" in str(e) and not nvsight @@ -517,8 +515,6 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: t, m, answer = compute_time_cost_ms(torch_compiled, executable_str, iters, *input_args) except Exception as e: print(f"Compiled trace execution still failed:\n{e}") - else: - print(f"Unhandled exception occured:\n{e}") finally: reset_tracectx(trace_tok) diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 4b6ad4fb2d..6f201fbd60 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -368,8 +368,8 @@ def split(): from thunder.executors.transformer_engineex import transformer_engine_ex executors_candidates: dict[str, list] = { + # 'linear_layer': [transformer_engine_ex.name], 'scaled_dot_product_attention': [sdpa_ex.name, cudnn_ex.name, fa3_ex.name], - 'linear_layer': [transformer_engine_ex.name] } # TODO (matteochen): use BackendOptimizer tracing @@ -488,15 +488,16 @@ def split(): compile_stats.last_backward_traces += bw_traces return fw_extrace, bw_extrace + # TODO: catch individual failures except Exception as e: import traceback - print(f'Exception occured:\n{e}\nTraceback:') + log(f'Exception occured when tuning {executors_candidates}:\n{e}\nTraceback:') traceback.print_exc() # Restore before calling split compile_data.executors_list = list(cached_executor_list) log( - f"================================================================================ Before Autotune Tuning: exception occured, executors:\n{executors_candidates} will not be autotuned", + f"================================================================================ Before Autotune Tuning: exception occured, executors:\n{executors_candidates} will not be autotuned (priority list policy will be used)", level=LogLevel.INFO, ) primal_trace, fw_extrace, bw_extrace, fw_traces, bw_traces = split() From 89f8d3fa66388ab852d7bde08fbfafe013494ebd Mon Sep 17 00:00:00 2001 From: matteochen Date: Mon, 19 Aug 2024 09:15:55 +0300 Subject: [PATCH 053/171] Refactored and fixed autotuning that requires fw and bw split handling (#14) --- thunder/__init__.py | 2 + thunder/backend_optimizer/optimizer.py | 1 + thunder/backend_optimizer/utils.py | 5 +- thunder/executors/torch_autograd.py | 164 ++++++++++++++----------- 4 files changed, 99 insertions(+), 73 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 61d6e0adeb..d5f367ac71 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -334,6 +334,8 @@ def jit( # Otherwise the user restricted choice will be used if not executors: executors = get_all_executors() + # Remove python and cudagraph + executors = [ex for ex in executors if ex.name != 'python' and ex.name != 'cudagraphex'] # Resolve names of executors executors = resolve_executors(executors) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 0c54e69b58..fa2883a468 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -1023,3 +1023,4 @@ def get_optimal_fw_traces(self) -> Sequence[TraceCtx]: def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: return self.optimizer.get_optimal_fw_bw_traces() + diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 5bc3baceae..a3713a67ff 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -23,6 +23,8 @@ def rec(s) -> str: name += e.name elif isinstance(e, Sequence): name += rec(e) + elif isinstance(e, int): + name += 'int' + str(e) else: raise AssertionError(f"Unsupported type = {type(e)}") name += ']' @@ -500,7 +502,8 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) except Exception as e: import traceback - traceback.print_exc() + ex_str = traceback.format_exc() + print(ex_str) # https://github.com/Lightning-AI/lightning-thunder/issues/664 # Seems that this patch never work ... if ( diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 6f201fbd60..4abcbfbf11 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -3,6 +3,7 @@ import torch +from thunder.backend_optimizer.utils import operation_in_trace import thunder.core.utils as utils from thunder.core.prims import PrimIDs from thunder.core.proxies import TensorProxy, variableify @@ -10,7 +11,7 @@ from thunder.core.symbol import BoundSymbol from thunder.core.trace import TraceCtx, from_trace, set_tracectx, reset_tracectx from thunder.core.transform_common import replace_redundant_inputs -from thunder.extend import OperatorExecutor +from thunder.extend import OperatorExecutor, Executor from thunder.core.vjp_utils import get_saved_for_backward_tensors if TYPE_CHECKING: @@ -366,9 +367,10 @@ def split(): from thunder.executors.cudnnex import cudnn_ex from thunder.executors.fa3ex import fa3_ex from thunder.executors.transformer_engineex import transformer_engine_ex + from thunder.executors.torchex import ex as torchex executors_candidates: dict[str, list] = { - # 'linear_layer': [transformer_engine_ex.name], + 'linear': [transformer_engine_ex.name], 'scaled_dot_product_attention': [sdpa_ex.name, cudnn_ex.name, fa3_ex.name], } @@ -378,46 +380,77 @@ def split(): # as the autotuner will receive already split fw and bw traces if autotune_type is not None: cached_executor_list = list(compile_data.executors_list) - try: - # disable this part for now - # raise RuntimeError('Disabled') - is_tuned = False + cached_executor_list_copy = list(compile_data.executors_list) + placed = set() - # We are interested to save the best_*s at the last iteration over the executors_candidates dict as the last - # out *_extrace from calling split will contain all the best executors computed incrementally - # i.e: best_* will track the best placemet for iteration (executors_candidates iteration) i plus every iteration from [0, i-1] - best_cost: float = float('inf') - best_fw_extrace: TraceCtx | None = None - best_bw_extrace: TraceCtx | None = None - best_fw_traces: list[TraceCtx] = [] - best_bw_traces: list[TraceCtx] = [] - best_primal_trace: TraceCtx | None = None - best_executor: OperatorExecutor | None = None + def find_torchex_index(): + index = 0 + for i, e in enumerate(cached_executor_list): + if e in placed: + index = i+1 + return min(len(cached_executor_list), index) + try: + is_tuned = False + benchmark_iters: int = 20 for i, (ex_type, ex_list) in enumerate(executors_candidates.items()): + # We need to reference some additional tructtures other than the best fw and bw traces as we have to update compile data only after we got the optimal choice + best_cost: float = float('inf') + best_fw_extrace: TraceCtx | None = None + best_bw_extrace: TraceCtx | None = None + best_fw_traces: list[TraceCtx] = [] + best_bw_traces: list[TraceCtx] = [] + best_primal_trace: TraceCtx | None = None + best_executor: Executor | None = None log( f"================================================================================ Before Autotune Tuning: Optimizing {ex_type}", level=LogLevel.INFO) - # Search in the requested executor list if one or more than one options for a know multiple executable region is available - to_benchmark = [ex for ex in cached_executor_list if ex.name in ex_list] - - if not to_benchmark: + # Filter out the executors based on the executors list, maybe not all the options have to be used + # Torch executor (default kernel) will be given a chance always + to_benchmark: list[Executor] = [ex for ex in cached_executor_list if ex.name in ex_list] + # Add torchexecutor if not present + if torchex not in to_benchmark: + to_benchmark.append(torchex) + # Verify that op is present in the trace + op_in_trace: bool = operation_in_trace(trace=computation_trc, op=ex_type) + + if (not to_benchmark and op_in_trace) or not op_in_trace: log( - f"================================================================================ Before Autotune Tuning: Skipping optimization for {ex_type} as not requested.", - level=LogLevel.INFO) + f"================================================================================ Before Autotune Tuning: Skipping optimization for {ex_type} as not requested or not present in computation_trc.", + level=LogLevel.INFO, + ) + continue - # TODO: do we want to give a chance to the configuration with no executor? - # e.g. assigning scaled_dot_product_attention to torch and not to sdpa / cudnn + log( + f"================================================================================ Before Autotune Tuning: Executors to bench for {ex_type}: {to_benchmark}", + level=LogLevel.INFO, + ) for e in to_benchmark: + # Create the executor list putting the executor under analysis at the head of queue + # 1. Add all executors except the ones under benchmark compile_data.executors_list = [ex for ex in cached_executor_list if ex not in to_benchmark] - # Make it with most priority - compile_data.executors_list.insert(0, e) + # 2. Make the current one with most priority to be picked up by + if e in compile_data.executors_list: + compile_data.executors_list.insert(0, compile_data.executors_list.pop(compile_data.executors_list.index(e))) + else: + compile_data.executors_list.insert(0, e) + # TODO: write why we have to place torchex as near as possible to the start of the list + if torchex not in compile_data.executors_list: + torchex_index = max(1, find_torchex_index()) + compile_data.executors_list.insert(torchex_index, torchex) log( f"================================================================================ Before Autotune Tuning: Testing compile data executors: {compile_data.executors_list}", level=LogLevel.INFO) + try: + primal_trace, fw_extrace, bw_extrace, fw_traces, bw_traces = split() + except Exception as exception: + import traceback + ex_str = traceback.format_exc() + log( + f"================================================================================ Before Autotune Tuning: Failed to place {e.name}: {exception}\n{ex_str}") + continue - primal_trace, fw_extrace, bw_extrace, fw_traces, bw_traces = split() - time_fw, mem_fw, _ = benchmark_trace(fw_extrace, iters=10, apply_del_last_used=False) - time_bw, mem_bw, _ = benchmark_trace(bw_extrace, iters=10, apply_del_last_used=False) + time_fw, mem_fw, _ = benchmark_trace(fw_extrace, iters=benchmark_iters, apply_del_last_used=False) + time_bw, mem_bw, _ = benchmark_trace(bw_extrace, iters=benchmark_iters, apply_del_last_used=False) tot_time = time_fw + time_bw tot_mem = mem_fw + mem_bw log( @@ -437,64 +470,51 @@ def split(): best_bw_traces = bw_traces best_primal_trace = primal_trace best_executor = e + print(f'Best executor end iteration: {best_executor}') - # c, m , _ = benchmark_trace(best_fw_extrace, iters=10, apply_del_last_used=False) - # print(f'inside update {c}') - # c, m , _ = benchmark_trace(best_bw_extrace, iters=10, apply_del_last_used=False) - # print(f'inside update {c}') - - c, m , _ = benchmark_trace(best_fw_extrace, iters=10, apply_del_last_used=False) + c, m , _ = benchmark_trace(best_fw_extrace, iters=benchmark_iters, apply_del_last_used=False) log( f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} best fw_extrace (time = {c}, mem = {m}):\n{best_fw_extrace}", level=LogLevel.INFO) - c, m , _ = benchmark_trace(best_bw_extrace, iters=10, apply_del_last_used=False) + c, m , _ = benchmark_trace(best_bw_extrace, iters=benchmark_iters, apply_del_last_used=False) log( f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} best bw_extrace (time = {c}, mem = {m}):\n{best_bw_extrace}", level=LogLevel.INFO) # Update the executor list with the winner executor for the current ex_type cached_executor_list = [ex for ex in cached_executor_list if ex not in to_benchmark] - # We have a solution, we don't have it if not requested from the executor list - if best_executor is not None: - cached_executor_list.insert(0, best_executor) - best_executor = None + if best_executor is None: + log( + f"================================================================================ Before Autotune Tuning: Could not find best executor for {ex_type}. Assigning torchex by default", level=LogLevel.INFO) + best_executor = torchex + cached_executor_list.insert(0, best_executor) + placed.add(best_executor) + log( + f"================================================================================ Before Autotune Tuning: Best executor for {ex_type}: {best_executor.name}", level=LogLevel.INFO) log( - f"================================================================================ Before Autotune Tuning: Benchmark {ex_type}, new executor list: {cached_executor_list}", level=LogLevel.INFO) + f"================================================================================ Before Autotune Tuning: Benchmark {ex_type}, updated executor list: {cached_executor_list}", level=LogLevel.INFO) # Update the compile stats on the last iter if i == len(executors_candidates)-1: # Check that we have solution, we don't have it if not requested from the executor list - if is_tuned: - # Restore - compile_data.executors_list = list(cached_executor_list) - - log( - f"================================================================================ Before Autotune Tuning: autotuned split_forward_backward from {executors_candidates}", level=LogLevel.INFO) - if compile_stats is not None: - compile_stats.last_traces.append(best_primal_trace) - compile_stats.last_traces += best_fw_traces - compile_stats.last_backward_traces += best_bw_traces - - return best_fw_extrace, best_bw_extrace - # If no solution is found at this optmization step, we proceed normally - else: - # Restore before calling split - compile_data.executors_list = list(cached_executor_list) - - log( - f"================================================================================ Before Autotune Tuning: not autotuned split_forward_backward from {executors_candidates}", level=LogLevel.INFO) - primal_trace, fw_extrace, bw_extrace, fw_traces, bw_traces = split() - if compile_stats is not None: - compile_stats.last_traces.append(primal_trace) - compile_stats.last_traces += fw_traces - compile_stats.last_backward_traces += bw_traces - - return fw_extrace, bw_extrace - # TODO: catch individual failures - except Exception as e: + if not is_tuned: + raise AssertionError( + f"No executors have been placed inside the trace. Will autotune the computation_trc ignoring the following executors:\n{executors_candidates}" + ) + # Restore + compile_data.executors_list = list(cached_executor_list) + log( + f"================================================================================ Before Autotune Tuning: autotuned split_forward_backward from {executors_candidates}", level=LogLevel.INFO) + if compile_stats is not None: + compile_stats.last_traces.append(best_primal_trace) + compile_stats.last_traces += best_fw_traces + compile_stats.last_backward_traces += best_bw_traces + + return best_fw_extrace, best_bw_extrace + except Exception as exception: import traceback - log(f'Exception occured when tuning {executors_candidates}:\n{e}\nTraceback:') - traceback.print_exc() + ex_str = traceback.format_exc() + log(f'Exception occured when tuning {executors_candidates}: {exception}\n{ex_str}') # Restore before calling split - compile_data.executors_list = list(cached_executor_list) + compile_data.executors_list = cached_executor_list_copy log( f"================================================================================ Before Autotune Tuning: exception occured, executors:\n{executors_candidates} will not be autotuned (priority list policy will be used)", From 30ef6ca36e902f10a93ff32e37f6e9071a2d46c1 Mon Sep 17 00:00:00 2001 From: matteochen Date: Mon, 19 Aug 2024 11:07:51 +0300 Subject: [PATCH 054/171] Disabling reverse search --- thunder/backend_optimizer/optimizer.py | 30 +++++++++++++------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index fa2883a468..0fb8610283 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -697,21 +697,21 @@ def measure_and_update_result(): # From bottom to up (this will exclude the full region as being handled in the for cycle above) # -> First iteration is the one with len(fusion_region) - 1 # -> Last iteration gives no fusion regions - for j in range(start_idx, start_idx + i + 1, increment_factor): - match_bsym_output( - group[j], - [dict_time_strat, dict_mem_strat], - get_first_available_operator_executor( - bsym=group[j], - executors=self.executors, - empty_hash=self.empty_executor_hashable_placeholder, - ), - ) - for k in range(start_idx + i + 1, len(group), increment_factor): - match_bsym_output(group[k], [dict_time_strat, dict_mem_strat], ex) - - # Benchmark this placement - measure_and_update_result() + # for j in range(start_idx, start_idx + i + 1, increment_factor): + # match_bsym_output( + # group[j], + # [dict_time_strat, dict_mem_strat], + # get_first_available_operator_executor( + # bsym=group[j], + # executors=self.executors, + # empty_hash=self.empty_executor_hashable_placeholder, + # ), + # ) + # for k in range(start_idx + i + 1, len(group), increment_factor): + # match_bsym_output(group[k], [dict_time_strat, dict_mem_strat], ex) + + # # Benchmark this placement + # measure_and_update_result() if best_placement_time is None or best_keys_time is None: raise AssertionError("Failed to get best time placement") From c43e66a8472c48cd44b0ddb3536fb0d302ccaf45 Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 20 Aug 2024 12:13:05 +0300 Subject: [PATCH 055/171] Transformer Engine support (#15) --- examples/dev/LLaMAMLP.py | 11 +- examples/dev/MLP.py | 16 +- examples/dev/litGPT.py | 14 +- examples/dev/nanogpt-block.py | 18 +- examples/dev/nanogpt.py | 11 +- examples/dev/nvfuser_optimizations.py | 17 +- examples/dev/sdpa.py | 23 +- examples/dev/sdpa_linear.py | 53 ++ examples/dev/sdpa_slow.py | 82 --- examples/dev/sdpaex_qsize.txt:q | 707 ++++++++++++++++++++++ examples/dev/simple.py | 18 +- examples/dev/te.py | 53 ++ examples/dev/transformer.py | 15 - thunder/backend_optimizer/optimizer.py | 85 ++- thunder/backend_optimizer/utils.py | 204 ++++--- thunder/benchmarks/utils.py | 34 +- thunder/core/trace.py | 5 + thunder/executors/torch_autograd.py | 108 ++-- thunder/executors/transformer_engineex.py | 7 +- 19 files changed, 1158 insertions(+), 323 deletions(-) create mode 100644 examples/dev/sdpa_linear.py delete mode 100644 examples/dev/sdpa_slow.py create mode 100644 examples/dev/sdpaex_qsize.txt:q create mode 100644 examples/dev/te.py delete mode 100644 examples/dev/transformer.py diff --git a/examples/dev/LLaMAMLP.py b/examples/dev/LLaMAMLP.py index dd9b3ce262..9e89b0139c 100644 --- a/examples/dev/LLaMAMLP.py +++ b/examples/dev/LLaMAMLP.py @@ -30,14 +30,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: iters = 100 print('Results with thunder benchmark:') - traces = [ + fw_traces = [ thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], + ] + bw_traces = [ thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1], ] - labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] - thunder_fw_bw_benchmark(traces, labels, iters, nvsight = False) + fw_labels = ["fw_def", "fw_auto"] + bw_labels = ["bw_def", "bw_auto"] + thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) callables = [jmodel_def, jmodel_auto] labels = ['def', 'auto'] @@ -48,5 +51,3 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch_total_benchmark(callables, labels, inputs, iters) torch_fw_bw_benchmark_nvsight(callables, labels, inputs, iters) - for t in traces: - print(f'{t}\n#####################################') diff --git a/examples/dev/MLP.py b/examples/dev/MLP.py index 9ae0fcfd13..0f3317a328 100644 --- a/examples/dev/MLP.py +++ b/examples/dev/MLP.py @@ -45,23 +45,21 @@ def forward(self, x): callables = [jmodel_auto, jmodel_def] labels = ['auto', 'def'] inputs = [x, x] - print('Results with torch fw bw benchmark:') - torch_fw_bw_benchmark(callables, labels, inputs, iters) + print('Results with torch total benchmark:') torch_total_benchmark(callables, labels, inputs, iters) - torch_fw_bw_benchmark_nvsight(callables, labels, inputs, iters) print('Results with thunder benchmark:') - traces = [ + fw_traces = [ thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], + ] + bw_traces = [ thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1], ] - traces.reverse() - labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] - labels.reverse() - thunder_fw_bw_benchmark(traces, labels, iters, nvsight = False) - thunder_fw_bw_benchmark(traces, labels, iters, nvsight = True) + fw_labels = ["fw_def", "fw_auto"] + bw_labels = ["bw_def", "bw_auto"] + thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) # for t in traces: # print(t) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 8d3aad2b7d..85dc00d35a 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -32,9 +32,17 @@ def __init__(self, layers: int, autotune_type: str, batch_size: int) -> None: iters = 100 print(f'Results thunder benchmark ({iters} iters):') - traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] - labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] - thunder_fw_bw_benchmark(traces, labels, iters) + fw_traces = [ + thunder.last_traces(jmodel_def)[-1], + thunder.last_traces(jmodel_auto)[-1], + ] + bw_traces = [ + thunder.last_backward_traces(jmodel_def)[-1], + thunder.last_backward_traces(jmodel_auto)[-1], + ] + fw_labels = ["fw_def", "fw_auto"] + bw_labels = ["bw_def", "bw_auto"] + thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) print(f'\n\nResults torch fw bw benchmark ({iters} iters):') callables = [jmodel_def, jmodel_auto] diff --git a/examples/dev/nanogpt-block.py b/examples/dev/nanogpt-block.py index 75e8c9c854..2992525296 100644 --- a/examples/dev/nanogpt-block.py +++ b/examples/dev/nanogpt-block.py @@ -117,17 +117,25 @@ def forward(self, x): print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) - + iters = 100 print('Results thunder benchmark:') - traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] - labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] - thunder_fw_bw_benchmark(traces, labels, 5) + fw_traces = [ + thunder.last_traces(jmodel_def)[-1], + thunder.last_traces(jmodel_auto)[-1], + ] + bw_traces = [ + thunder.last_backward_traces(jmodel_def)[-1], + thunder.last_backward_traces(jmodel_auto)[-1], + ] + fw_labels = ["fw_def", "fw_auto"] + bw_labels = ["bw_def", "bw_auto"] + thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) print('\n\nResults torch fw bw benchmark:') callables = [jmodel_def, jmodel_auto] labels = ['def', 'auto'] inputs = [x, x] - torch_fw_bw_benchmark(callables, labels, inputs, 5) + torch_fw_bw_benchmark(callables, labels, inputs, 100) print('\n\n\n\n\n\n') print(f'{thunder.last_traces(jmodel_def)[-1]}') diff --git a/examples/dev/nanogpt.py b/examples/dev/nanogpt.py index eb519adec5..da560160f8 100644 --- a/examples/dev/nanogpt.py +++ b/examples/dev/nanogpt.py @@ -147,16 +147,17 @@ def measure_nvsight(m, label): measure(jmodel_def, 'def') print('\n\nResults thunder benchmark:') - traces = [ + fw_traces = [ thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], + ] + bw_traces = [ thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1], ] - traces.reverse() - labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] - labels.reverse() - thunder_fw_bw_benchmark(traces, labels, 100) + fw_labels = ["fw_def", "fw_auto"] + bw_labels = ["bw_def", "bw_auto"] + thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, 20) measure_nvsight(jmodel_def, 'def') measure_nvsight(jmodel_auto, 'auto') diff --git a/examples/dev/nvfuser_optimizations.py b/examples/dev/nvfuser_optimizations.py index 487b1b24aa..abd3fe09bc 100644 --- a/examples/dev/nvfuser_optimizations.py +++ b/examples/dev/nvfuser_optimizations.py @@ -27,22 +27,25 @@ def forward(self, x: torch.Tensor): x = torch.randn(1 << 9, in_features, requires_grad=True) jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type='runtime', executors=['nvfuser', 'cudnn', 'torch', 'python']) + jmodel_auto = thunder.jit(model, autotune_type="runtime", executors=["nvfuser", "cudnn", "torch", "python"]) y = jmodel_def(x) y = jmodel_auto(x) iters = 100 print('Results thunder benchmark:') - traces = [ + fw_traces = [ thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], + ] + bw_traces = [ thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1], ] - labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] - thunder_fw_bw_benchmark(traces, labels, iters) - thunder_fw_bw_benchmark(traces, labels, iters, nvsight=True) + fw_labels = ["fw_def", "fw_auto"] + bw_labels = ["bw_def", "bw_auto"] + thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) + # thunder_fw_bw_benchmark(traces, labels, iters, nvsight=True) callables = [jmodel_def, jmodel_auto] labels = ['def', 'auto'] @@ -52,5 +55,7 @@ def forward(self, x: torch.Tensor): torch_total_benchmark(callables, labels, inputs, iters) torch_fw_bw_benchmark_nvsight(callables, labels, inputs, iters) - for t in traces: + for t in fw_traces: + print(f'{t}\n#########################################') + for t in bw_traces: print(f'{t}\n#########################################') diff --git a/examples/dev/sdpa.py b/examples/dev/sdpa.py index d307cae578..7dc0377111 100644 --- a/examples/dev/sdpa.py +++ b/examples/dev/sdpa.py @@ -1,6 +1,6 @@ import torch import thunder -from thunder.backend_optimizer.optimizer import benchmark_trace +from thunder.benchmarks.utils import thunder_fw_bw_benchmark class Model(torch.nn.Module): def __init__(self) -> None: @@ -24,15 +24,18 @@ def forward(self, query, key, value): print('deviation def:', (jmodel_def(q, k, v) - model(q, k, v)).abs().max().item()) print('deviation auto:', (jmodel_auto(q, k, v) - model(q, k, v)).abs().max().item()) - print('########################################') - c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='sdpa_def_fw', iters=10) - print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') - c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='sdpa_auto_fw', iters=10) - print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') - c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='sdpa_def_bw', iters=10) - print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') - c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='sdpa_auto_bw', iters=10) - print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') + iters = 100 + fw_traces = [ + thunder.last_traces(jmodel_def)[-1], + thunder.last_traces(jmodel_auto)[-1], + ] + bw_traces = [ + thunder.last_backward_traces(jmodel_def)[-1], + thunder.last_backward_traces(jmodel_auto)[-1], + ] + fw_labels = ["fw_def", "fw_auto"] + bw_labels = ["bw_def", "bw_auto"] + thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) print('\n\n\n\n\n\n') print(f'{thunder.last_traces(jmodel_def)[-1]}') diff --git a/examples/dev/sdpa_linear.py b/examples/dev/sdpa_linear.py new file mode 100644 index 0000000000..7627160828 --- /dev/null +++ b/examples/dev/sdpa_linear.py @@ -0,0 +1,53 @@ +import torch +import thunder +from thunder.backend_optimizer.optimizer import benchmark_trace + +torch.set_default_dtype(torch.float32) + +class Model(torch.nn.Module): + def __init__(self, inf, outf) -> None: + super().__init__() + self.linear = torch.nn.Linear(inf, outf, bias=False) + + def forward(self, query, key, value): + query = self.linear(query) + a = torch.nn.functional.scaled_dot_product_attention(query, key, value) + return a + +with torch.device('cuda'): + features = 128 + model = Model(features, features) + + jmodel_def = thunder.jit(model) + jmodel_auto = thunder.jit(model, autotune_type='runtime', executors=['nvfuser', 'cudnn', 'sdpa', 'fa3', 'torchcompile']) + + q = torch.rand(32, 8, 128, features, requires_grad=True) + k = torch.rand(32, 8, 128, features, requires_grad=True) + v = torch.rand(32, 8, 128, features, requires_grad=True) + + print('deviation def:', (jmodel_def(q, k, v) - model(q, k, v)).abs().max().item()) + print('deviation auto:', (jmodel_auto(q, k, v) - model(q, k, v)).abs().max().item()) + + print('########################################') + c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='sdpa_def_fw', iters=10) + print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') + c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='sdpa_auto_fw', iters=10) + print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') + c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='sdpa_def_bw', iters=10) + print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') + c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='sdpa_auto_bw', iters=10) + print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') + + print('\n\n\n\n\n\n') + print(f'{thunder.last_traces(jmodel_def)[-1]}') + print('###############################################################################') + print(f'{thunder.last_traces(jmodel_auto)[-1]}') + + print('\n\n') + print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') + print('###############################################################################') + print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') + + + + diff --git a/examples/dev/sdpa_slow.py b/examples/dev/sdpa_slow.py deleted file mode 100644 index c9d4381757..0000000000 --- a/examples/dev/sdpa_slow.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import thunder -import math - -class ModelConfig: - def __init__(self, n_embd=256, n_head=8, dropout=0.1, block_size=64, bias=True): - self.n_embd = n_embd - self.n_head = n_head - self.dropout = dropout - self.bias = bias - self.block_size = block_size - -class Module(nn.Module): - def __init__(self, config): - """ - My implementation of NanoGPT Causal Self Attention module for PyTorch. - - Args: - - config: Configuration object containing parameters for the attention module. - """ - super().__init__() - assert config.n_embd % config.n_head == 0 - # key, query, value projections for all heads, but in a batch - self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) - # output projection - self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) - # regularization - self.attn_dropout = nn.Dropout(config.dropout) - self.resid_dropout = nn.Dropout(config.dropout) - self.n_head = config.n_head - self.n_embd = config.n_embd - self.dropout = config.dropout - # causal mask to ensure that attention is only applied to the left in the input sequence - self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) - .view(1, 1, config.block_size, config.block_size)) - - def forward(self, x): - """ - Forward pass of the Causal Self Attention module. - - Args: - - x: Input tensor. - - Returns: - - torch.Tensor: Output tensor after self-attention. - """ - B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) - - # calculate query, key, values for all heads in batch and move head forward to be the batch dim - q, k, v = self.c_attn(x).split(self.n_embd, dim=2) - k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - - # manual implementation of attention - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) - att = F.softmax(att, dim=-1) - att = self.attn_dropout(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side - - # output projection - y = self.resid_dropout(self.c_proj(y)) - return y - -with torch.device('cuda'): - config = ModelConfig(n_embd = 1536) - module = Module(config) - j_module = thunder.jit(module) - - batch_size, sequence_length, embedding_dim = 8, 16, config.n_embd - x = torch.randn((batch_size, sequence_length, embedding_dim)) - - ans = j_module(x) - - print(thunder.last_traces(j_module)[-1]) - - - diff --git a/examples/dev/sdpaex_qsize.txt:q b/examples/dev/sdpaex_qsize.txt:q new file mode 100644 index 0000000000..a27d6e5e39 --- /dev/null +++ b/examples/dev/sdpaex_qsize.txt:q @@ -0,0 +1,707 @@ +filtered: {'linear': [], 'scaled_dot_product_attention': [thunder.extend.OperatorExecutor('sdpa'), thunder.extend.OperatorExecutor('cudnn')]} +->unpack_trivial +->unpack_trivial +->unpack_trivial +->scaled_dot_product_attention +3 -> cudnn +->scaled_dot_product_attention +4 -> fa3 +->add +->python_return +Assigned: {'scaled_dot_product_attention': {3: thunder.extend.OperatorExecutor('cudnn'), 4: thunder.extend.OperatorExecutor('fa3')}} +Input ex: (thunder.extend.OperatorExecutor('nvfuser'), thunder.extend.OperatorExecutor('torchcompile'), thunder.extend.OperatorExecutor('sdpa'), thunder.extend.OperatorExecutor('cudnn'), thunder.extend.OperatorExecutor('torch'), thunder.extend.OperatorExecutor('python')) +================================================================================ Autotune: ================================================================================ Before Autotune Tuning: Optimizing linear +================================================================================ Autotune: ================================================================================ Before Autotune Tuning: Skipping optimization for linear as not requested or not present in computation_trc. +================================================================================ Autotune: ================================================================================ Before Autotune Tuning: Optimizing scaled_dot_product_attention +================================================================================ Autotune: ================================================================================ Before Autotune Tuning: Executors to bench for scaled_dot_product_attention: [thunder.extend.OperatorExecutor('sdpa'), thunder.extend.OperatorExecutor('cudnn'), thunder.extend.OperatorExecutor('torch')] +================================================================================ Autotune: ================================================================================ Before Autotune Tuning: Testing compile data executors: [thunder.extend.OperatorExecutor('sdpa'), thunder.extend.OperatorExecutor('torch'), thunder.extend.OperatorExecutor('nvfuser'), thunder.extend.OperatorExecutor('torchcompile'), thunder.extend.OperatorExecutor('python')] +#FW after trace call +._trace at 0x7cdda07cfe20> +#FW after trace call +._trace at 0x7cdd9ee381f0> +#FW after trace call +._trace at 0x7cdda07cfe20> +#FW after trace call +._trace at 0x7cdd9ee38b80> +#FW after trace call +._trace at 0x7cdda07cfd90> +================================================================================ Autotune: Executors: +================================================================================ Autotune: sdpa -> is operator = True, is fusion = False +================================================================================ Autotune: torch -> is operator = True, is fusion = False +================================================================================ Autotune: nvfuser -> is operator = False, is fusion = True +================================================================================ Autotune: torchcompile -> is operator = False, is fusion = True +================================================================================ Autotune: python -> is operator = True, is fusion = False +================================================================================ Autotune: New forward trace to optimize (strat = OptimizerType.RUNTIME): +# Constructed by Dead Code Elimination (took 0 milliseconds) +import thunder +import thunder.core.dtypes as dtypes +import thunder.core.prims as prims +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(query, key, value): + # query: "cuda:0 bf16[4, 128, 6, 64]" + # key: "cuda:0 bf16[4, 128, 6, 64]" + # value: "cuda:0 bf16[4, 128, 6, 64]" + (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) + (t7, t8, t9, t10, _, _, t11, t12, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) + t14 = prims.convert_element_type(t0, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" + t15 = prims.convert_element_type(t7, dtypes.float32) # t15: "cuda:0 f32[4, 128, 6, 64]" + t16 = ltorch.add(t14, t15, alpha=None) # t16: "cuda:0 f32[4, 128, 6, 64]" + # t16 = prims.add(t14, t15) # t16: "cuda:0 f32[4, 128, 6, 64]" + t17 = prims.convert_element_type(t16, dtypes.bfloat16) # t17: "cuda:0 bf16[4, 128, 6, 64]" + return {'output': t17, 'flat_args': [query, key, value], 'flat_output': (t17,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t7, t8, t9, t10, t11, t12), ()) +================================================================================ Autotune: Searching best placement for fusion executor = nvfuser +================================================================================ Autotune: Searching best placement for fusion executor = torchcompile +================================================================================ Autotune: Benchmark fw end: Trace = [nvfuser] (time = 0.014675200078636408 ms, mem = 0.002941131591796875 GB)": +# Constructed by Transform for operator executor execution (took 0 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(query, key, value): + # query: "cuda:0 bf16[4, 128, 6, 64]" + # key: "cuda:0 bf16[4, 128, 6, 64]" + # value: "cuda:0 bf16[4, 128, 6, 64]" + (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) + [t17] = nvFusion0(t0) + # t14 = prims.convert_element_type(t0, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" + # t16 = prims.add(t14, t14) # t16: "cuda:0 f32[4, 128, 6, 64]" + # t17 = prims.convert_element_type(t16, dtypes.bfloat16) # t17: "cuda:0 bf16[4, 128, 6, 64]" + return {'output': t17, 'flat_args': [query, key, value], 'flat_output': (t17,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t0, t1, t2, t3, t4, t5), ()) +================================================================================ Autotune: Benchmark fw end: Trace = [nvfuser] (time = 0.014627200085669756 ms, mem = 0.002941131591796875 GB)": +# Constructed by Transform for operator executor execution (took 0 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(query, key, value): + # query: "cuda:0 bf16[4, 128, 6, 64]" + # key: "cuda:0 bf16[4, 128, 6, 64]" + # value: "cuda:0 bf16[4, 128, 6, 64]" + (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) + [t17] = nvFusion0(t0) + # t14 = prims.convert_element_type(t0, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" + # t16 = prims.add(t14, t14) # t16: "cuda:0 f32[4, 128, 6, 64]" + # t17 = prims.convert_element_type(t16, dtypes.bfloat16) # t17: "cuda:0 bf16[4, 128, 6, 64]" + return {'output': t17, 'flat_args': [query, key, value], 'flat_output': (t17,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t0, t1, t2, t3, t4, t5), ()) +================================================================================ Autotune: Caching fw candidate [compile option: None] +# Constructed by Transform for operator executor execution (took 0 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(query, key, value): + # query: "cuda:0 bf16[4, 128, 6, 64]" + # key: "cuda:0 bf16[4, 128, 6, 64]" + # value: "cuda:0 bf16[4, 128, 6, 64]" + (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) + [t17] = nvFusion0(t0) + # t14 = prims.convert_element_type(t0, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" + # t16 = prims.add(t14, t14) # t16: "cuda:0 f32[4, 128, 6, 64]" + # t17 = prims.convert_element_type(t16, dtypes.bfloat16) # t17: "cuda:0 bf16[4, 128, 6, 64]" + return {'output': t17, 'flat_args': [query, key, value], 'flat_output': (t17,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t0, t1, t2, t3, t4, t5), ()) +================================================================================ Autotune: Caching fw candidate [compile option: None] +# Constructed by Transform for operator executor execution (took 0 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(query, key, value): + # query: "cuda:0 bf16[4, 128, 6, 64]" + # key: "cuda:0 bf16[4, 128, 6, 64]" + # value: "cuda:0 bf16[4, 128, 6, 64]" + (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) + [t17] = nvFusion0(t0) + # t14 = prims.convert_element_type(t0, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" + # t16 = prims.add(t14, t14) # t16: "cuda:0 f32[4, 128, 6, 64]" + # t17 = prims.convert_element_type(t16, dtypes.bfloat16) # t17: "cuda:0 bf16[4, 128, 6, 64]" + return {'output': t17, 'flat_args': [query, key, value], 'flat_output': (t17,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t0, t1, t2, t3, t4, t5), ()) +================================================================================ Autotune: End fw time mem pair +================================================================================ Autotune: Benchmark fw end: Trace = [torchcompile] (time = 0.026291199773550034 ms, mem = 0.00331878662109375 GB)": +# Constructed by Transform for operator executor execution (took 0 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(query, key, value): + # query: "cuda:0 bf16[4, 128, 6, 64]" + # key: "cuda:0 bf16[4, 128, 6, 64]" + # value: "cuda:0 bf16[4, 128, 6, 64]" + (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) + (t7, t8, t9, t10, _, _, t11, t12, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) + [t17] = TorchCompile0(t0, t7) + # t14 = prims.convert_element_type(t0, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" + # t15 = prims.convert_element_type(t7, dtypes.float32) # t15: "cuda:0 f32[4, 128, 6, 64]" + # t16 = ltorch.add(t14, t15, alpha=None) # t16: "cuda:0 f32[4, 128, 6, 64]" + # t16 = prims.add(t14, t15) # t16: "cuda:0 f32[4, 128, 6, 64]" + # t17 = prims.convert_element_type(t16, dtypes.bfloat16) # t17: "cuda:0 bf16[4, 128, 6, 64]" + return {'output': t17, 'flat_args': [query, key, value], 'flat_output': (t17,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t7, t8, t9, t10, t11, t12), ()) +================================================================================ Autotune: Benchmark fw end: Trace = [torchcompile] (time = 0.02619839971885085 ms, mem = 0.00331878662109375 GB)": +# Constructed by Transform for operator executor execution (took 0 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(query, key, value): + # query: "cuda:0 bf16[4, 128, 6, 64]" + # key: "cuda:0 bf16[4, 128, 6, 64]" + # value: "cuda:0 bf16[4, 128, 6, 64]" + (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) + (t7, t8, t9, t10, _, _, t11, t12, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) + [t17] = TorchCompile0(t0, t7) + # t14 = prims.convert_element_type(t0, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" + # t15 = prims.convert_element_type(t7, dtypes.float32) # t15: "cuda:0 f32[4, 128, 6, 64]" + # t16 = ltorch.add(t14, t15, alpha=None) # t16: "cuda:0 f32[4, 128, 6, 64]" + # t16 = prims.add(t14, t15) # t16: "cuda:0 f32[4, 128, 6, 64]" + # t17 = prims.convert_element_type(t16, dtypes.bfloat16) # t17: "cuda:0 bf16[4, 128, 6, 64]" + return {'output': t17, 'flat_args': [query, key, value], 'flat_output': (t17,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t7, t8, t9, t10, t11, t12), ()) +================================================================================ Autotune: Caching fw candidate [compile option: None] +# Constructed by Transform for operator executor execution (took 0 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(query, key, value): + # query: "cuda:0 bf16[4, 128, 6, 64]" + # key: "cuda:0 bf16[4, 128, 6, 64]" + # value: "cuda:0 bf16[4, 128, 6, 64]" + (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) + (t7, t8, t9, t10, _, _, t11, t12, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) + [t17] = TorchCompile0(t0, t7) + # t14 = prims.convert_element_type(t0, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" + # t15 = prims.convert_element_type(t7, dtypes.float32) # t15: "cuda:0 f32[4, 128, 6, 64]" + # t16 = ltorch.add(t14, t15, alpha=None) # t16: "cuda:0 f32[4, 128, 6, 64]" + # t16 = prims.add(t14, t15) # t16: "cuda:0 f32[4, 128, 6, 64]" + # t17 = prims.convert_element_type(t16, dtypes.bfloat16) # t17: "cuda:0 bf16[4, 128, 6, 64]" + return {'output': t17, 'flat_args': [query, key, value], 'flat_output': (t17,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t7, t8, t9, t10, t11, t12), ()) +================================================================================ Autotune: Caching fw candidate [compile option: None] +# Constructed by Transform for operator executor execution (took 0 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(query, key, value): + # query: "cuda:0 bf16[4, 128, 6, 64]" + # key: "cuda:0 bf16[4, 128, 6, 64]" + # value: "cuda:0 bf16[4, 128, 6, 64]" + (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) + (t7, t8, t9, t10, _, _, t11, t12, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) + [t17] = TorchCompile0(t0, t7) + # t14 = prims.convert_element_type(t0, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" + # t15 = prims.convert_element_type(t7, dtypes.float32) # t15: "cuda:0 f32[4, 128, 6, 64]" + # t16 = ltorch.add(t14, t15, alpha=None) # t16: "cuda:0 f32[4, 128, 6, 64]" + # t16 = prims.add(t14, t15) # t16: "cuda:0 f32[4, 128, 6, 64]" + # t17 = prims.convert_element_type(t16, dtypes.bfloat16) # t17: "cuda:0 bf16[4, 128, 6, 64]" + return {'output': t17, 'flat_args': [query, key, value], 'flat_output': (t17,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t7, t8, t9, t10, t11, t12), ()) +================================================================================ Autotune: End fw time mem pair +================================================================================ Autotune: New backward trace to optimize (strat = OptimizerType.RUNTIME): +# Constructed by Backward pass +import thunder +import thunder.core.dtypes as dtypes +import thunder.core.prims as prims +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def backward_fn(saved_for_backward, cotangents): + # saved_for_backward: "Collection" + # cotangents: "Collection" + C0, C1, = saved_for_backward + t18, = cotangents + query, key, value, t0, t1, t2, t3, t4, t5, t7, t8, t9, t10, t11, t12, = C0 + # C1 (empty sequence) + t14 = prims.convert_element_type(t18, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" + t15 = prims.convert_element_type(t14, dtypes.bfloat16) # t15: "cuda:0 bf16[4, 128, 6, 64]" + t16 = prims.convert_element_type(t14, dtypes.bfloat16) # t16: "cuda:0 bf16[4, 128, 6, 64]" + (t17, t19, t20) = sdpafx_scaled_dot_product_efficient_attention_backward(t15, query, key, value, t7, t8, t9, t10, 6, 6, 0.0, False, t11, t12, scale=None) + (t21, t22, t23) = sdpafx_scaled_dot_product_efficient_attention_backward(t16, query, key, value, t0, t1, t2, t3, 6, 6, 0.0, False, t4, t5, scale=None) + t24 = prims.convert_element_type(t17, dtypes.float32) # t24: "cuda:0 f32[4, 128, 6, 64]" + t25 = prims.convert_element_type(t21, dtypes.float32) # t25: "cuda:0 f32[4, 128, 6, 64]" + t26 = prims.add(t24, t25) # t26: "cuda:0 f32[4, 128, 6, 64]" + t27 = prims.convert_element_type(t26, dtypes.bfloat16) # t27: "cuda:0 bf16[4, 128, 6, 64]" + t28 = prims.convert_element_type(t19, dtypes.float32) # t28: "cuda:0 f32[4, 128, 6, 64]" + t29 = prims.convert_element_type(t22, dtypes.float32) # t29: "cuda:0 f32[4, 128, 6, 64]" + t30 = prims.add(t28, t29) # t30: "cuda:0 f32[4, 128, 6, 64]" + t31 = prims.convert_element_type(t30, dtypes.bfloat16) # t31: "cuda:0 bf16[4, 128, 6, 64]" + t32 = prims.convert_element_type(t20, dtypes.float32) # t32: "cuda:0 f32[4, 128, 6, 64]" + t33 = prims.convert_element_type(t23, dtypes.float32) # t33: "cuda:0 f32[4, 128, 6, 64]" + t34 = prims.add(t32, t33) # t34: "cuda:0 f32[4, 128, 6, 64]" + t35 = prims.convert_element_type(t34, dtypes.bfloat16) # t35: "cuda:0 bf16[4, 128, 6, 64]" + return (t27, t31, t35) +================================================================================ Autotune: Backward optimization with fw from nvfuser +Current fw cached ctx: +# Constructed by Autotuned transform for execution (took 3006 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(query, key, value): + # query: "cuda:0 bf16[4, 128, 6, 64]" + # key: "cuda:0 bf16[4, 128, 6, 64]" + # value: "cuda:0 bf16[4, 128, 6, 64]" + (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) + [t17] = nvFusion0(t0) + # t14 = prims.convert_element_type(t0, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" + # t16 = prims.add(t14, t14) # t16: "cuda:0 f32[4, 128, 6, 64]" + # t17 = prims.convert_element_type(t16, dtypes.bfloat16) # t17: "cuda:0 bf16[4, 128, 6, 64]" + return {'output': t17, 'flat_args': [query, key, value], 'flat_output': (t17,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t0, t1, t2, t3, t4, t5), ()) +Options: None +================================================================================ Autotune: Searching best placement for fusion executor = nvfuser +#FW after trace call +._trace at 0x7cdd4a1d77f0> +#FW after trace call +._trace at 0x7cdd4a1d4160> +#FW after trace call +._trace at 0x7cdd4a0201f0> +#FW after trace call +._trace at 0x7cdd4a1d4f70> +#FW after trace call +._trace at 0x7cdd4a1d4160> +TORCH DTYPE t2 torch.int64 +TORCH DTYPE t3 torch.int64 +TORCH DTYPE t4 torch.int64 +TORCH DTYPE t5 torch.int64 +#FN EXECUTION FAILED: +# Constructed by Delete Last Used (took 0 milliseconds) +from torch import Tensor +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def backward_fn(saved_for_backward, cotangents): + # saved_for_backward: "Collection" + # cotangents: "Collection" + C0, _, = saved_for_backward + del saved_for_backward + t18, = cotangents + del cotangents + key, query, t0, t1, t2, t3, t4, t5, value, = C0 + del C0 + (t17, t19, t20) = sdpafx_scaled_dot_product_efficient_attention_backward(t18, query, key, value, t0, t1, t2, t3, 6, 6, 0.0, False, t4, t5, scale=None) + del t18, query, key, value, t0, t1, t2, t3, t4, t5 + [t24] = nvFusion1(t17) + t28 = Tensor.to(t19, torch.float32, copy=True) # t28: "cuda:0 f32[4, 128, 6, 64]" + del t19 + t32 = Tensor.to(t20, torch.float32, copy=True) # t32: "cuda:0 f32[4, 128, 6, 64]" + del t20 + t25 = Tensor.to(t17, torch.float32, copy=True) # t25: "cuda:0 f32[4, 128, 6, 64]" + del t17 + t26 = torch.add(t24, t25) # t26: "cuda:0 f32[4, 128, 6, 64]" + del t24, t25 + t30 = torch.add(t28, t28) # t30: "cuda:0 f32[4, 128, 6, 64]" + del t28 + t34 = torch.add(t32, t32) # t34: "cuda:0 f32[4, 128, 6, 64]" + del t32 + t27 = Tensor.to(t26, torch.bfloat16, copy=True) # t27: "cuda:0 bf16[4, 128, 6, 64]" + del t26 + t31 = Tensor.to(t30, torch.bfloat16, copy=True) # t31: "cuda:0 bf16[4, 128, 6, 64]" + del t30 + t35 = Tensor.to(t34, torch.bfloat16, copy=True) # t35: "cuda:0 bf16[4, 128, 6, 64]" + del t34 + return t27, t31, t35 +Traceback (most recent call last): + File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 505, in benchmark_trace + t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) + File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 346, in compute_time_cost_ms + raise e + File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 313, in compute_time_cost_ms + fn(*args) + File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context + return func(*args, **kwargs) + File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast + return func(*args, **kwargs) + File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast + return func(*args, **kwargs) + File "thunder.backward_fn_48", line 17, in backward_fn + (t17, t19, t20) = sdpafx_scaled_dot_product_efficient_attention_backward(t18, query, key, value, t0, t1, t2, t3, 6, 6, 0.0, False, t4, t5, scale=None) + File "/workspace/workdir/thunder/executors/sdpaex.py", line 453, in _scaled_dot_product_flash_attention_backward_impl + grads = torch.ops.aten._scaled_dot_product_flash_attention_backward( + File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1127, in __call__ + return self._op(*args, **(kwargs or {})) + File "/usr/local/lib/python3.10/dist-packages/torch/utils/_device.py", line 79, in __torch_function__ + return func(*args, **kwargs) + File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1127, in __call__ + return self._op(*args, **(kwargs or {})) +RuntimeError: cu_seqlens_q must have dtype int32 + +#FW after trace call +._trace at 0x7cdd9ee3bc70> +TORCH DTYPE t2 torch.int64 +TORCH DTYPE t3 torch.int64 +TORCH DTYPE t4 torch.int64 +TORCH DTYPE t5 torch.int64 +#FN EXECUTION FAILED: +# Constructed by Delete Last Used (took 0 milliseconds) +from torch import Tensor +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def backward_fn(saved_for_backward, cotangents): + # saved_for_backward: "Collection" + # cotangents: "Collection" + C0, _, = saved_for_backward + del saved_for_backward + t18, = cotangents + del cotangents + key, query, t0, t1, t2, t3, t4, t5, value, = C0 + del C0 + (t17, t19, t20) = sdpafx_scaled_dot_product_efficient_attention_backward(t18, query, key, value, t0, t1, t2, t3, 6, 6, 0.0, False, t4, t5, scale=None) + del t18, query, key, value, t0, t1, t2, t3, t4, t5 + t24 = Tensor.to(t17, torch.float32, copy=True) # t24: "cuda:0 f32[4, 128, 6, 64]" + [t27, t31, t35] = nvFusion1(t17, t24, t19, t20) + del t17, t24, t19, t20 + return t27, t31, t35 +Traceback (most recent call last): + File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 505, in benchmark_trace + t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) + File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 346, in compute_time_cost_ms + raise e + File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 313, in compute_time_cost_ms + fn(*args) + File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context + return func(*args, **kwargs) + File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast + return func(*args, **kwargs) + File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast + return func(*args, **kwargs) + File "thunder.backward_fn_49", line 17, in backward_fn + (t17, t19, t20) = sdpafx_scaled_dot_product_efficient_attention_backward(t18, query, key, value, t0, t1, t2, t3, 6, 6, 0.0, False, t4, t5, scale=None) + File "/workspace/workdir/thunder/executors/sdpaex.py", line 453, in _scaled_dot_product_flash_attention_backward_impl + grads = torch.ops.aten._scaled_dot_product_flash_attention_backward( + File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1127, in __call__ + return self._op(*args, **(kwargs or {})) + File "/usr/local/lib/python3.10/dist-packages/torch/utils/_device.py", line 79, in __torch_function__ + return func(*args, **kwargs) + File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1127, in __call__ + return self._op(*args, **(kwargs or {})) +RuntimeError: cu_seqlens_q must have dtype int32 + +#FW after trace call +._trace at 0x7cdd4a0875b0> +TORCH DTYPE t2 torch.int64 +TORCH DTYPE t3 torch.int64 +TORCH DTYPE t4 torch.int64 +TORCH DTYPE t5 torch.int64 +#FN EXECUTION FAILED: +# Constructed by Delete Last Used (took 0 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def backward_fn(saved_for_backward, cotangents): + # saved_for_backward: "Collection" + # cotangents: "Collection" + C0, _, = saved_for_backward + del saved_for_backward + t18, = cotangents + del cotangents + key, query, t0, t1, t2, t3, t4, t5, value, = C0 + del C0 + (t17, t19, t20) = sdpafx_scaled_dot_product_efficient_attention_backward(t18, query, key, value, t0, t1, t2, t3, 6, 6, 0.0, False, t4, t5, scale=None) + del t18, query, key, value, t0, t1, t2, t3, t4, t5 + [t27, t31, t35] = nvFusion1(t17, t19, t20) + del t17, t19, t20 + return t27, t31, t35 +Traceback (most recent call last): + File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 505, in benchmark_trace + t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) + File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 346, in compute_time_cost_ms + raise e + File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 313, in compute_time_cost_ms + fn(*args) + File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context + return func(*args, **kwargs) + File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast + return func(*args, **kwargs) + File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast + return func(*args, **kwargs) + File "thunder.backward_fn_50", line 16, in backward_fn + (t17, t19, t20) = sdpafx_scaled_dot_product_efficient_attention_backward(t18, query, key, value, t0, t1, t2, t3, 6, 6, 0.0, False, t4, t5, scale=None) + File "/workspace/workdir/thunder/executors/sdpaex.py", line 453, in _scaled_dot_product_flash_attention_backward_impl + grads = torch.ops.aten._scaled_dot_product_flash_attention_backward( + File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1127, in __call__ + return self._op(*args, **(kwargs or {})) + File "/usr/local/lib/python3.10/dist-packages/torch/utils/_device.py", line 79, in __torch_function__ + return func(*args, **kwargs) + File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1127, in __call__ + return self._op(*args, **(kwargs or {})) +RuntimeError: cu_seqlens_q must have dtype int32 + +#FW after trace call +._trace at 0x7cdd9ee3aef0> +TORCH DTYPE t2 torch.int64 +TORCH DTYPE t3 torch.int64 +TORCH DTYPE t4 torch.int64 +TORCH DTYPE t5 torch.int64 +#FN EXECUTION FAILED: +# Constructed by Delete Last Used (took 0 milliseconds) +from torch import Tensor +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def backward_fn(saved_for_backward, cotangents): + # saved_for_backward: "Collection" + # cotangents: "Collection" + C0, _, = saved_for_backward + del saved_for_backward + t18, = cotangents + del cotangents + key, query, t0, t1, t2, t3, t4, t5, value, = C0 + del C0 + (t17, t19, t20) = sdpafx_scaled_dot_product_efficient_attention_backward(t18, query, key, value, t0, t1, t2, t3, 6, 6, 0.0, False, t4, t5, scale=None) + del t18, query, key, value, t0, t1, t2, t3, t4, t5 + t24 = Tensor.to(t17, torch.float32, copy=True) # t24: "cuda:0 f32[4, 128, 6, 64]" + del t17 + t28 = Tensor.to(t19, torch.float32, copy=True) # t28: "cuda:0 f32[4, 128, 6, 64]" + del t19 + t32 = Tensor.to(t20, torch.float32, copy=True) # t32: "cuda:0 f32[4, 128, 6, 64]" + del t20 + t26 = torch.add(t24, t24) # t26: "cuda:0 f32[4, 128, 6, 64]" + del t24 + t30 = torch.add(t28, t28) # t30: "cuda:0 f32[4, 128, 6, 64]" + del t28 + t34 = torch.add(t32, t32) # t34: "cuda:0 f32[4, 128, 6, 64]" + del t32 + t27 = Tensor.to(t26, torch.bfloat16, copy=True) # t27: "cuda:0 bf16[4, 128, 6, 64]" + del t26 + t31 = Tensor.to(t30, torch.bfloat16, copy=True) # t31: "cuda:0 bf16[4, 128, 6, 64]" + del t30 + t35 = Tensor.to(t34, torch.bfloat16, copy=True) # t35: "cuda:0 bf16[4, 128, 6, 64]" + del t34 + return t27, t31, t35 +Traceback (most recent call last): + File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 505, in benchmark_trace + t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) + File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 346, in compute_time_cost_ms + raise e + File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 313, in compute_time_cost_ms + fn(*args) + File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context + return func(*args, **kwargs) + File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast + return func(*args, **kwargs) + File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast + return func(*args, **kwargs) + File "thunder.backward_fn_51", line 17, in backward_fn + (t17, t19, t20) = sdpafx_scaled_dot_product_efficient_attention_backward(t18, query, key, value, t0, t1, t2, t3, 6, 6, 0.0, False, t4, t5, scale=None) + File "/workspace/workdir/thunder/executors/sdpaex.py", line 453, in _scaled_dot_product_flash_attention_backward_impl + grads = torch.ops.aten._scaled_dot_product_flash_attention_backward( + File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1127, in __call__ + return self._op(*args, **(kwargs or {})) + File "/usr/local/lib/python3.10/dist-packages/torch/utils/_device.py", line 79, in __torch_function__ + return func(*args, **kwargs) + File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1127, in __call__ + return self._op(*args, **(kwargs or {})) +RuntimeError: cu_seqlens_q must have dtype int32 + +================================================================================ Autotune: ================================================================================ Before Autotune Tuning: Failed to place sdpa: Failed to get best time placement +Traceback (most recent call last): + File "/workspace/workdir/thunder/executors/torch_autograd.py", line 458, in split_forward_backward + primal_trace, fw_extrace, bw_extrace, fw_traces, bw_traces = split() + File "/workspace/workdir/thunder/executors/torch_autograd.py", line 248, in split + fw_extrace, bw_extrace = autotune_transform_for_execution( + File "/workspace/workdir/thunder/executors/passes.py", line 164, in autotune_transform_for_execution + optimizer_context.optimize() + File "/workspace/workdir/thunder/backend_optimizer/optimizer.py", line 1124, in optimize + self.optimizer.optimize() + File "/workspace/workdir/thunder/backend_optimizer/optimizer.py", line 974, in optimize + _optimize() + File "/workspace/workdir/thunder/backend_optimizer/optimizer.py", line 887, in _optimize + self._search_candidates() + File "/workspace/workdir/thunder/backend_optimizer/optimizer.py", line 827, in _search_candidates + _search(ex) + File "/workspace/workdir/thunder/backend_optimizer/optimizer.py", line 717, in _search + raise AssertionError("Failed to get best time placement") +AssertionError: Failed to get best time placement + +================================================================================ Autotune: ================================================================================ Before Autotune Tuning: Testing compile data executors: [thunder.extend.OperatorExecutor('cudnn'), thunder.extend.OperatorExecutor('torch'), thunder.extend.OperatorExecutor('nvfuser'), thunder.extend.OperatorExecutor('torchcompile'), thunder.extend.OperatorExecutor('python')] +#FW after trace call +._trace at 0x7cdd4a022950> +#FW after trace call +._trace at 0x7cdd5bef48b0> +#FW after trace call +._trace at 0x7cdd5bef48b0> +#FW after trace call +._trace at 0x7cdd4a022830> +#FW after trace call +._trace at 0x7cdd9ee3bb50> +================================================================================ Autotune: Executors: +================================================================================ Autotune: cudnn -> is operator = True, is fusion = False +================================================================================ Autotune: torch -> is operator = True, is fusion = False +================================================================================ Autotune: nvfuser -> is operator = False, is fusion = True +================================================================================ Autotune: torchcompile -> is operator = False, is fusion = True +================================================================================ Autotune: python -> is operator = True, is fusion = False +================================================================================ Autotune: New forward trace to optimize (strat = OptimizerType.RUNTIME): +# Constructed by Dead Code Elimination (took 0 milliseconds) +import thunder +import thunder.core.dtypes as dtypes +import thunder.core.prims as prims +import thunder.torch as ltorch +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(query, key, value): + # query: "cuda:0 bf16[4, 128, 6, 64]" + # key: "cuda:0 bf16[4, 128, 6, 64]" + # value: "cuda:0 bf16[4, 128, 6, 64]" + (t0, t1, t2, t3) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) + (t4, t5, t6, t7) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) + t8 = prims.convert_element_type(t0, dtypes.float32) # t8: "cuda:0 f32[4, 128, 6, 64]" + t9 = prims.convert_element_type(t4, dtypes.float32) # t9: "cuda:0 f32[4, 128, 6, 64]" + t10 = ltorch.add(t8, t9, alpha=None) # t10: "cuda:0 f32[4, 128, 6, 64]" + # t10 = prims.add(t8, t9) # t10: "cuda:0 f32[4, 128, 6, 64]" + t11 = prims.convert_element_type(t10, dtypes.bfloat16) # t11: "cuda:0 bf16[4, 128, 6, 64]" + return {'output': t11, 'flat_args': [query, key, value], 'flat_output': (t11,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t6, t7), ()) +================================================================================ Autotune: Searching best placement for fusion executor = nvfuser +================================================================================ Autotune: Searching best placement for fusion executor = torchcompile +================================================================================ Autotune: Benchmark fw end: Trace = [nvfuser] (time = 0.010593599872663617 ms, mem = 0.002941131591796875 GB)": +# Constructed by Transform for operator executor execution (took 0 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(query, key, value): + # query: "cuda:0 bf16[4, 128, 6, 64]" + # key: "cuda:0 bf16[4, 128, 6, 64]" + # value: "cuda:0 bf16[4, 128, 6, 64]" + (t0, t1, t2, t3) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) + [t11] = nvFusion0(t0) + # t8 = prims.convert_element_type(t0, dtypes.float32) # t8: "cuda:0 f32[4, 128, 6, 64]" + # t10 = prims.add(t8, t8) # t10: "cuda:0 f32[4, 128, 6, 64]" + # t11 = prims.convert_element_type(t10, dtypes.bfloat16) # t11: "cuda:0 bf16[4, 128, 6, 64]" + return {'output': t11, 'flat_args': [query, key, value], 'flat_output': (t11,)}, ((query, key, value, t0, t1, t2, t3, t0, t1, t2, t3), ()) +================================================================================ Autotune: Benchmark fw end: Trace = [nvfuser] (time = 0.010647999914363026 ms, mem = 0.002941131591796875 GB)": +# Constructed by Transform for operator executor execution (took 0 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(query, key, value): + # query: "cuda:0 bf16[4, 128, 6, 64]" + # key: "cuda:0 bf16[4, 128, 6, 64]" + # value: "cuda:0 bf16[4, 128, 6, 64]" + (t0, t1, t2, t3) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) + [t11] = nvFusion0(t0) + # t8 = prims.convert_element_type(t0, dtypes.float32) # t8: "cuda:0 f32[4, 128, 6, 64]" + # t10 = prims.add(t8, t8) # t10: "cuda:0 f32[4, 128, 6, 64]" + # t11 = prims.convert_element_type(t10, dtypes.bfloat16) # t11: "cuda:0 bf16[4, 128, 6, 64]" + return {'output': t11, 'flat_args': [query, key, value], 'flat_output': (t11,)}, ((query, key, value, t0, t1, t2, t3, t0, t1, t2, t3), ()) +================================================================================ Autotune: Caching fw candidate [compile option: None] +# Constructed by Transform for operator executor execution (took 0 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(query, key, value): + # query: "cuda:0 bf16[4, 128, 6, 64]" + # key: "cuda:0 bf16[4, 128, 6, 64]" + # value: "cuda:0 bf16[4, 128, 6, 64]" + (t0, t1, t2, t3) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) + [t11] = nvFusion0(t0) + # t8 = prims.convert_element_type(t0, dtypes.float32) # t8: "cuda:0 f32[4, 128, 6, 64]" + # t10 = prims.add(t8, t8) # t10: "cuda:0 f32[4, 128, 6, 64]" + # t11 = prims.convert_element_type(t10, dtypes.bfloat16) # t11: "cuda:0 bf16[4, 128, 6, 64]" + return {'output': t11, 'flat_args': [query, key, value], 'flat_output': (t11,)}, ((query, key, value, t0, t1, t2, t3, t0, t1, t2, t3), ()) +================================================================================ Autotune: Caching fw candidate [compile option: None] +# Constructed by Transform for operator executor execution (took 0 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(query, key, value): + # query: "cuda:0 bf16[4, 128, 6, 64]" + # key: "cuda:0 bf16[4, 128, 6, 64]" + # value: "cuda:0 bf16[4, 128, 6, 64]" + (t0, t1, t2, t3) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) + [t11] = nvFusion0(t0) + # t8 = prims.convert_element_type(t0, dtypes.float32) # t8: "cuda:0 f32[4, 128, 6, 64]" + # t10 = prims.add(t8, t8) # t10: "cuda:0 f32[4, 128, 6, 64]" + # t11 = prims.convert_element_type(t10, dtypes.bfloat16) # t11: "cuda:0 bf16[4, 128, 6, 64]" + return {'output': t11, 'flat_args': [query, key, value], 'flat_output': (t11,)}, ((query, key, value, t0, t1, t2, t3, t0, t1, t2, t3), ()) +================================================================================ Autotune: End fw time mem pair +================================================================================ Autotune: Benchmark fw end: Trace = [torchcompile] (time = 0.018396800477057697 ms, mem = 0.00331878662109375 GB)": +# Constructed by Transform for operator executor execution (took 0 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(query, key, value): + # query: "cuda:0 bf16[4, 128, 6, 64]" + # key: "cuda:0 bf16[4, 128, 6, 64]" + # value: "cuda:0 bf16[4, 128, 6, 64]" + (t0, t1, t2, t3) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) + (t4, t5, t6, t7) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) + [t11] = TorchCompile0(t0, t4) + # t8 = prims.convert_element_type(t0, dtypes.float32) # t8: "cuda:0 f32[4, 128, 6, 64]" + # t9 = prims.convert_element_type(t4, dtypes.float32) # t9: "cuda:0 f32[4, 128, 6, 64]" + # t10 = ltorch.add(t8, t9, alpha=None) # t10: "cuda:0 f32[4, 128, 6, 64]" + # t10 = prims.add(t8, t9) # t10: "cuda:0 f32[4, 128, 6, 64]" + # t11 = prims.convert_element_type(t10, dtypes.bfloat16) # t11: "cuda:0 bf16[4, 128, 6, 64]" + return {'output': t11, 'flat_args': [query, key, value], 'flat_output': (t11,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t6, t7), ()) +================================================================================ Autotune: Benchmark fw end: Trace = [torchcompile] (time = 0.01832640040665865 ms, mem = 0.00331878662109375 GB)": +# Constructed by Transform for operator executor execution (took 0 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(query, key, value): + # query: "cuda:0 bf16[4, 128, 6, 64]" + # key: "cuda:0 bf16[4, 128, 6, 64]" + # value: "cuda:0 bf16[4, 128, 6, 64]" + (t0, t1, t2, t3) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) + (t4, t5, t6, t7) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) + [t11] = TorchCompile0(t0, t4) + # t8 = prims.convert_element_type(t0, dtypes.float32) # t8: "cuda:0 f32[4, 128, 6, 64]" + # t9 = prims.convert_element_type(t4, dtypes.float32) # t9: "cuda:0 f32[4, 128, 6, 64]" + # t10 = ltorch.add(t8, t9, alpha=None) # t10: "cuda:0 f32[4, 128, 6, 64]" + # t10 = prims.add(t8, t9) # t10: "cuda:0 f32[4, 128, 6, 64]" + # t11 = prims.convert_element_type(t10, dtypes.bfloat16) # t11: "cuda:0 bf16[4, 128, 6, 64]" + return {'output': t11, 'flat_args': [query, key, value], 'flat_output': (t11,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t6, t7), ()) +================================================================================ Autotune: Caching fw candidate [compile option: None] +# Constructed by Transform for operator executor execution (took 0 milliseconds) +import torch +from thunder.executors.torchex import no_autocast + +@torch.no_grad() +@no_autocast +def augmented_forward_fn(query, key, value): + # query: "cuda:0 bf16[4, 128, 6, 64]" + # key: "cuda:0 bf16[4, 128, 6, 64]" + # value: "cuda:0 bf16[4, 128, 6, 64]" + (t0, t1, t2, t3) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) + (t4, t5, t6, t7) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) + [t11] = TorchCompile0(t0, t4) + # t8 = prims.convert_element_type(t0, dtypes.float32) # t8: "cuda:0 f32[4, 128, 6, 64]" + # t9 = prims.convert_element_type(t4, dtypes.float32) # t9: "cuda:0 f32[4, 128, 6, 64]" + # t10 = ltorch.add(t8, t9, alpha=None) # t10: "cuda:0 f32[4, 128, 6, 64]" + # t10 = prims.add(t8, t9) # t10: "cuda:0 f32[4, 128, 6, 64]" + # t11 = prims.convert_element_type(t10, dtypes.bfloat16) # t11: "cuda:0 bf16[4, 128, 6, 64]" + return {'output': t11, 'flat_args': [query, key, value], 'flat_output': (t11,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t6, t7), ()) diff --git a/examples/dev/simple.py b/examples/dev/simple.py index 0add49560a..be19d6fd62 100644 --- a/examples/dev/simple.py +++ b/examples/dev/simple.py @@ -26,16 +26,22 @@ def forward(self, x: torch.Tensor): y = jmodel_def(x) y = jmodel_auto(x) + iters = 100 print('Results thunder benchmark:') - traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] - labels = ['fw_def', 'fw_auto', 'bw_def', 'bw_auto'] - thunder_fw_bw_benchmark(traces, labels, 50) + fw_traces = [ + thunder.last_traces(jmodel_def)[-1], + thunder.last_traces(jmodel_auto)[-1], + ] + bw_traces = [ + thunder.last_backward_traces(jmodel_def)[-1], + thunder.last_backward_traces(jmodel_auto)[-1], + ] + fw_labels = ["fw_def", "fw_auto"] + bw_labels = ["bw_def", "bw_auto"] + thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) callables = [jmodel_def, jmodel_auto] labels = ['def', 'auto'] inputs = [x, x] print('Results torch benchmark:') torch_fw_bw_benchmark(callables, labels, inputs, 50) - - for t in traces: - print(f'{t}\n###################') diff --git a/examples/dev/te.py b/examples/dev/te.py new file mode 100644 index 0000000000..efd59c7392 --- /dev/null +++ b/examples/dev/te.py @@ -0,0 +1,53 @@ +import torch +import thunder +from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_fw_bw_benchmark_nvsight, torch_total_benchmark + +class Module(torch.nn.Module): + def __init__(self, in_features, out_features) -> None: + super().__init__() + self.linear = torch.nn.Sequential( + torch.nn.Linear(in_features, out_features), + ) + + def forward(self, x: torch.Tensor): + a = x + x + return self.linear(a) + +with torch.device('cuda'): + m = 1 + in_features = 4096 * m + out_features = 4096 * m + model = Module(in_features, out_features) + x = torch.randn(768, in_features, requires_grad=True) + + jmodel_def = thunder.jit(model, executors=['transformer_engine'], use_cudagraphs=False) + jmodel_auto = thunder.jit( + model, + autotune_type="runtime", + executors=["nvfuser", "transformer_engine", "cudnn", "torch"], + use_cudagraphs=False, + ) + + y = jmodel_def(x) + y = jmodel_auto(x) + + iters = 100 + fw_traces = [ + thunder.last_traces(jmodel_def)[-1], + thunder.last_traces(jmodel_auto)[-1], + ] + bw_traces = [ + thunder.last_backward_traces(jmodel_def)[-1], + thunder.last_backward_traces(jmodel_auto)[-1], + ] + fw_labels = ["fw_def", "fw_auto"] + bw_labels = ["bw_def", "bw_auto"] + print('Results thunder benchmark:') + thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) + + callables = [jmodel_def, jmodel_auto] + labels = ['def', 'auto'] + inputs = [x, x] + print('\n\nResults torch benchmark:') + torch_total_benchmark(callables, labels, inputs, iters) + diff --git a/examples/dev/transformer.py b/examples/dev/transformer.py deleted file mode 100644 index 2a4c5d4be6..0000000000 --- a/examples/dev/transformer.py +++ /dev/null @@ -1,15 +0,0 @@ - -import torch -import thunder - -with torch.device('cuda'): - transformer_model = torch.nn.Transformer(nhead=16, num_encoder_layers=12) - src = torch.rand((10, 32, 512)) - tgt = torch.rand((20, 32, 512)) - out = transformer_model(src, tgt) - print(out) - - - jmodel = thunder.jit(transformer_model) - out = jmodel(src, tgt) - diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 0fb8610283..92f644f784 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -10,8 +10,6 @@ from typing import Hashable from thunder.backend_optimizer.utils import benchmark_trace - -# Currently this manages both time and memory class BenchmarkResult: def __init__( self, @@ -28,7 +26,6 @@ def __init__( self.label: str | Hashable = label self.index: int = index - class OptimizerType(Enum): MEMORY = 0 RUNTIME = 1 @@ -101,7 +98,7 @@ def __init__( self.tot_cost: float = cost def __repr__(self) -> str: - return f"Final output candidate: forward trace:\n{self.fw.__repr__()}\nFinal output candidate: backward trace:{self.bw.__repr__()}" + return f"Final output candidate: forward trace:\n{self.fw.__repr__()}\nFinal output candidate: backward trace:\n{self.bw.__repr__()}" # Benchmark only traces will contain traces after the rematerialization call with fw and bw calls, reproducing what will be the real traces after the autotune pass @@ -273,22 +270,22 @@ def _best_runtime_and_memory_candidates(self, candidates): pair_cost_time = 0 pair_cost_mem = 0 t, m, _ = benchmark_trace(fw, iters=self.benchmark_iters) - log(f"Pair fw time: {t}, mem: {m}", level=LogLevel.INFO) + # log(f"Pair fw time: {t} ms, mem: {m/(2**30)} GB", level=LogLevel.INFO) pair_cost_time = pair_cost_time + t pair_cost_mem = pair_cost_mem + m - t, m, _ = benchmark_trace(bw, iters=self.benchmark_iters) - log(f"Pair bw time: {t}, mem: {m}", level=LogLevel.INFO) + t, m, _ = benchmark_trace(bw, iters=self.benchmark_iters, fw_trace=fw) + # log(f"Pair bw time: {t} ms, mem: {m/(2**30)} GB", level=LogLevel.INFO) pair_cost_time = pair_cost_time + t pair_cost_mem = pair_cost_mem + m if pair_cost_time < min_value_time: best_pair_runtime = OutputCandidate(fw=fw, bw=bw, cost=pair_cost_time) - log(f"New best runtime pair (no remat):\n{best_pair_runtime}", level=LogLevel.INFO) + # log(f"New best runtime pair (no remat):\n{best_pair_runtime}", level=LogLevel.INFO) min_value_time = pair_cost_time if pair_cost_mem < min_value_mem: best_pair_memory = OutputCandidate(fw=fw, bw=bw, cost=pair_cost_mem) - log(f"New best memory pair (no remat):\n{best_pair_memory}", level=LogLevel.INFO) + # log(f"New best memory pair (no remat):\n{best_pair_memory}", level=LogLevel.INFO) min_value_mem = pair_cost_mem return best_pair_runtime, best_pair_memory @@ -310,18 +307,18 @@ def fw_benchmark(): label = list(pair_time.keys())[0] # TODO (matteochen): remove the benchmark here as will done later on the bw pass c, m, _ = benchmark_trace(trc_time, self.benchmark_iters) - log( - f'Benchmark fw end: Trace = [{label}] (time = {c} ms, mem = {m / (2**30)} GB)":\n{trc_time}', - level=LogLevel.INFO, - ) + # log( + # f'Benchmark fw end: Trace = [{label}] (time = {c} ms, mem = {m / (2**30)} GB)":\n{trc_time}', + # level=LogLevel.INFO, + # ) self.debug_msg += ( f"Trace name = [{label}] - Target: TIME - Time = {c} ms - Mem = {m / (2**30)} GB\n{trc_time}\n\n" ) c, m, _ = benchmark_trace(trc_mem, self.benchmark_iters) - log( - f'Benchmark fw end: Trace = [{label}] (time = {c} ms, mem = {m / (2**30)} GB)":\n{trc_mem}', - level=LogLevel.INFO, - ) + # log( + # f'Benchmark fw end: Trace = [{label}] (time = {c} ms, mem = {m / (2**30)} GB)":\n{trc_mem}', + # level=LogLevel.INFO, + # ) self.debug_msg += ( f"Trace name = [{label}] - Target: MEM - Mem = {m / (2**30)} GB - Time = {c} ms\n{trc_mem}\n\n" ) @@ -332,13 +329,11 @@ def fw_benchmark(): # print(f"Caching fw with compile options mem: {compile_opt_mem.fusion_tag}") for t, o in zip([trc_time, trc_mem], [compile_opt_time, compile_opt_mem]): - log(f'Caching fw candidate [compile option: {o.fusion_tag if o else "None"}]\n{t}') + log(f'Caching fw candidate [compile option: {o.fusion_tag if o else "None"}]') self.cached_fw_traces.append( TraceCandidate(trace=t, compile_opt=o, label=label + '_enabled_' + o.fusion_tag if o is not None else label) ) - log("End fw time mem pair", level=LogLevel.INFO) - def bw_benchmark(): time_result = BenchmarkResult() memory_result = BenchmarkResult() @@ -348,12 +343,12 @@ def bw_benchmark(): # Unpack the dict label = list(pair.keys())[0] trace = list(pair.values())[0] - trace_time, trace_mem, res = benchmark_trace(trace, self.benchmark_iters) + trace_time, trace_mem, res = benchmark_trace(trace, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0]) self.debug_msg += f"Trace name = [{label}] - Target: TIME - Time = {trace_time} ms - Mem = {trace_mem / (2**30)} GB\n{trace}\n\n" - log( - f'Benchmark trace (target TIME) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}', - level=LogLevel.INFO, - ) + # log( + # f'Benchmark trace (target TIME) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}', + # level=LogLevel.INFO, + # ) if trace_time < time_result.runtime: time_result = BenchmarkResult(time=trace_time, memory=trace_mem, trace=trace, label=label, index=i) @@ -363,26 +358,26 @@ def bw_benchmark(): label = list(pair.keys())[0] trace = list(pair.values())[0] - trace_time, trace_mem, res = benchmark_trace(trace, self.benchmark_iters) + trace_time, trace_mem, res = benchmark_trace(trace, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0]) del res self.debug_msg += f"Trace name = [{label}] - Target: MEM - Mem = {trace_mem / (2**30)} GB - Time = {trace_time} ms\n{trace}\n\n" - log( - f'Benchmark trace (target MEM) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}', - level=LogLevel.INFO, - ) + # log( + # f'Benchmark trace (target MEM) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}', + # level=LogLevel.INFO, + # ) if trace_mem < memory_result.memory: memory_result = BenchmarkResult( time=trace_time, memory=trace_mem, trace=trace, label=label, index=i ) - log( - f'Benchmark end: Best trace time "{time_result.label} (time = {time_result.runtime} ms)":\n{time_result.trace}', - level=LogLevel.INFO, - ) - log( - f'Benchmark end: Best trace mem "{memory_result.label} (mem = {memory_result.memory / (2 ** 30)} GB)":\n{memory_result.trace}', - level=LogLevel.INFO, - ) + # log( + # f'Benchmark end: Best trace time "{time_result.label} (time = {time_result.runtime} ms)":\n{time_result.trace}', + # level=LogLevel.INFO, + # ) + # log( + # f'Benchmark end: Best trace mem "{memory_result.label} (mem = {memory_result.memory / (2 ** 30)} GB)":\n{memory_result.trace}', + # level=LogLevel.INFO, + # ) # Here we have to recover the traces without the pass through remat in order to be compliant # with thunder flow as we might have request for no remat @@ -594,7 +589,7 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): if self.trace_type == TraceType.BW and self.active_fw_trace_ctx[0] is not None: _, trc = rematerialize_forward_and_backward(self.active_fw_trace_ctx[0], trc) # Now, benchmark - t, m, _ = benchmark_trace(trc, self.benchmark_iters) + t, m, _ = benchmark_trace(trc, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0]) # Update results if t < candidate_best_time.runtime: candidate_best_time = BenchmarkResult(time=t, index=i) @@ -640,7 +635,7 @@ def measure_and_update_result(): trc, keys, placements = get_placed_trace(dict_time_strat, increasing_symbols) if self.trace_type == TraceType.BW and self.active_fw_trace_ctx[0] is not None: _, trc = rematerialize_forward_and_backward(self.active_fw_trace_ctx[0], trc) - cost, mem, _ = benchmark_trace(trc, self.benchmark_iters) + cost, mem, _ = benchmark_trace(trc, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0]) log(f"Placed trace (cost = {cost} ms, mem = {mem/(2**30)} GB)\n{trc}", level=LogLevel.DEBUG) if cost < best_res_time.runtime or (cost == best_res_time.runtime and mem < best_res_time.memory): best_res_time = BenchmarkResult(time=cost, memory=mem, trace=trc) @@ -829,6 +824,7 @@ def measure_and_update_result(): # Currently we are enabling one compile option at the time as testing all the permutations might need too much time. # TODO: Consider implementing patterns based on the executor under investingation if ex_compile_opts: + log(f'{ex.name} compile options: {[option.fusion_tag for option in ex_compile_opts]}') for opt in ex_compile_opts: # Search only if we have an instruction related to the compile option op_in_trace: bool = operation_in_trace(trace=self.trace, op=opt.symbol_tag) @@ -862,7 +858,7 @@ def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): match self.trace_type: case TraceType.FW: log( - f"New forward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}", level=LogLevel.INFO + f"New forward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}", level=LogLevel.DEBUG ) # TODO (matteochen): support bw trace optimization even though with no fw traces cached (computational trace?) case TraceType.BW: @@ -870,7 +866,7 @@ def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): raise AssertionError("Can not optimize backward traces before forward traces") log( f"New backward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}", - level=LogLevel.INFO, + level=LogLevel.DEBUG, ) def optimize(self): @@ -950,8 +946,9 @@ def _optimize(): self.trace = from_trace(cached_self_trace) self.trace.bound_symbols = list(cached_self_trace.bound_symbols) # Set the current active cached forward trace context - print( - f'Current fw cached ctx:\n{fw_trace_candidate.trace}\nOptions: {fw_trace_candidate.compile_opt.fusion_tag if fw_trace_candidate.compile_opt is not None else "None"}' + log( + f'Current fw cached ctx:\n{fw_trace_candidate.trace}\nOptions: {fw_trace_candidate.compile_opt.fusion_tag if fw_trace_candidate.compile_opt is not None else "None"}', + level=LogLevel.DEBUG ) self.active_fw_trace_ctx = fw_trace_candidate.trace, fw_trace_candidate.compile_opt diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index a3713a67ff..eea242df7b 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -1,8 +1,8 @@ from collections.abc import Callable, Hashable, Sequence from typing import Any -from thunder.core.dtypes import dtype, to_torch_dtype +from thunder.core.dtypes import to_torch_dtype from thunder.core.prims import PrimIDs -from thunder.core.proxies import CollectionProxy, FloatProxy, IntegerProxy, Proxy, TensorProxy, Variable, variableify +from thunder.core.proxies import AnyProxy, CollectionProxy, FloatProxy, IntegerProxy, Proxy, TensorProxy, Variable, variableify from thunder.core.symbol import BoundSymbol, Symbol from thunder.core.trace import TraceCtx, get_tracectx, reset_tracectx, set_tracectx from thunder.extend import Executor, FusionExecutor, OperatorExecutor @@ -10,8 +10,6 @@ import thunder.core.transforms as transforms from itertools import chain import torch -import thunder - def sequence_hash(s: Sequence) -> str: def rec(s) -> str: @@ -257,6 +255,13 @@ def operation_in_trace(*, trace: TraceCtx, op: str) -> bool: return True return False +def is_te_used(trace: TraceCtx) -> bool: + from thunder.executors.transformer_engineex import linear_bound_symbol_name_prefix + from thunder.executors.transformer_engineex import te_functional_linear_backward_name + for bsym in trace.bound_symbols: + if bsym.sym.name.startswith(linear_bound_symbol_name_prefix) or bsym.sym.name.startswith(te_functional_linear_backward_name): + return True + return False def benchmark_trace( trace: TraceCtx, @@ -267,14 +272,13 @@ def benchmark_trace( snapshot_name="", nvsight: bool = False, nvsight_fn_name: str = "", + **kwargs ) -> tuple[float, float, Any]: from thunder.executors.passes import del_last_used import inspect - input_args = [] - - if trace.bound_symbols[-1].sym.id != PrimIDs.RETURN: - raise AssertionError("Missing return statement") + # In order to benchmark traces with TE enabled, the backward pass needs the context object returned from the forward trace + cached_fw_te_ctx_out = None def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, float, Any]: try: @@ -303,22 +307,40 @@ def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, f print(f"#NVSIGHT FN EXECUTION FAILED:\n{trc}") raise e + def clone_args(args): + res = [] + for arg in args: + if isinstance(arg, Sequence): + res.append(clone_args(arg)) + else: + if isinstance(arg, torch.Tensor): + res.append(arg.clone()) + else: + res.append(arg) + return tuple(res) + def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[float, float, Any]: try: warm_up_iters = 50 out = None + # print_args(args) + # Warm up cycles for _ in range(warm_up_iters): - fn(*args) + cloned_args = clone_args(args) + out = fn(*cloned_args) + del cloned_args # Snapshot request if snapshot: + cloned_args = clone_args(args) torch.cuda.empty_cache() torch.cuda.synchronize() torch.cuda.memory._record_memory_history() - fn(*args) + fn(*cloned_args) torch.cuda.memory._dump_snapshot(snapshot_name + "_benchmark.pickle") torch.cuda.memory._record_memory_history(enabled=None) + del cloned_args # Benchmark stream = torch.cuda.current_stream() max_allocated_bytes = 0 @@ -326,15 +348,17 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] torch.cuda.synchronize() for i in range(iters): + cloned_args = clone_args(args) torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) torch.cuda.empty_cache() torch.cuda._sleep(1_000_000) start_events[i].record(stream) - fn(*args) + fn(*cloned_args) end_events[i].record(stream) max_allocated_bytes = max( max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device()) ) + del cloned_args torch.cuda.synchronize() times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] @@ -345,50 +369,6 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl print(f"#FN EXECUTION FAILED:\n{repr}") raise e - def print_input_args(args, level=0, show_content=False): - for e in args: - if isinstance(e, tuple) or isinstance(e, list): - print_input_args(e, level=level + 1) - else: - print(f"level {level}", type(e)) - - # def print_trace_execution_output(out: Any, show_content=False): - # if isinstance(out, tuple): - # for e in out: - # print(f'{type(e)}') - # else: - # print(f'{type(out)}') - - # TODO (matteochen): convert this into dict - def thunder_to_torch_float_dtype(tp: dtype, byte: int) -> torch.dtype: - if byte == 1: - raise AssertionError("Not implmented: 8 bit float") - # Dispatch flaot 16 type 1 from type 2 - elif byte == 2: - if tp._name == thunder.bfloat16._name: - return torch.bfloat16 - else: - return torch.float16 - elif byte == 4: - return torch.float32 - elif byte == 8: - return torch.float64 - else: - raise AssertionError(f"Not supported byte = {byte}") - - # TODO (matteochen): convert this into dict - def thunder_to_torch_int_dtype(byte: int) -> torch.dtype: - if byte == 1: - return torch.int8 - elif byte == 2: - return torch.int16 - elif byte == 4: - return torch.int32 - elif byte == 8: - return torch.int64 - else: - raise AssertionError(f"Not supported byte = {byte}") - # TODO (matteochen): use more appropriate mock int and float def transform_input_tuple(t: tuple, level=0) -> tuple: res = [] @@ -405,10 +385,13 @@ def transform_input_tuple(t: tuple, level=0) -> tuple: res.append(0 if e.value is None else e.value) elif isinstance(e, FloatProxy): res.append(0.0 if e.value is None else e.value) + # Transformer engine context object + elif hasattr(e, 'name') and isinstance(e, AnyProxy) and e.name.startswith('ctx_te'): + res.append(cached_fw_te_ctx_out) elif e is None: res.append(None) else: - raise AssertionError(f"Input arg type not recognized: {type(e)}") + raise AssertionError(f'Input arg type not recognized: {type(e)} with name: {e.name if hasattr(e, "name") else "unknown"} with value: {e}') return tuple(res) def transform_tensor(arg: TensorProxy) -> torch.Tensor: @@ -420,14 +403,23 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: shape = arg.shape device = arg.device requires_grad = arg.requires_grad - torch_dtype = to_torch_dtype(dtype) if torch_dtype is None: raise AssertionError(f'Unrecognized thunder dtype: {dtype}') if is_float_dtype(dtype): - tensor: torch.Tensor = torch.randn( - shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad - ) + # Use TE Float8 if TE is enabled, it has float32 ad torch dtype + if te_used: + tensor: torch.Tensor = torch.randn( + shape, dtype=torch_dtype if dtype.bytes > 1 else torch.float32, device=device.device_str(), requires_grad=requires_grad + ) + if dtype.bytes == 1: + import transformer_engine.pytorch as te + tensor = te.float8_tensor.Float8Tensor.to_float8(tensor) + # Support standard float tensors + else: + tensor: torch.Tensor = torch.randn( + shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad + ) elif is_signedinteger_dtype(dtype): tensor: torch.Tensor = torch.randint( 0, 8, shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad @@ -442,23 +434,52 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: return tensor - # print(f'BENCHMARKING:\n{trace}') - # def p(args): - # for e in args: - # if not isinstance(e, Sequence): - # if isinstance(e, torch.Tensor): - # print(f'{e.size()}') - # else: - # try: - # print(f'{e.name} -> {e}') - # except: - # print(f'{e}') - # else: - # print('rec') - # p(e) - # p(trace.args) - # print('##################') - # p(input_args) + # We have to fix the saved_for_backward tuple as TE output TensorProxy don't have a correct one + def fix_te_backward_inputs(inputs: list): + saved_for_bw = [] + for i, e in enumerate(inputs[0][0]): + # This tensor should be an uint8 https://github.com/NVIDIA/TransformerEngine/blob/4edcff5777be08b6f89658572c433aa8f36acf0d/transformer_engine/pytorch/module/linear.py#L366 + if i == 1: + inputmat_t = e + if inputmat_t.dtype != torch.uint8: + inputmat_t = torch.randint(0, 8, (e.shape), dtype=torch.uint8, device=e.device) + saved_for_bw.append(inputmat_t) + else: + saved_for_bw.append(e) + + fixed_inputs_first_index = tuple([tuple(saved_for_bw), inputs[0][1]]) + return fixed_inputs_first_index + + # Trace real input args + input_args = [] + + # Check for correctness + if trace.bound_symbols[-1].sym.id != PrimIDs.RETURN: + raise AssertionError("Missing return statement") + + if apply_del_last_used: + trace = del_last_used(trace) + + # Enable TE fp8 autocast if needed + te_used = is_te_used(trace) + if te_used: + cached_te_fp8_autocast_value = trace._include_te_fp8_autocast + trace._include_te_fp8_autocast = True + + # If transformer_engine executor is used and it is the bw function we have to recover the forward context from the forward trace + trace_signature = trace.signature_with_no_ctx() + if te_used and trace_signature.startswith('def backward') and 'fw_trace' not in kwargs: + raise RuntimeError('Set the associated forward trace in order to benchmark backward pass with TE executor') + elif te_used and trace_signature.startswith('def backward'): + # print('TE Benchmarking fw trace for bw') + fw_trace = kwargs.get('fw_trace', None) + if not isinstance(fw_trace, TraceCtx): + raise AssertionError(f'forward trace is not a TraceCtx. Received: {type(fw_trace)}') + # Run the fw trace and get the outputs + fw_output = benchmark_trace(fw_trace, apply_del_last_used=False)[2] + # Retrive the context from the fw pass output + # Currently it will contain an empty transformer_engineex.Context but might be useful for the future + cached_fw_te_ctx_out = fw_output[1][1][0] # Can we remove this check? # TODO (matteochen): use more appropriate mock int and float @@ -481,8 +502,11 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: else: raise AssertionError("Unexpexcted args type") - if apply_del_last_used: - trace = del_last_used(trace) + if te_used and trace_signature.startswith('def backward'): + first_tuple = fix_te_backward_inputs(input_args) + input_args.pop(0) + input_args.insert(0, first_tuple) + trace_tok = set_tracectx(trace) @@ -521,6 +545,10 @@ def transform_tensor(arg: TensorProxy) -> torch.Tensor: finally: reset_tracectx(trace_tok) + # Restore the autocast value to not mess up the input trace + if te_used: + trace._include_te_fp8_autocast = cached_te_fp8_autocast_value + return t, m, answer @@ -569,3 +597,21 @@ def wrap_fn_with_exeuctor_compile_option(option, fn: Callable | None = None, *ar cd.compile_options[option.fusion_tag] = old_opt return out + +def print_trace_args(trace: TraceCtx): + print_args(trace.args) + +# Display nest sequence arguments +def print_args(args): + print('\n\n###################################### Debug args') + def _print(args): + print('Sequence start') + for arg in args: + if isinstance(arg, Sequence): + _print(arg) + else: + tensor_shape = arg.shape if isinstance(arg, torch.Tensor) else None + print(f'{type(arg)} {tensor_shape if tensor_shape else ""}') + print('Sequence end') + _print(args) + print('###################################### Debug args\n') diff --git a/thunder/benchmarks/utils.py b/thunder/benchmarks/utils.py index 8b0109ecd9..85d8a098a0 100644 --- a/thunder/benchmarks/utils.py +++ b/thunder/benchmarks/utils.py @@ -1,8 +1,32 @@ +from collections.abc import Sequence import torch from thunder.backend_optimizer.utils import benchmark_trace +from thunder.core.trace import TraceCtx warm_up_iters = 50 +class AutotunerTorchAutogradBenchmarkUtils(): + def __init__( + self, + cost: float = float('inf'), + fw_trace: TraceCtx | None = None, + bw_trace: TraceCtx | None = None, + fw_traces: Sequence[TraceCtx] = [], + bw_traces: Sequence[TraceCtx] = [], + primal_trace: TraceCtx | None = None, + executor = None, + selected_executors: Sequence = [] + ) -> None: + self.cost: float = cost + self.fw_trace = fw_trace + self.bw_trace = bw_trace + self.fw_traces = fw_traces + self.bw_traces = bw_traces + self.primal_trace = primal_trace + self.executor = executor + self.selected_executors = selected_executors + + def torch_fw_bw_benchmark_nvsight(models: list, labels: list, inputs: list, iters: int) -> None: for m, input, label in zip(models, inputs, labels): @@ -123,8 +147,14 @@ def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int) print(f'{label} max allocated memory: {max_allocated_bytes / (2**30)} GB') -def thunder_fw_bw_benchmark(traces: list, labels: list, iters: int, nvsight: bool = False) -> None: - for trc, label in zip(traces, labels): +def thunder_fw_bw_benchmark(fw_traces: list, bw_traces: list, fw_labels: list, bw_labels: list, iters: int, nvsight: bool = False) -> None: + assert(len(fw_traces) == len(bw_traces) == len(fw_labels) == len(bw_labels)) + for trc, label in zip(fw_traces, fw_labels): c, m, _ = benchmark_trace(trc, apply_del_last_used=False, snapshot=True, snapshot_name=label, iters=iters, nvsight=nvsight, nvsight_fn_name=label) print(f'Executing {label} trace:\n{c} ms, {m / (2**30)} GB') + i = 0 + for trc, label in zip(bw_traces, bw_labels): + c, m, _ = benchmark_trace(trc, apply_del_last_used=False, snapshot=True, snapshot_name=label, iters=iters, nvsight=nvsight, nvsight_fn_name=label, fw_trace=fw_traces[i]) + print(f'Executing {label} trace:\n{c} ms, {m / (2**30)} GB') + i += 1 diff --git a/thunder/core/trace.py b/thunder/core/trace.py index da1fd27b23..cd9a9bf75b 100644 --- a/thunder/core/trace.py +++ b/thunder/core/trace.py @@ -313,6 +313,11 @@ def set_current_source_location(self, filename: str | None, positions: Positions self._current_source_filename = filename self._current_source_positions = positions + def signature_with_no_ctx(self) -> str: + si = self.siginfo() + signature_str = si.prettyprint(trace=self) + return signature_str + # TODO Account for multi-line signatures # TODO issue "Add type annotations to Python function produced by traces" # Consider extending the signature with type information, in particular the diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 4abcbfbf11..705f277e59 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -257,23 +257,10 @@ def split(): if autotune_type is None: # TODO Restore request for no rematerialization - # TODO (matteochen): remove these logs - c, m, _ = benchmark_trace(fw_extrace, iters=5) - log(f'before remat fw trace time = {c}, mem = {m}\n{fw_extrace}', level=LogLevel.INFO) - c, m, _ = benchmark_trace(bw_extrace, iters=5) - log(f'before remat bw trace time = {c}, mem = {m}', level=LogLevel.INFO) fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) - c, m, _ = benchmark_trace(fw_extrace, iters=5) - log(f'after remat fw trace time = {c}, mem = {m}\n{fw_extrace}', level=LogLevel.INFO) - c, m, _ = benchmark_trace(bw_extrace, iters=5) - log(f'after remat bw trace time = {c}, mem = {m}', level=LogLevel.INFO) # Autotuner has been taken care of remat else: - # TODO (matteochen): remove this - c, m, _ = benchmark_trace(fw_extrace, iters=5) - log(f'after remat fw trace time = {c}, mem = {m}\n{fw_extrace}', level=LogLevel.INFO) - c, m, _ = benchmark_trace(bw_extrace, iters=5) - log(f'after remat bw trace time = {c}, mem = {m}', level=LogLevel.INFO) + pass fw_traces.append(fw_extrace) bw_traces.append(bw_extrace) @@ -391,17 +378,15 @@ def find_torchex_index(): return min(len(cached_executor_list), index) try: + from thunder.benchmarks.utils import AutotunerTorchAutogradBenchmarkUtils + is_tuned = False benchmark_iters: int = 20 + global_optimal_result = AutotunerTorchAutogradBenchmarkUtils() + for i, (ex_type, ex_list) in enumerate(executors_candidates.items()): - # We need to reference some additional tructtures other than the best fw and bw traces as we have to update compile data only after we got the optimal choice - best_cost: float = float('inf') - best_fw_extrace: TraceCtx | None = None - best_bw_extrace: TraceCtx | None = None - best_fw_traces: list[TraceCtx] = [] - best_bw_traces: list[TraceCtx] = [] - best_primal_trace: TraceCtx | None = None - best_executor: Executor | None = None + # We need to reference some additional structures other than the best fw and bw traces as we have to update compile data only after we got the optimal choice + result = AutotunerTorchAutogradBenchmarkUtils() log( f"================================================================================ Before Autotune Tuning: Optimizing {ex_type}", level=LogLevel.INFO) @@ -412,7 +397,9 @@ def find_torchex_index(): if torchex not in to_benchmark: to_benchmark.append(torchex) # Verify that op is present in the trace - op_in_trace: bool = operation_in_trace(trace=computation_trc, op=ex_type) + # op_in_trace: bool = operation_in_trace(trace=computation_trc, op=ex_type) + # TODO (matteochen): currently it is bugged if op is not in trace + op_in_trace: bool = True if (not to_benchmark and op_in_trace) or not op_in_trace: log( @@ -450,45 +437,62 @@ def find_torchex_index(): continue time_fw, mem_fw, _ = benchmark_trace(fw_extrace, iters=benchmark_iters, apply_del_last_used=False) - time_bw, mem_bw, _ = benchmark_trace(bw_extrace, iters=benchmark_iters, apply_del_last_used=False) + time_bw, mem_bw, _ = benchmark_trace(bw_extrace, iters=benchmark_iters, apply_del_last_used=False, fw_trace=fw_extrace) tot_time = time_fw + time_bw tot_mem = mem_fw + mem_bw log( f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} options: {e.name}. Time fw = {time_fw} ms - Time bw = {time_bw} ms - Mem fw = {mem_fw / (2**30)} GB - Mem bw = {mem_bw / (2**30)} GB", level=LogLevel.INFO) log( f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} options: {e.name}. Time = {tot_time} ms - Mem = {tot_mem / (2**30)} GB", level=LogLevel.INFO) - log(f'Fw trace:\n{fw_extrace}', level=LogLevel.INFO) - log(f'Bw trace:\n{bw_extrace}', level=LogLevel.INFO) benchmark_cost = tot_time if autotune_type == OptimizerType.RUNTIME else tot_mem - if benchmark_cost < best_cost: + if benchmark_cost < result.cost: is_tuned = True - best_cost = benchmark_cost - best_fw_extrace = fw_extrace - best_bw_extrace = bw_extrace - best_fw_traces = fw_traces - best_bw_traces = bw_traces - best_primal_trace = primal_trace - best_executor = e - print(f'Best executor end iteration: {best_executor}') - - c, m , _ = benchmark_trace(best_fw_extrace, iters=benchmark_iters, apply_del_last_used=False) + result = AutotunerTorchAutogradBenchmarkUtils( + benchmark_cost, + fw_extrace, + bw_extrace, + fw_traces, + bw_traces, + primal_trace, + e + ) + if benchmark_cost < global_optimal_result.cost: + is_tuned = True + global_optimal_result = AutotunerTorchAutogradBenchmarkUtils( + benchmark_cost, + fw_extrace, + bw_extrace, + fw_traces, + bw_traces, + primal_trace, + selected_executors=list(compile_data.executors_list) + ) + + log(f'Best executor for {ex_type} iteration: {result.executor}') + + c, m , _ = benchmark_trace(result.fw_trace, iters=benchmark_iters, apply_del_last_used=False, level=LogLevel.DEBUG) log( - f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} best fw_extrace (time = {c}, mem = {m}):\n{best_fw_extrace}", level=LogLevel.INFO) - c, m , _ = benchmark_trace(best_bw_extrace, iters=benchmark_iters, apply_del_last_used=False) + f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} best fw_extrace (time = {c}, mem = {m}):\n{result.fw_trace}", level=LogLevel.DEBUG) + c, m , _ = benchmark_trace(result.bw_trace, iters=benchmark_iters, apply_del_last_used=False, fw_trace=result.fw_trace, level=LogLevel.DEBUG) log( - f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} best bw_extrace (time = {c}, mem = {m}):\n{best_bw_extrace}", level=LogLevel.INFO) + f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} best bw_extrace (time = {c}, mem = {m}):\n{result.bw_trace}", level=LogLevel.DEBUG) # Update the executor list with the winner executor for the current ex_type + # TODO: unify this by using global_optimal_result as it should contain already the optimal placement cached_executor_list = [ex for ex in cached_executor_list if ex not in to_benchmark] - if best_executor is None: + if result.executor is None: log( f"================================================================================ Before Autotune Tuning: Could not find best executor for {ex_type}. Assigning torchex by default", level=LogLevel.INFO) - best_executor = torchex - cached_executor_list.insert(0, best_executor) - placed.add(best_executor) + result.executor = torchex + cached_executor_list.insert(0, result.executor) + placed.add(result.executor) + # Restore all placed but not included in the executor list + for ex_placed in placed: + if ex_placed not in cached_executor_list: + cached_executor_list.append(ex_placed) log( - f"================================================================================ Before Autotune Tuning: Best executor for {ex_type}: {best_executor.name}", level=LogLevel.INFO) + f"================================================================================ Before Autotune Tuning: Best executor for {ex_type}: {result.executor.name}", level=LogLevel.INFO) log( f"================================================================================ Before Autotune Tuning: Benchmark {ex_type}, updated executor list: {cached_executor_list}", level=LogLevel.INFO) @@ -499,16 +503,20 @@ def find_torchex_index(): raise AssertionError( f"No executors have been placed inside the trace. Will autotune the computation_trc ignoring the following executors:\n{executors_candidates}" ) + # Restore - compile_data.executors_list = list(cached_executor_list) + executors_list_to_restore = cached_executor_list if result.cost < global_optimal_result.cost else global_optimal_result.selected_executors + result_to_assign = result if result.cost < global_optimal_result.cost else global_optimal_result + + compile_data.executors_list = list(executors_list_to_restore) log( f"================================================================================ Before Autotune Tuning: autotuned split_forward_backward from {executors_candidates}", level=LogLevel.INFO) if compile_stats is not None: - compile_stats.last_traces.append(best_primal_trace) - compile_stats.last_traces += best_fw_traces - compile_stats.last_backward_traces += best_bw_traces + compile_stats.last_traces.append(result_to_assign.primal_trace) + compile_stats.last_traces += result_to_assign.fw_traces + compile_stats.last_backward_traces += result_to_assign.bw_traces - return best_fw_extrace, best_bw_extrace + return result_to_assign.fw_trace, result_to_assign.bw_trace except Exception as exception: import traceback ex_str = traceback.format_exc() diff --git a/thunder/executors/transformer_engineex.py b/thunder/executors/transformer_engineex.py index a93ba3d6c4..c19d5bbae7 100644 --- a/thunder/executors/transformer_engineex.py +++ b/thunder/executors/transformer_engineex.py @@ -432,9 +432,10 @@ def _te_functional_linear_backward_meta( TensorProxy(like=g, shape=b_shape) if b_shape else None, ) +te_functional_linear_backward_name: str = "te_functional_linear_backward" te_functional_linear_backward = transformer_engine_ex.register_operator( - "te_functional_linear_backward", meta=_te_functional_linear_backward_meta, fn=_te_functional_linear_backward_impl + te_functional_linear_backward_name, meta=_te_functional_linear_backward_meta, fn=_te_functional_linear_backward_impl ) LINEAR_CALLS_COUNTER = 0 @@ -446,13 +447,15 @@ def _te_functional_linear_backward_meta( FP8_RECIPE_KEY = "te_fp8_recipe" +linear_bound_symbol_name_prefix: str = "te_linear" + # Creates a new stateful operator for each invocation of `linear`. def _create_fp8_linear_bound_symbol( a: TensorProxy, w: TensorProxy, b: TensorProxy, is_grad_enabled=False ) -> tuple[torch.Tensor, AnyProxy | None]: linear_fn = partial(TELinear(w.shape[1], w.shape[0]), is_grad_enabled=is_grad_enabled) global LINEAR_CALLS_COUNTER - name = f"te_linear_{LINEAR_CALLS_COUNTER}" + name = f"{linear_bound_symbol_name_prefix}_{LINEAR_CALLS_COUNTER}" desc = "transformer_engine_ex: Optional fp8_recipe for `fp8_autocast` context manager." if (fp8_recipe := get_compile_option(FP8_RECIPE_KEY, desc)) is None: From 79dd4d29325d9c19b70c2ab8dc0ab9da9a348567 Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 20 Aug 2024 12:21:06 +0300 Subject: [PATCH 056/171] Unified args building fn --- thunder/backend_optimizer/utils.py | 33 +++++------------------------- 1 file changed, 5 insertions(+), 28 deletions(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index eea242df7b..91fcdbe82d 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -370,11 +370,11 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl raise e # TODO (matteochen): use more appropriate mock int and float - def transform_input_tuple(t: tuple, level=0) -> tuple: + def transform_input_sequence(sequence: Sequence, level=0) -> tuple | list: res = [] - for e in t: + for e in sequence: if type(e) is tuple: - res.append(transform_input_tuple(e, level + 1)) + res.append(transform_input_sequence(e, level + 1)) else: if isinstance(e, TensorProxy): res.append(transform_tensor(e)) @@ -392,7 +392,7 @@ def transform_input_tuple(t: tuple, level=0) -> tuple: res.append(None) else: raise AssertionError(f'Input arg type not recognized: {type(e)} with name: {e.name if hasattr(e, "name") else "unknown"} with value: {e}') - return tuple(res) + return tuple(res) if level > 0 else res def transform_tensor(arg: TensorProxy) -> torch.Tensor: from thunder.core.dtypes import is_float_dtype, is_signedinteger_dtype, is_boolean_dtype @@ -450,9 +450,6 @@ def fix_te_backward_inputs(inputs: list): fixed_inputs_first_index = tuple([tuple(saved_for_bw), inputs[0][1]]) return fixed_inputs_first_index - # Trace real input args - input_args = [] - # Check for correctness if trace.bound_symbols[-1].sym.id != PrimIDs.RETURN: raise AssertionError("Missing return statement") @@ -481,33 +478,13 @@ def fix_te_backward_inputs(inputs: list): # Currently it will contain an empty transformer_engineex.Context but might be useful for the future cached_fw_te_ctx_out = fw_output[1][1][0] - # Can we remove this check? - # TODO (matteochen): use more appropriate mock int and float - if isinstance(trace.args, Sequence): - for arg in trace.args: - if isinstance(arg, tuple): - input_args.append(transform_input_tuple(arg)) - elif isinstance(arg, TensorProxy): - e = transform_tensor(arg) - input_args.append(e) - elif isinstance(arg, IntegerProxy): - if arg.python_type is bool: - input_args.append(False if arg.value is None else arg.value) - else: - input_args.append(0 if arg.value is None else arg.value) - elif isinstance(arg, FloatProxy): - input_args.append(0.0 if arg.value is None else arg.value) - else: - raise AssertionError(f"Input arg type not recognized: {type(arg)}") - else: - raise AssertionError("Unexpexcted args type") + input_args: list = transform_input_sequence(trace.args) if te_used and trace_signature.startswith('def backward'): first_tuple = fix_te_backward_inputs(input_args) input_args.pop(0) input_args.insert(0, first_tuple) - trace_tok = set_tracectx(trace) # Obtain the python executable string From e91e0eac19e0afde9a9c3b971465ccd1c89fbabf Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 20 Aug 2024 22:18:30 +0300 Subject: [PATCH 057/171] Beam search for fw bw split operators --- examples/dev/sdpa.py | 28 +- thunder/__init__.py | 17 +- thunder/backend_optimizer/optimizer.py | 16 + thunder/benchmarks/utils.py | 11 +- thunder/core/vjp_utils.py | 324 +++++++++------ thunder/executors/torch_autograd.py | 543 ++++++++----------------- 6 files changed, 424 insertions(+), 515 deletions(-) diff --git a/examples/dev/sdpa.py b/examples/dev/sdpa.py index 7dc0377111..505855fece 100644 --- a/examples/dev/sdpa.py +++ b/examples/dev/sdpa.py @@ -2,27 +2,33 @@ import thunder from thunder.benchmarks.utils import thunder_fw_bw_benchmark +torch.set_default_dtype(torch.bfloat16) + class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() - def forward(self, query, key, value): + def forward(self, query, key, value, q_l, k_l, v_l): a = torch.nn.functional.scaled_dot_product_attention(query, key, value) - return a + b = torch.nn.functional.scaled_dot_product_attention(q_l, k_l, v_l) + return a, b with torch.device('cuda'): model = Model() jmodel_def = thunder.jit(model) - # Order does not matter anymore - jmodel_auto = thunder.jit(model, autotune_type='runtime', executors = ['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python']) + jmodel_auto = thunder.jit(model, autotune_type='runtime', executors = ['nvfuser', 'sdpa', 'cudnn', 'torch', 'python']) + + q = torch.rand(32, 8, 128, 64*1, requires_grad=True) + k = torch.rand(32, 8, 128, 64*1, requires_grad=True) + v = torch.rand(32, 8, 128, 64*1, requires_grad=True) - q = torch.rand(32, 8, 128, 64*16, dtype=torch.float32, requires_grad=True) - k = torch.rand(32, 8, 128, 64*16, dtype=torch.float32, requires_grad=True) - v = torch.rand(32, 8, 128, 64*16, dtype=torch.float32, requires_grad=True) + q_l = torch.rand(32, 8, 128, 64*1, requires_grad=True) + k_l = torch.rand(32, 8, 128, 64*1, requires_grad=True) + v_l = torch.rand(32, 8, 128, 64*1, requires_grad=True) - print('deviation def:', (jmodel_def(q, k, v) - model(q, k, v)).abs().max().item()) - print('deviation auto:', (jmodel_auto(q, k, v) - model(q, k, v)).abs().max().item()) + jmodel_def(q, k, v, q_l, k_l, v_l) + jmodel_auto(q, k, v, q_l, k_l, v_l) iters = 100 fw_traces = [ @@ -46,7 +52,3 @@ def forward(self, query, key, value): print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') print('###############################################################################') print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') - - - - diff --git a/thunder/__init__.py b/thunder/__init__.py index d5f367ac71..4e6d6dcc09 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -273,7 +273,7 @@ def jit( disable_torch_autograd: bool = False, # TODO Revisit this UX for RC1 transforms: list[Transform] | None = None, record_history: bool = False, - autotune_type: Any | None = None, + # autotune_type: Any | None = None, **compile_options, # TODO RC1 Make this explicit -- dict of options ) -> Callable: """Just-in-time compile a callable (function or model). @@ -322,13 +322,12 @@ def jit( if transforms is None: transforms = [] - if autotune_type is not None: - if autotune_type == 'runtime': - autotune_type = OptimizerType.RUNTIME - elif autotune_type == 'memory': - autotune_type = OptimizerType.MEMORY - else: - raise AssertionError(f'Not supported optimization: {autotune_type}') + required_autotune = compile_options.get('autotune_type', None) + if required_autotune is not None: + if required_autotune not in ['runtime', 'memory']: + raise AssertionError(f'Not supported optimization: {required_autotune}') + + compile_options |= {"autotune_type": OptimizerType.RUNTIME if required_autotune == 'runtime' else OptimizerType.MEMORY} # Default the executors list to all_executors if no options are given # Otherwise the user restricted choice will be used @@ -667,7 +666,7 @@ def get_computation_and_inputs(*args, **kwargs): # transform_for_execution and various sorting of symbols, # applying transform_for_execution after this would be # breaking the order of operations - computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, autotune_type, *inps) + computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps) # Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces # by split_forward_backward diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 92f644f784..648ffd1bad 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -10,6 +10,22 @@ from typing import Hashable from thunder.backend_optimizer.utils import benchmark_trace +# Defining a wrapper fn as the imports will crash in the global scope +def get_fw_bw_split_backends_options(bsym: BoundSymbol) -> list: + from thunder.executors.sdpaex import sdpa_ex + from thunder.executors.cudnnex import cudnn_ex + from thunder.executors.fa3ex import fa3_ex + from thunder.executors.transformer_engineex import transformer_engine_ex + #Current configuration + options: dict[str, list] = { + # TODO: filter out TE only if requested + 'linear': [transformer_engine_ex], + 'scaled_dot_product_attention': [sdpa_ex, cudnn_ex, fa3_ex], + } + + return options.get(bsym.sym.name, []) + + class BenchmarkResult: def __init__( self, diff --git a/thunder/benchmarks/utils.py b/thunder/benchmarks/utils.py index 85d8a098a0..00c7ac844c 100644 --- a/thunder/benchmarks/utils.py +++ b/thunder/benchmarks/utils.py @@ -1,10 +1,19 @@ -from collections.abc import Sequence +from collections.abc import Callable, Sequence import torch from thunder.backend_optimizer.utils import benchmark_trace from thunder.core.trace import TraceCtx warm_up_iters = 50 +class SplitFwBwBenchmarkUtils(): + def __init__( + self, *, cost: float = float("inf"), fw_fn: Callable | None = None, bw_fn: Callable | None = None, executor = None + ) -> None: + self.cost: float = cost + self.fw_fn: Callable | None = fw_fn + self.bw_fn: Callable | None = bw_fn + self.executor = executor + class AutotunerTorchAutogradBenchmarkUtils(): def __init__( self, diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index 46fd75bc73..c946c334bf 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -5,12 +5,14 @@ from itertools import chain from thunder.core import prims, utils +from thunder.core.compile_data import get_compile_data from thunder.core.prims import PrimIDs from thunder.core.proxies import Proxy, variableify, TensorProxy from thunder.core.pytree import tree_flatten, tree_map from thunder.core.symbol import BoundSymbol from thunder.core.trace import from_trace, TraceCtx from thunder.core.transform_common import dce +from thunder.extend import Executor _cache = {} @@ -49,135 +51,209 @@ def make_aug_forward_and_backward(bsym: BoundSymbol) -> tuple[Callable, Callable from thunder.common import _make_cache_key from thunder.core.transforms import _get_gradfn_and_executor, eval_trace - joint_forward_backward, executor = _get_gradfn_and_executor(bsym) - utils.check( - joint_forward_backward is not None, - lambda: f"Cannot generate forward and backward functions for {bsym.sym.name}", - ) - key = (bsym.sym, executor, subkey := _make_cache_key(bsym.args, bsym.kwargs)) - cached_result = _cache.get(key, None) if subkey is not None else None - if cached_result is not None and not getattr(joint_forward_backward, "_disable_caching", False): - return cached_result - - joint_trace = thunder.trace(inline_trace=False, use_dce=False)(joint_forward_backward, *bsym.args, **bsym.kwargs) - consumers = utils.consumers(joint_trace) - - def find_backward_input(forward_output): - output_consumers = consumers.get(forward_output, None) - if output_consumers is None or not output_consumers: - return None - get_grad_bsym = next( - filter(lambda bsym: bsym.sym.id == PrimIDs.GET_GRAD, output_consumers), - None, + def _make_aug_forward_and_backward(return_traces = False) -> tuple[Callable, Callable] | tuple[Callable, Callable, TraceCtx, TraceCtx]: + joint_forward_backward, executor = _get_gradfn_and_executor(bsym) + utils.check( + joint_forward_backward is not None, + lambda: f"Cannot generate forward and backward functions for {bsym.sym.name}", ) - return get_grad_bsym.output if get_grad_bsym is not None else None - - def find_backward_output(forward_input): - forward_input_consumers = consumers.get(forward_input, None) - if forward_input_consumers is None or not forward_input_consumers: - return None - put_grad_bsym = next( - filter(lambda bsym: bsym.sym.id == PrimIDs.PUT_GRAD, forward_input_consumers), - None, + key = (bsym.sym, executor, subkey := _make_cache_key(bsym.args, bsym.kwargs)) + cached_result = _cache.get(key, None) if subkey is not None else None + if cached_result is not None and not getattr(joint_forward_backward, "_disable_caching", False): + return cached_result + + joint_trace = thunder.trace(inline_trace=False, use_dce=False)(joint_forward_backward, *bsym.args, **bsym.kwargs) + consumers = utils.consumers(joint_trace) + + def find_backward_input(forward_output): + output_consumers = consumers.get(forward_output, None) + if output_consumers is None or not output_consumers: + return None + get_grad_bsym = next( + filter(lambda bsym: bsym.sym.id == PrimIDs.GET_GRAD, output_consumers), + None, + ) + return get_grad_bsym.output if get_grad_bsym is not None else None + + def find_backward_output(forward_input): + forward_input_consumers = consumers.get(forward_input, None) + if forward_input_consumers is None or not forward_input_consumers: + return None + put_grad_bsym = next( + filter(lambda bsym: bsym.sym.id == PrimIDs.PUT_GRAD, forward_input_consumers), + None, + ) + return put_grad_bsym.args[1] if put_grad_bsym is not None else None + + bw_inputs = tree_map(find_backward_input, utils.sequencify(joint_trace.output)) + bw_outputs_args = tree_map(find_backward_output, joint_trace.args) + bw_outputs_kwargs = tree_map(find_backward_output, joint_trace.kwargs) + meta_parameters = inspect.signature(bsym.sym.meta).parameters + meta_parameters = { + name: param + for name, param in meta_parameters.items() + if param.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.POSITIONAL_ONLY) + } + bw_outputs = {name: bw_output for name, bw_output in utils.safe_zip(meta_parameters, bw_outputs_args)} + bw_outputs = bw_outputs | bw_outputs_kwargs + flat_bw_outputs, _ = tree_flatten(bw_outputs) + + backward_bsyms = utils.find_producer_symbols(joint_trace, flat_bw_outputs, tree_flatten(bw_inputs)[0]) + skip = ( + prims.PrimIDs.UNPACK_EMPTY_DICT, + prims.PrimIDs.UNPACK_KEY, + prims.PrimIDs.UNPACK_SEQUENCE, + prims.PrimIDs.UNPACK_TRIVIAL, + prims.PrimIDs.GET_GRAD, ) - return put_grad_bsym.args[1] if put_grad_bsym is not None else None - - bw_inputs = tree_map(find_backward_input, utils.sequencify(joint_trace.output)) - bw_outputs_args = tree_map(find_backward_output, joint_trace.args) - bw_outputs_kwargs = tree_map(find_backward_output, joint_trace.kwargs) - meta_parameters = inspect.signature(bsym.sym.meta).parameters - meta_parameters = { - name: param - for name, param in meta_parameters.items() - if param.kind in (Parameter.POSITIONAL_OR_KEYWORD, Parameter.POSITIONAL_ONLY) - } - bw_outputs = {name: bw_output for name, bw_output in utils.safe_zip(meta_parameters, bw_outputs_args)} - bw_outputs = bw_outputs | bw_outputs_kwargs - flat_bw_outputs, _ = tree_flatten(bw_outputs) - - backward_bsyms = utils.find_producer_symbols(joint_trace, flat_bw_outputs, tree_flatten(bw_inputs)[0]) - skip = ( - prims.PrimIDs.UNPACK_EMPTY_DICT, - prims.PrimIDs.UNPACK_KEY, - prims.PrimIDs.UNPACK_SEQUENCE, - prims.PrimIDs.UNPACK_TRIVIAL, - prims.PrimIDs.GET_GRAD, - ) - backward_bsyms = [bsym for bsym in backward_bsyms if bsym.sym.id not in skip] - backward_bsyms.append(prims.python_return.bind(bw_outputs, output=())) - - forward_input_proxies = tree_flatten((joint_trace.args, joint_trace.kwargs))[0] - forward_input_proxies = [arg for arg in forward_input_proxies if isinstance(arg, Proxy)] - forward_bsyms = utils.find_producer_symbols(joint_trace, tree_flatten(joint_trace.output)[0], forward_input_proxies) - backward_bsyms = [bsym for bsym in backward_bsyms if bsym not in forward_bsyms] - - # Find required info from forward trace for backward trace - backward_producers = utils.producers(backward_bsyms) - saved_for_backward = [] - for backward_bsym in backward_bsyms: - for arg in backward_bsym.flat_args: - if not isinstance(arg, Proxy): - continue - if arg not in backward_producers and variableify(arg) not in map(variableify, tree_flatten(bw_inputs)[0]): - saved_for_backward.append(arg) - - saved_for_backward = list({variableify(arg): arg for arg in saved_for_backward}.values()) - - # Augment forward trace to include saved_for_backward as output - augmented_forward_trace = from_trace(joint_trace) - augmented_forward_trace.bound_symbols = [ - b for b in joint_trace.bound_symbols if b.sym.id not in (PrimIDs.PUT_GRAD, PrimIDs.GET_GRAD) - ] - return_bsym = augmented_forward_trace.bound_symbols[-1] - assert return_bsym.sym.id == PrimIDs.RETURN - augmented_forward_trace.bound_symbols[-1] = prims.python_return.bind( - (joint_trace.output, saved_for_backward), output=() - ) - # Remove put/get grad and backward symbols from augmented forward trace - augmented_forward_trace = dce(augmented_forward_trace) - - # Check if any of the bound symbols in the backward trace are also in the - # augmented forward trace - # If so, remove them from the backward trace - same_bsyms = set(augmented_forward_trace.bound_symbols) & set(backward_bsyms) - if same_bsyms: - backward_bsyms = [bsym for bsym in backward_bsyms if bsym not in same_bsyms] - additional_saved = [o for bsym in same_bsyms for o in bsym.flat_proxy_outs] - saved_for_backward += list({variableify(arg): arg for arg in additional_saved}.values()) + backward_bsyms = [bsym for bsym in backward_bsyms if bsym.sym.id not in skip] + backward_bsyms.append(prims.python_return.bind(bw_outputs, output=())) + + forward_input_proxies = tree_flatten((joint_trace.args, joint_trace.kwargs))[0] + forward_input_proxies = [arg for arg in forward_input_proxies if isinstance(arg, Proxy)] + forward_bsyms = utils.find_producer_symbols(joint_trace, tree_flatten(joint_trace.output)[0], forward_input_proxies) + backward_bsyms = [bsym for bsym in backward_bsyms if bsym not in forward_bsyms] + + # Find required info from forward trace for backward trace + backward_producers = utils.producers(backward_bsyms) + saved_for_backward = [] + for backward_bsym in backward_bsyms: + for arg in backward_bsym.flat_args: + if not isinstance(arg, Proxy): + continue + if arg not in backward_producers and variableify(arg) not in map(variableify, tree_flatten(bw_inputs)[0]): + saved_for_backward.append(arg) + + saved_for_backward = list({variableify(arg): arg for arg in saved_for_backward}.values()) + + # Augment forward trace to include saved_for_backward as output + augmented_forward_trace = from_trace(joint_trace) + augmented_forward_trace.bound_symbols = [ + b for b in joint_trace.bound_symbols if b.sym.id not in (PrimIDs.PUT_GRAD, PrimIDs.GET_GRAD) + ] + return_bsym = augmented_forward_trace.bound_symbols[-1] + assert return_bsym.sym.id == PrimIDs.RETURN augmented_forward_trace.bound_symbols[-1] = prims.python_return.bind( (joint_trace.output, saved_for_backward), output=() ) - - backward_params = [ - Parameter(getattr(x, "name", f"arg{i}"), Parameter.POSITIONAL_OR_KEYWORD) - for i, x in enumerate(chain(saved_for_backward, bw_inputs)) - ] - backward_signature = Signature(backward_params) - - def backward_fn(): - pass - - backward_fn.__signature__ = backward_signature - backward_fn.__name__ = bsym.sym.name + "_backward" - - # Finally, build the backward trace - backward_trace = TraceCtx(backward_fn) - backward_trace.args = (*saved_for_backward, *bw_inputs) - backward_trace.kwargs = {} - backward_trace.bound_symbols = backward_bsyms - - # Creating new functions instead of using partial to avoid limitations in - # codeutils.get_siginfo - # https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/core/codeutils.py#L349-L353 - def fw_fn(*args, **kwargs): - return eval_trace(augmented_forward_trace, *args, **kwargs) - - def bw_fn(*args, **kwargs): - return eval_trace(backward_trace, *args, **kwargs) - - _cache[key] = fw_fn, bw_fn - - return fw_fn, bw_fn + # Remove put/get grad and backward symbols from augmented forward trace + augmented_forward_trace = dce(augmented_forward_trace) + + # Check if any of the bound symbols in the backward trace are also in the + # augmented forward trace + # If so, remove them from the backward trace + same_bsyms = set(augmented_forward_trace.bound_symbols) & set(backward_bsyms) + if same_bsyms: + backward_bsyms = [bsym for bsym in backward_bsyms if bsym not in same_bsyms] + additional_saved = [o for bsym in same_bsyms for o in bsym.flat_proxy_outs] + saved_for_backward += list({variableify(arg): arg for arg in additional_saved}.values()) + augmented_forward_trace.bound_symbols[-1] = prims.python_return.bind( + (joint_trace.output, saved_for_backward), output=() + ) + + backward_params = [ + Parameter(getattr(x, "name", f"arg{i}"), Parameter.POSITIONAL_OR_KEYWORD) + for i, x in enumerate(chain(saved_for_backward, bw_inputs)) + ] + backward_signature = Signature(backward_params) + + def backward_fn(): + pass + + backward_fn.__signature__ = backward_signature + backward_fn.__name__ = bsym.sym.name + "_backward" + + # Finally, build the backward trace + backward_trace = TraceCtx(backward_fn) + backward_trace.args = (*saved_for_backward, *bw_inputs) + backward_trace.kwargs = {} + backward_trace.bound_symbols = backward_bsyms + + # Creating new functions instead of using partial to avoid limitations in + # codeutils.get_siginfo + # https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/core/codeutils.py#L349-L353 + def fw_fn(*args, **kwargs): + return eval_trace(augmented_forward_trace, *args, **kwargs) + + def bw_fn(*args, **kwargs): + return eval_trace(backward_trace, *args, **kwargs) + + if not return_traces: + return fw_fn, bw_fn + return fw_fn, bw_fn, augmented_forward_trace, backward_trace + + cd = get_compile_data() + assert cd + # No autotuning + if not cd.compile_options.get('autotune_type', None): + return _make_aug_forward_and_backward() + + # This search will be performed on the requested executors list + is_backend_available: bool = _get_gradfn_and_executor(bsym)[1] is not None + if not is_backend_available: + key = (bsym.sym, None, subkey := _make_cache_key(bsym.args, bsym.kwargs)) + # Cached will be checked in the inner fn if not miss + fw_fn, bw_fn = _make_aug_forward_and_backward() + _cache[key] = fw_fn, bw_fn + return fw_fn, bw_fn + # We have a backend + else: + from thunder.backend_optimizer.optimizer import get_fw_bw_split_backends_options + from thunder.backend_optimizer.utils import benchmark_trace + + # In order define this unique trace region we need a unique id + bsym_id = id(bsym) + key = (bsym.sym, Executor(f'{bsym_id}-autotuned'), subkey := _make_cache_key(bsym.args, bsym.kwargs)) + # We do check the cache here as the key in the inner fn does not know about this special id + cached_result = _cache.get(key, None) if subkey is not None else None + # NOTE: cache is always enabled here + if cached_result is not None: + return cached_result + + # Get the possible backends for the current bsym + backends = get_fw_bw_split_backends_options(bsym) + assert backends + + cached_executors_list = list(cd.executors_list) + # Retrieve all the executors which are requested to be used + requested_executors_list_for_bsym = [ex for ex in cached_executors_list if ex in backends] + # from thunder.executors.torchex import ex as torchex + # if torchex not in executors_list: + # executors_list.append(torchex) + from thunder.benchmarks.utils import SplitFwBwBenchmarkUtils + from thunder.backend_optimizer.optimizer import OptimizerType + best = SplitFwBwBenchmarkUtils() + + # Restrict the search space + backends = list(requested_executors_list_for_bsym) + print(f'Search space for {bsym.sym.name}: {backends}') + for b in backends: + print('Benchmarking backend', b.name) + # Let downstream fn to pick up this + requested_executors_list_for_bsym.remove(b) + requested_executors_list_for_bsym.insert(0, b) + cd.executors_list = requested_executors_list_for_bsym + fw_fn, bw_fn, fw_trace, bw_trace = _make_aug_forward_and_backward(True) + fw_time, fw_mem, _ = benchmark_trace(fw_trace, iters=20, apply_del_last_used=False) + bw_time, bw_mem, _ = benchmark_trace(bw_trace, iters=20, apply_del_last_used=False) + cost = fw_time + bw_time if cd.compile_options['autotune_type'] == OptimizerType.RUNTIME else fw_mem + bw_mem + print('cost', cost) + if cost < best.cost: + best = SplitFwBwBenchmarkUtils(cost = cost, fw_fn = fw_fn, bw_fn = bw_fn, executor = b) + + assert best.cost != float('inf') + from thunder.backend_optimizer.optimizer import log + log(f'Best executor for symbol [{bsym.output.name} = {bsym.sym.name}]: {best.executor.name}') + + # Update the compile options + cached_executors_list = [ex for ex in cached_executors_list if ex not in backends] + cached_executors_list.insert(0, best.executor) + + # Restore executor list for downstream optimizations + cd.executors_list = cached_executors_list + + _cache[key] = best.fw_fn, best.bw_fn + return best.fw_fn, best.bw_fn def get_saved_for_backward_tensors(trace: TraceCtx) -> tuple[TensorProxy]: diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 705f277e59..2589afb057 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -138,7 +138,7 @@ def update_bw_from_forward_optimization(*, fw: TraceCtx, bw:TraceCtx) -> TraceCt return bw -def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stats, autotune_type, /, *flat_args): +def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stats, /, *flat_args): from thunder.core.rematerialization import rematerialize_all_gather, rematerialize_forward_and_backward from thunder.core.transforms import forward_and_backward_from_trace from thunder.distributed.transforms import FSDPCommBucketing @@ -147,393 +147,200 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat from thunder.visualizer.visualizer_helper import Visualizer from thunder.backend_optimizer.optimizer import log, LogLevel, TraceType, BackendOptimizer, OptimizerType, benchmark_trace - def split(): - utils.check(compile_data is not None, lambda: "`compile_data` is required") - # NOTE: This function is rather slow, so it's intended to be used - # behind a cache. - tensor_cls = (torch.Tensor, TensorProxy) - requires_grad_mask = tuple(isinstance(arg, tensor_cls) and arg.requires_grad for arg in flat_args) - # If none of the inputs require gradients, raise an error - if not any(requires_grad_mask): - raise RuntimeError("PyTorch's Autograd interface requires at least one tensor input with requires_grad=True") - - primal_trace = computation_trc - primal_trace = sort_data_parallel_syncs(primal_trace) - - # Handled by the caller if autotune is not None - if compile_stats is not None and autotune_type is None: - compile_stats.last_traces.append(primal_trace) - - # torch.autograd.Function doesn't support non-flat outputs, the - # grads wouldn't be propagated and backward receives None for each - # non-flat non-tensor output. The output must also be a flat tuple, - # not any other container type. So we need to flatten the outputs of - # the forward trace and inputs of the backward trace. - fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True) - - fw_traces = [fw_trace] - bw_traces = [bw_trace] - - from thunder.distributed import FSDPType - - # only enable rematerialize_params_in_backward when using FSDP ZeRO3 - _rematerialize_params_in_backward = ( - getattr(compile_data.fn, "use_fsdp", False) and getattr(compile_data.fn, "sharding_strategy") == FSDPType.ZERO3 - ) - if _rematerialize_params_in_backward: - fw_trace, bw_trace = rematerialize_all_gather(fw_trace, bw_trace) - - # Update the backward trace to only compute gradients for the - # inputs that require gradients - assert bw_trace.bound_symbols[-1].sym.id == PrimIDs.RETURN - filtered_grads = tuple( - (arg_grad if requires_grad else None) - for arg_grad, requires_grad in utils.safe_zip(bw_trace.bound_symbols[-1].args[0], requires_grad_mask) - ) + utils.check(compile_data is not None, lambda: "`compile_data` is required") + # NOTE: This function is rather slow, so it's intended to be used + # behind a cache. + tensor_cls = (torch.Tensor, TensorProxy) + requires_grad_mask = tuple(isinstance(arg, tensor_cls) and arg.requires_grad for arg in flat_args) + # If none of the inputs require gradients, raise an error + if not any(requires_grad_mask): + raise RuntimeError("PyTorch's Autograd interface requires at least one tensor input with requires_grad=True") - # autograd.Function.backward expects a flat tuple of gradients - bw_trace.bound_symbols[-1] = replace(bw_trace.bound_symbols[-1], args=(filtered_grads,)) - - _fsdp_comm_bucketing: FSDPCommBucketing | None = None - if getattr(compile_data.fn, "use_fsdp", False): - _fsdp_comm_bucketing = FSDPCommBucketing(compile_data, computation_trc) - fw_trace = _fsdp_comm_bucketing.apply_bucketing_to_forward_trace(fw_trace) - - do_apply_bucketing_bw_trace: bool = getattr(compile_data.fn, "use_fsdp", False) - - # Now we can run the optimization passes on the forward trace - visualizer = Visualizer(produce_hidden=False) - backend_optimizer_ctx: BackendOptimizer | None = ( - None - if autotune_type is None - else BackendOptimizer( - priority_executors=compile_data.executors_list, - apply_bucketing_bw_trace=do_apply_bucketing_bw_trace, - produce_log=True, - visualizer=visualizer, - optimizer_type=autotune_type, - compile_data=compile_data - ) + autotune_type = compile_data.compile_options.get('autotune_type', None) + + primal_trace = computation_trc + primal_trace = sort_data_parallel_syncs(primal_trace) + + # Handled by the caller if autotune is not None + if compile_stats is not None: + compile_stats.last_traces.append(primal_trace) + + # torch.autograd.Function doesn't support non-flat outputs, the + # grads wouldn't be propagated and backward receives None for each + # non-flat non-tensor output. The output must also be a flat tuple, + # not any other container type. So we need to flatten the outputs of + # the forward trace and inputs of the backward trace. + fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True) + + fw_traces = [fw_trace] + bw_traces = [bw_trace] + + from thunder.distributed import FSDPType + + # only enable rematerialize_params_in_backward when using FSDP ZeRO3 + _rematerialize_params_in_backward = ( + getattr(compile_data.fn, "use_fsdp", False) and getattr(compile_data.fn, "sharding_strategy") == FSDPType.ZERO3 + ) + if _rematerialize_params_in_backward: + fw_trace, bw_trace = rematerialize_all_gather(fw_trace, bw_trace) + + # Update the backward trace to only compute gradients for the + # inputs that require gradients + assert bw_trace.bound_symbols[-1].sym.id == PrimIDs.RETURN + filtered_grads = tuple( + (arg_grad if requires_grad else None) + for arg_grad, requires_grad in utils.safe_zip(bw_trace.bound_symbols[-1].args[0], requires_grad_mask) + ) + + # autograd.Function.backward expects a flat tuple of gradients + bw_trace.bound_symbols[-1] = replace(bw_trace.bound_symbols[-1], args=(filtered_grads,)) + + _fsdp_comm_bucketing: FSDPCommBucketing | None = None + if getattr(compile_data.fn, "use_fsdp", False): + _fsdp_comm_bucketing = FSDPCommBucketing(compile_data, computation_trc) + fw_trace = _fsdp_comm_bucketing.apply_bucketing_to_forward_trace(fw_trace) + + do_apply_bucketing_bw_trace: bool = getattr(compile_data.fn, "use_fsdp", False) + + # Now we can run the optimization passes on the forward trace + visualizer = Visualizer(produce_hidden=False) + backend_optimizer_ctx: BackendOptimizer | None = ( + None + if autotune_type is None + else BackendOptimizer( + priority_executors=compile_data.executors_list, + apply_bucketing_bw_trace=do_apply_bucketing_bw_trace, + produce_log=True, + visualizer=visualizer, + optimizer_type=autotune_type, + compile_data=compile_data ) + ) - visualizer.set_fw_initial_trace(fw_trace) - # Get optimzied fw trace - fw_extrace = ( - transform_for_execution(fw_trace, executors_list=compile_data.executors_list) - if autotune_type is None - else autotune_transform_for_execution( - optimizer_context=backend_optimizer_ctx, trace=fw_trace, trace_type=TraceType.FW - ) + visualizer.set_fw_initial_trace(fw_trace) + # Get optimzied fw trace + fw_extrace = ( + transform_for_execution(fw_trace, executors_list=compile_data.executors_list) + if autotune_type is None + else autotune_transform_for_execution( + optimizer_context=backend_optimizer_ctx, trace=fw_trace, trace_type=TraceType.FW ) + ) - # If in default mode, otherwise the best fw will be returned only at the end - if autotune_type is None: - # Here fw_extrace is not None + # If in default mode, otherwise the best fw will be returned only at the end + if autotune_type is None: + # Here fw_extrace is not None - fw_traces.append(fw_extrace) - visualizer.set_fw_optimized_trace(fw_extrace) + fw_traces.append(fw_extrace) + visualizer.set_fw_optimized_trace(fw_extrace) - # If autotuning is activated, it will take care of the followinf 2 calls - bw_trace = update_bw_from_forward_optimization(fw=fw_extrace, bw=bw_trace) - if do_apply_bucketing_bw_trace: - bw_trace = _fsdp_comm_bucketing.apply_bucketing_to_backward_trace(bw_trace) + # If autotuning is activated, it will take care of the following 2 calls + bw_trace = update_bw_from_forward_optimization(fw=fw_extrace, bw=bw_trace) + if do_apply_bucketing_bw_trace: + bw_trace = _fsdp_comm_bucketing.apply_bucketing_to_backward_trace(bw_trace) - # Now we can run the optimization passes on the backward trace + # Now we can run the optimization passes on the backward trace + visualizer.set_bw_initial_trace(bw_trace) + if autotune_type is not None: + fw_extrace, bw_extrace = autotune_transform_for_execution( + optimizer_context=backend_optimizer_ctx, trace=bw_trace, trace_type=TraceType.BW + ) + fw_traces.append(fw_extrace) + visualizer.set_bw_optimized_trace(fw_extrace) + else: + bw_extrace = transform_for_execution( + bw_trace, + executors_list=compile_data.executors_list, + ) + bw_traces.append(bw_extrace) + visualizer.set_bw_optimized_trace(bw_extrace) + + if autotune_type is None: # TODO Restore request for no rematerialization - visualizer.set_bw_initial_trace(bw_trace) - if autotune_type is not None: - fw_extrace, bw_extrace = autotune_transform_for_execution( - optimizer_context=backend_optimizer_ctx, trace=bw_trace, trace_type=TraceType.BW + fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) + # Autotuner has been taken care of remat + else: + pass + fw_traces.append(fw_extrace) + bw_traces.append(bw_extrace) + + # We need to sort the waits in forward and backward trace to overlap + # computation with communication + # For performance we need the wait_prim_impl nodes in the execution trace to be as far from the + # communication ops as possible. But it causes the all_gather_prim_impl nodes gathered at the start of + # backward trace and increases the peak allocated memory + use_fsdp: bool = getattr(compile_data.fn, "use_fsdp", False) + if use_fsdp: + assert hasattr(compile_data.fn, "sharding_strategy") + if getattr(compile_data.fn, "sharding_strategy") == FSDPType.ZERO3: + from thunder.distributed import FSDPBucketingStrategy + from thunder.distributed.utils import limit_in_flight_allgathers + + fw_extrace = sort_communication_ops(fw_extrace) + fw_extrace = limit_in_flight_allgathers( + fw_extrace, + 3, + compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE, ) - fw_traces.append(fw_extrace) - visualizer.set_bw_optimized_trace(fw_extrace) - else: - bw_extrace = transform_for_execution( - bw_trace, - executors_list=compile_data.executors_list, + bw_extrace = sort_communication_ops(bw_extrace) + bw_extrace = limit_in_flight_allgathers( + bw_extrace, + 3, + compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE, + ) + if getattr(compile_data.fn, "sharding_strategy") == FSDPType.ZERO2: + from thunder.distributed import FSDPBucketingStrategy + from thunder.distributed.utils import limit_in_flight_allgathers + from sys import maxsize as INT_MAX + + # sort the allgather+wait as consumer order just before consumer + fw_extrace = sort_communication_ops(fw_extrace) + # unlimited number of allgathers, i.e. allgathers are listed at the beginning of the trace in consumer order and wait stays just before wait + fw_extrace = limit_in_flight_allgathers( + fw_extrace, + INT_MAX, + compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE, ) - bw_traces.append(bw_extrace) - visualizer.set_bw_optimized_trace(bw_extrace) - - if autotune_type is None: - # TODO Restore request for no rematerialization - fw_extrace, bw_extrace = rematerialize_forward_and_backward(fw_extrace, bw_extrace) - # Autotuner has been taken care of remat - else: - pass - fw_traces.append(fw_extrace) - bw_traces.append(bw_extrace) - - # We need to sort the waits in forward and backward trace to overlap - # computation with communication - # For performance we need the wait_prim_impl nodes in the execution trace to be as far from the - # communication ops as possible. But it causes the all_gather_prim_impl nodes gathered at the start of - # backward trace and increases the peak allocated memory - use_fsdp: bool = getattr(compile_data.fn, "use_fsdp", False) - if use_fsdp: - assert hasattr(compile_data.fn, "sharding_strategy") - if getattr(compile_data.fn, "sharding_strategy") == FSDPType.ZERO3: - from thunder.distributed import FSDPBucketingStrategy - from thunder.distributed.utils import limit_in_flight_allgathers - - fw_extrace = sort_communication_ops(fw_extrace) - fw_extrace = limit_in_flight_allgathers( - fw_extrace, - 3, - compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE, - ) - bw_extrace = sort_communication_ops(bw_extrace) - bw_extrace = limit_in_flight_allgathers( - bw_extrace, - 3, - compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE, - ) - if getattr(compile_data.fn, "sharding_strategy") == FSDPType.ZERO2: - from thunder.distributed import FSDPBucketingStrategy - from thunder.distributed.utils import limit_in_flight_allgathers - from sys import maxsize as INT_MAX - - # sort the allgather+wait as consumer order just before consumer - fw_extrace = sort_communication_ops(fw_extrace) - # unlimited number of allgathers, i.e. allgathers are listed at the beginning of the trace in consumer order and wait stays just before wait - fw_extrace = limit_in_flight_allgathers( - fw_extrace, - INT_MAX, - compile_data.fn.bucketing_strategy != FSDPBucketingStrategy.NONE, - ) - bw_extrace = sort_waits(bw_extrace) - use_ddp: bool = getattr(compile_data.fn, "use_ddp", False) - if use_ddp: bw_extrace = sort_waits(bw_extrace) - if (not use_ddp) and (not use_fsdp): - from thunder.distributed.utils import maybe_sort_waits + use_ddp: bool = getattr(compile_data.fn, "use_ddp", False) + if use_ddp: + bw_extrace = sort_waits(bw_extrace) + if (not use_ddp) and (not use_fsdp): + from thunder.distributed.utils import maybe_sort_waits - _, fw_extrace = maybe_sort_waits(fw_extrace) - _, bw_extrace = maybe_sort_waits(bw_extrace) + _, fw_extrace = maybe_sort_waits(fw_extrace) + _, bw_extrace = maybe_sort_waits(bw_extrace) - # Importing here to avoid cyclical dependencies in future. - from thunder.executors.transformer_engineex import _transformer_engine_bwd_fp8_meta_sync, transformer_engine_ex + # Importing here to avoid cyclical dependencies in future. + from thunder.executors.transformer_engineex import _transformer_engine_bwd_fp8_meta_sync, transformer_engine_ex - if transformer_engine_ex in compile_data.executors_list: - # NOTE: `_transformer_engine_bwd_fp8_meta_sync` may mutate `fw_extrace` or `bw_extrace`. - _transformer_engine_bwd_fp8_meta_sync(fw_extrace, bw_extrace) + if transformer_engine_ex in compile_data.executors_list: + # NOTE: `_transformer_engine_bwd_fp8_meta_sync` may mutate `fw_extrace` or `bw_extrace`. + _transformer_engine_bwd_fp8_meta_sync(fw_extrace, bw_extrace) - fw_extrace = del_last_used(fw_extrace) - fw_traces.append(fw_extrace) - - bw_extrace = del_last_used(bw_extrace, clear_mutable_collections=True) - bw_traces.append(bw_extrace) - - bw_trace = rename_bwd_trace_outputs(bw_extrace, fw_extrace) + fw_extrace = del_last_used(fw_extrace) + fw_traces.append(fw_extrace) - # This is moved to the caller if autotune is enabled - if compile_stats is not None and autotune_type is None: - compile_stats.last_traces += fw_traces - compile_stats.last_backward_traces += bw_traces + bw_extrace = del_last_used(bw_extrace, clear_mutable_collections=True) + bw_traces.append(bw_extrace) - # Enable wrapping with `te.fp8_autocast`. - fw_extrace._include_te_fp8_autocast = True - # We only want the forward function to be called with `te.fp8_autocast` manager. - bw_extrace._include_te_fp8_autocast = False + bw_trace = rename_bwd_trace_outputs(bw_extrace, fw_extrace) - # Let's include the last traces also after all the passes - visualizer.set_fw_final_trace(fw_extrace) - visualizer.set_bw_final_trace(bw_extrace) + # This is moved to the caller if autotune is enabled + if compile_stats is not None: + compile_stats.last_traces += fw_traces + compile_stats.last_backward_traces += bw_traces - # TODO: implement new visualizer - # visualizer.produce() + # Enable wrapping with `te.fp8_autocast`. + fw_extrace._include_te_fp8_autocast = True + # We only want the forward function to be called with `te.fp8_autocast` manager. + bw_extrace._include_te_fp8_autocast = False - if autotune_type is None: - return fw_extrace, bw_extrace - else: - return primal_trace, fw_extrace, bw_extrace, fw_traces, bw_traces - - # Defined executors that are matched inside the fw and bw split, hence outside the autotuner scope - # TODO (matteochen): integrate Transofrmer Engine - from thunder.executors.sdpaex import sdpa_ex - from thunder.executors.cudnnex import cudnn_ex - from thunder.executors.fa3ex import fa3_ex - from thunder.executors.transformer_engineex import transformer_engine_ex - from thunder.executors.torchex import ex as torchex - - executors_candidates: dict[str, list] = { - 'linear': [transformer_engine_ex.name], - 'scaled_dot_product_attention': [sdpa_ex.name, cudnn_ex.name, fa3_ex.name], - } + # Let's include the last traces also after all the passes + visualizer.set_fw_final_trace(fw_extrace) + visualizer.set_bw_final_trace(bw_extrace) - # TODO (matteochen): use BackendOptimizer tracing + # TODO: implement new visualizer + # visualizer.produce() - # If autotuner is enabled, we compare different impl of executors which are assigned inside the call 'forward_and_backward_from_trace' - # as the autotuner will receive already split fw and bw traces - if autotune_type is not None: - cached_executor_list = list(compile_data.executors_list) - cached_executor_list_copy = list(compile_data.executors_list) - placed = set() - - def find_torchex_index(): - index = 0 - for i, e in enumerate(cached_executor_list): - if e in placed: - index = i+1 - return min(len(cached_executor_list), index) - - try: - from thunder.benchmarks.utils import AutotunerTorchAutogradBenchmarkUtils - - is_tuned = False - benchmark_iters: int = 20 - global_optimal_result = AutotunerTorchAutogradBenchmarkUtils() - - for i, (ex_type, ex_list) in enumerate(executors_candidates.items()): - # We need to reference some additional structures other than the best fw and bw traces as we have to update compile data only after we got the optimal choice - result = AutotunerTorchAutogradBenchmarkUtils() - log( - f"================================================================================ Before Autotune Tuning: Optimizing {ex_type}", - level=LogLevel.INFO) - # Filter out the executors based on the executors list, maybe not all the options have to be used - # Torch executor (default kernel) will be given a chance always - to_benchmark: list[Executor] = [ex for ex in cached_executor_list if ex.name in ex_list] - # Add torchexecutor if not present - if torchex not in to_benchmark: - to_benchmark.append(torchex) - # Verify that op is present in the trace - # op_in_trace: bool = operation_in_trace(trace=computation_trc, op=ex_type) - # TODO (matteochen): currently it is bugged if op is not in trace - op_in_trace: bool = True - - if (not to_benchmark and op_in_trace) or not op_in_trace: - log( - f"================================================================================ Before Autotune Tuning: Skipping optimization for {ex_type} as not requested or not present in computation_trc.", - level=LogLevel.INFO, - ) - continue - - log( - f"================================================================================ Before Autotune Tuning: Executors to bench for {ex_type}: {to_benchmark}", - level=LogLevel.INFO, - ) - for e in to_benchmark: - # Create the executor list putting the executor under analysis at the head of queue - # 1. Add all executors except the ones under benchmark - compile_data.executors_list = [ex for ex in cached_executor_list if ex not in to_benchmark] - # 2. Make the current one with most priority to be picked up by - if e in compile_data.executors_list: - compile_data.executors_list.insert(0, compile_data.executors_list.pop(compile_data.executors_list.index(e))) - else: - compile_data.executors_list.insert(0, e) - # TODO: write why we have to place torchex as near as possible to the start of the list - if torchex not in compile_data.executors_list: - torchex_index = max(1, find_torchex_index()) - compile_data.executors_list.insert(torchex_index, torchex) - log( - f"================================================================================ Before Autotune Tuning: Testing compile data executors: {compile_data.executors_list}", level=LogLevel.INFO) - try: - primal_trace, fw_extrace, bw_extrace, fw_traces, bw_traces = split() - except Exception as exception: - import traceback - ex_str = traceback.format_exc() - log( - f"================================================================================ Before Autotune Tuning: Failed to place {e.name}: {exception}\n{ex_str}") - continue - - time_fw, mem_fw, _ = benchmark_trace(fw_extrace, iters=benchmark_iters, apply_del_last_used=False) - time_bw, mem_bw, _ = benchmark_trace(bw_extrace, iters=benchmark_iters, apply_del_last_used=False, fw_trace=fw_extrace) - tot_time = time_fw + time_bw - tot_mem = mem_fw + mem_bw - log( - f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} options: {e.name}. Time fw = {time_fw} ms - Time bw = {time_bw} ms - Mem fw = {mem_fw / (2**30)} GB - Mem bw = {mem_bw / (2**30)} GB", level=LogLevel.INFO) - log( - f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} options: {e.name}. Time = {tot_time} ms - Mem = {tot_mem / (2**30)} GB", level=LogLevel.INFO) - - benchmark_cost = tot_time if autotune_type == OptimizerType.RUNTIME else tot_mem - if benchmark_cost < result.cost: - is_tuned = True - result = AutotunerTorchAutogradBenchmarkUtils( - benchmark_cost, - fw_extrace, - bw_extrace, - fw_traces, - bw_traces, - primal_trace, - e - ) - if benchmark_cost < global_optimal_result.cost: - is_tuned = True - global_optimal_result = AutotunerTorchAutogradBenchmarkUtils( - benchmark_cost, - fw_extrace, - bw_extrace, - fw_traces, - bw_traces, - primal_trace, - selected_executors=list(compile_data.executors_list) - ) - - log(f'Best executor for {ex_type} iteration: {result.executor}') - - c, m , _ = benchmark_trace(result.fw_trace, iters=benchmark_iters, apply_del_last_used=False, level=LogLevel.DEBUG) - log( - f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} best fw_extrace (time = {c}, mem = {m}):\n{result.fw_trace}", level=LogLevel.DEBUG) - c, m , _ = benchmark_trace(result.bw_trace, iters=benchmark_iters, apply_del_last_used=False, fw_trace=result.fw_trace, level=LogLevel.DEBUG) - log( - f"================================================================================ Before Autotune Tuning: Benchmark {ex_type} best bw_extrace (time = {c}, mem = {m}):\n{result.bw_trace}", level=LogLevel.DEBUG) - - # Update the executor list with the winner executor for the current ex_type - # TODO: unify this by using global_optimal_result as it should contain already the optimal placement - cached_executor_list = [ex for ex in cached_executor_list if ex not in to_benchmark] - if result.executor is None: - log( - f"================================================================================ Before Autotune Tuning: Could not find best executor for {ex_type}. Assigning torchex by default", level=LogLevel.INFO) - result.executor = torchex - cached_executor_list.insert(0, result.executor) - placed.add(result.executor) - # Restore all placed but not included in the executor list - for ex_placed in placed: - if ex_placed not in cached_executor_list: - cached_executor_list.append(ex_placed) - log( - f"================================================================================ Before Autotune Tuning: Best executor for {ex_type}: {result.executor.name}", level=LogLevel.INFO) - log( - f"================================================================================ Before Autotune Tuning: Benchmark {ex_type}, updated executor list: {cached_executor_list}", level=LogLevel.INFO) - - # Update the compile stats on the last iter - if i == len(executors_candidates)-1: - # Check that we have solution, we don't have it if not requested from the executor list - if not is_tuned: - raise AssertionError( - f"No executors have been placed inside the trace. Will autotune the computation_trc ignoring the following executors:\n{executors_candidates}" - ) - - # Restore - executors_list_to_restore = cached_executor_list if result.cost < global_optimal_result.cost else global_optimal_result.selected_executors - result_to_assign = result if result.cost < global_optimal_result.cost else global_optimal_result - - compile_data.executors_list = list(executors_list_to_restore) - log( - f"================================================================================ Before Autotune Tuning: autotuned split_forward_backward from {executors_candidates}", level=LogLevel.INFO) - if compile_stats is not None: - compile_stats.last_traces.append(result_to_assign.primal_trace) - compile_stats.last_traces += result_to_assign.fw_traces - compile_stats.last_backward_traces += result_to_assign.bw_traces - - return result_to_assign.fw_trace, result_to_assign.bw_trace - except Exception as exception: - import traceback - ex_str = traceback.format_exc() - log(f'Exception occured when tuning {executors_candidates}: {exception}\n{ex_str}') - # Restore before calling split - compile_data.executors_list = cached_executor_list_copy - - log( - f"================================================================================ Before Autotune Tuning: exception occured, executors:\n{executors_candidates} will not be autotuned (priority list policy will be used)", - level=LogLevel.INFO, - ) - primal_trace, fw_extrace, bw_extrace, fw_traces, bw_traces = split() - if compile_stats is not None: - compile_stats.last_traces.append(primal_trace) - compile_stats.last_traces += fw_traces - compile_stats.last_backward_traces += bw_traces - - return fw_extrace, bw_extrace - else: - return split() + return fw_extrace, bw_extrace From c0f73c5496a27bf468170c6bb553f0dbf2668216 Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 20 Aug 2024 22:23:24 +0300 Subject: [PATCH 058/171] Beam search for fw bw split operators --- thunder/executors/torch_autograd.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 2589afb057..85f112b1e8 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -3,7 +3,6 @@ import torch -from thunder.backend_optimizer.utils import operation_in_trace import thunder.core.utils as utils from thunder.core.prims import PrimIDs from thunder.core.proxies import TensorProxy, variableify @@ -11,7 +10,6 @@ from thunder.core.symbol import BoundSymbol from thunder.core.trace import TraceCtx, from_trace, set_tracectx, reset_tracectx from thunder.core.transform_common import replace_redundant_inputs -from thunder.extend import OperatorExecutor, Executor from thunder.core.vjp_utils import get_saved_for_backward_tensors if TYPE_CHECKING: From 9b1d0cb0c985800abbfdb6d24e1966f0bcf78641 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 21 Aug 2024 15:55:41 +0300 Subject: [PATCH 059/171] Fixed issues about remat and solved sdpa backward pass benchmark issues about wrong input args --- examples/dev/sdpa.py | 23 +- thunder/__init__.py | 5 +- thunder/backend_optimizer/optimizer.py | 22 +- thunder/backend_optimizer/utils.py | 337 +++++++++++++++++-------- thunder/core/vjp_utils.py | 21 +- thunder/executors/sdpaex.py | 7 +- thunder/executors/torch_autograd.py | 7 +- 7 files changed, 283 insertions(+), 139 deletions(-) diff --git a/examples/dev/sdpa.py b/examples/dev/sdpa.py index 505855fece..63a71f9aef 100644 --- a/examples/dev/sdpa.py +++ b/examples/dev/sdpa.py @@ -2,33 +2,34 @@ import thunder from thunder.benchmarks.utils import thunder_fw_bw_benchmark -torch.set_default_dtype(torch.bfloat16) +dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 +torch.set_default_dtype(dtype) +print(f'Script data type: {dtype}') class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() - def forward(self, query, key, value, q_l, k_l, v_l): + def forward(self, query, key, value): a = torch.nn.functional.scaled_dot_product_attention(query, key, value) - b = torch.nn.functional.scaled_dot_product_attention(q_l, k_l, v_l) - return a, b + # Make different inputs as happens in a real model + b = torch.nn.functional.scaled_dot_product_attention(query+query, key+key, value+value) + c = torch.nn.functional.scaled_dot_product_attention(query+query, key+key, value+value) + d = torch.nn.functional.scaled_dot_product_attention(query+query, key+key, value+value) + return a * b + c - d with torch.device('cuda'): model = Model() jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type='runtime', executors = ['nvfuser', 'sdpa', 'cudnn', 'torch', 'python']) + jmodel_auto = thunder.jit(model, autotune_type='runtime', executors = ['nvfuser', 'cudnn', 'sdpa'], use_cudagraphs=False) q = torch.rand(32, 8, 128, 64*1, requires_grad=True) k = torch.rand(32, 8, 128, 64*1, requires_grad=True) v = torch.rand(32, 8, 128, 64*1, requires_grad=True) - q_l = torch.rand(32, 8, 128, 64*1, requires_grad=True) - k_l = torch.rand(32, 8, 128, 64*1, requires_grad=True) - v_l = torch.rand(32, 8, 128, 64*1, requires_grad=True) - - jmodel_def(q, k, v, q_l, k_l, v_l) - jmodel_auto(q, k, v, q_l, k_l, v_l) + jmodel_def(q, k, v) + jmodel_auto(q, k, v) iters = 100 fw_traces = [ diff --git a/thunder/__init__.py b/thunder/__init__.py index 4e6d6dcc09..f932b0cc40 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -327,7 +327,10 @@ def jit( if required_autotune not in ['runtime', 'memory']: raise AssertionError(f'Not supported optimization: {required_autotune}') - compile_options |= {"autotune_type": OptimizerType.RUNTIME if required_autotune == 'runtime' else OptimizerType.MEMORY} + compile_options |= { + "autotune_type": OptimizerType.RUNTIME if required_autotune == 'runtime' else OptimizerType.MEMORY, + "executors_placed_by_fw_bw_split": set() + } # Default the executors list to all_executors if no options are given # Otherwise the user restricted choice will be used diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 648ffd1bad..f3e74b8dd5 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -11,7 +11,7 @@ from thunder.backend_optimizer.utils import benchmark_trace # Defining a wrapper fn as the imports will crash in the global scope -def get_fw_bw_split_backends_options(bsym: BoundSymbol) -> list: +def get_fw_bw_split_backends_options(bsym: BoundSymbol | None = None) -> list | dict: from thunder.executors.sdpaex import sdpa_ex from thunder.executors.cudnnex import cudnn_ex from thunder.executors.fa3ex import fa3_ex @@ -23,7 +23,7 @@ def get_fw_bw_split_backends_options(bsym: BoundSymbol) -> list: 'scaled_dot_product_attention': [sdpa_ex, cudnn_ex, fa3_ex], } - return options.get(bsym.sym.name, []) + return options.get(bsym.sym.name, []) if bsym else options class BenchmarkResult: @@ -602,8 +602,8 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): # Retrieve partial trace and benchmark, apply remat if possible trc, _, _ = get_placed_trace(dict_time_strat, increasing_symbols) # Apply fw bw remat - if self.trace_type == TraceType.BW and self.active_fw_trace_ctx[0] is not None: - _, trc = rematerialize_forward_and_backward(self.active_fw_trace_ctx[0], trc) + # if self.trace_type == TraceType.BW and self.active_fw_trace_ctx[0] is not None: + # _, trc = rematerialize_forward_and_backward(self.active_fw_trace_ctx[0], trc) # Now, benchmark t, m, _ = benchmark_trace(trc, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0]) # Update results @@ -649,8 +649,8 @@ def measure_and_update_result(): nonlocal best_placement_mem nonlocal best_keys_mem trc, keys, placements = get_placed_trace(dict_time_strat, increasing_symbols) - if self.trace_type == TraceType.BW and self.active_fw_trace_ctx[0] is not None: - _, trc = rematerialize_forward_and_backward(self.active_fw_trace_ctx[0], trc) + # if self.trace_type == TraceType.BW and self.active_fw_trace_ctx[0] is not None: + # _, trc = rematerialize_forward_and_backward(self.active_fw_trace_ctx[0], trc) cost, mem, _ = benchmark_trace(trc, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0]) log(f"Placed trace (cost = {cost} ms, mem = {mem/(2**30)} GB)\n{trc}", level=LogLevel.DEBUG) if cost < best_res_time.runtime or (cost == best_res_time.runtime and mem < best_res_time.memory): @@ -802,9 +802,9 @@ def measure_and_update_result(): empty_str=self.empty_executor_hashable_placeholder, ) # print(f"Assigned trace:\n{trc}") - if self.trace_type == TraceType.BW: - # pass - _, trc = rematerialize_forward_and_backward(self.active_fw_trace_ctx[0], trc) + # if self.trace_type == TraceType.BW: + # # pass + # _, trc = rematerialize_forward_and_backward(self.active_fw_trace_ctx[0], trc) container.append({ex.name: trc}) # Save executors in order to generate real fw and bw trace with correct output with the placer @@ -874,7 +874,7 @@ def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): match self.trace_type: case TraceType.FW: log( - f"New forward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}", level=LogLevel.DEBUG + f"New forward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}", level=LogLevel.INFO ) # TODO (matteochen): support bw trace optimization even though with no fw traces cached (computational trace?) case TraceType.BW: @@ -882,7 +882,7 @@ def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): raise AssertionError("Can not optimize backward traces before forward traces") log( f"New backward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}", - level=LogLevel.DEBUG, + level=LogLevel.INFO, ) def optimize(self): diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 91fcdbe82d..6dd5412cc9 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -1,5 +1,6 @@ from collections.abc import Callable, Hashable, Sequence from typing import Any +from thunder.core.compile_data import get_compile_data from thunder.core.dtypes import to_torch_dtype from thunder.core.prims import PrimIDs from thunder.core.proxies import AnyProxy, CollectionProxy, FloatProxy, IntegerProxy, Proxy, TensorProxy, Variable, variableify @@ -11,6 +12,7 @@ from itertools import chain import torch +# Maybe we can use id(s) def sequence_hash(s: Sequence) -> str: def rec(s) -> str: name = "[" @@ -258,8 +260,30 @@ def operation_in_trace(*, trace: TraceCtx, op: str) -> bool: def is_te_used(trace: TraceCtx) -> bool: from thunder.executors.transformer_engineex import linear_bound_symbol_name_prefix from thunder.executors.transformer_engineex import te_functional_linear_backward_name + + for bsym in trace.bound_symbols: + if ( + bsym.sym.name.startswith(linear_bound_symbol_name_prefix) + or bsym.sym.name == te_functional_linear_backward_name + ): + return True + return False + +def is_te_ex_bw_used(trace: TraceCtx) -> bool: + from thunder.executors.transformer_engineex import te_functional_linear_backward_name for bsym in trace.bound_symbols: - if bsym.sym.name.startswith(linear_bound_symbol_name_prefix) or bsym.sym.name.startswith(te_functional_linear_backward_name): + if bsym.sym.name == te_functional_linear_backward_name: + return True + return False + +def is_sdpa_ex_bw_used(trace: TraceCtx) -> bool: + from thunder.executors.sdpaex import ( + sdpaex_scaled_dot_product_efficient_attention_backward_name as n1, + sdpafx_scaled_dot_product_efficient_attention_backward_name as n2, + ) + + for bsym in trace.bound_symbols: + if bsym.sym.name == n1 or bsym.sym.name == n2: return True return False @@ -277,9 +301,6 @@ def benchmark_trace( from thunder.executors.passes import del_last_used import inspect - # In order to benchmark traces with TE enabled, the backward pass needs the context object returned from the forward trace - cached_fw_te_ctx_out = None - def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, float, Any]: try: warm_up_iters = 50 @@ -302,9 +323,8 @@ def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, f return float("inf"), float("inf"), None except Exception as e: import inspect - trc = inspect.getsource(fn) - print(f"#NVSIGHT FN EXECUTION FAILED:\n{trc}") + print(f"#Trace execution failed for nvsight (error: {e}):\n{trc}") raise e def clone_args(args): @@ -321,6 +341,7 @@ def clone_args(args): def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[float, float, Any]: try: + current_iter = 0 warm_up_iters = 50 out = None @@ -348,6 +369,7 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] torch.cuda.synchronize() for i in range(iters): + current_iter = i cloned_args = clone_args(args) torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) torch.cuda.empty_cache() @@ -366,73 +388,11 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl tot_time = sum(times) / iters return tot_time, max_allocated_bytes, out except Exception as e: - print(f"#FN EXECUTION FAILED:\n{repr}") + print(f"#Trace execution failed at iter {current_iter} (error: {e})\n{repr}") raise e - # TODO (matteochen): use more appropriate mock int and float - def transform_input_sequence(sequence: Sequence, level=0) -> tuple | list: - res = [] - for e in sequence: - if type(e) is tuple: - res.append(transform_input_sequence(e, level + 1)) - else: - if isinstance(e, TensorProxy): - res.append(transform_tensor(e)) - elif isinstance(e, IntegerProxy): - if e.python_type is bool: - res.append(False if e.value is None else e.value) - else: - res.append(0 if e.value is None else e.value) - elif isinstance(e, FloatProxy): - res.append(0.0 if e.value is None else e.value) - # Transformer engine context object - elif hasattr(e, 'name') and isinstance(e, AnyProxy) and e.name.startswith('ctx_te'): - res.append(cached_fw_te_ctx_out) - elif e is None: - res.append(None) - else: - raise AssertionError(f'Input arg type not recognized: {type(e)} with name: {e.name if hasattr(e, "name") else "unknown"} with value: {e}') - return tuple(res) if level > 0 else res - - def transform_tensor(arg: TensorProxy) -> torch.Tensor: - from thunder.core.dtypes import is_float_dtype, is_signedinteger_dtype, is_boolean_dtype - - # TODO (matteochen): Missing parallel and fsdp handling... - # TODO (matteochen): Missing support for meta types ... - dtype = arg.dtype - shape = arg.shape - device = arg.device - requires_grad = arg.requires_grad - torch_dtype = to_torch_dtype(dtype) - if torch_dtype is None: - raise AssertionError(f'Unrecognized thunder dtype: {dtype}') - if is_float_dtype(dtype): - # Use TE Float8 if TE is enabled, it has float32 ad torch dtype - if te_used: - tensor: torch.Tensor = torch.randn( - shape, dtype=torch_dtype if dtype.bytes > 1 else torch.float32, device=device.device_str(), requires_grad=requires_grad - ) - if dtype.bytes == 1: - import transformer_engine.pytorch as te - tensor = te.float8_tensor.Float8Tensor.to_float8(tensor) - # Support standard float tensors - else: - tensor: torch.Tensor = torch.randn( - shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad - ) - elif is_signedinteger_dtype(dtype): - tensor: torch.Tensor = torch.randint( - 0, 8, shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad - ) - elif is_boolean_dtype(dtype): - # TODO (matteochen): maybe random? - tensor: torch.Tensor = torch.zeros( - *shape, dtype=torch.bool, device=device.device_str(), requires_grad=requires_grad - ) - else: - raise AssertionError(f"dtype {dtype} not supported yet") - - return tensor + def build_static_args(sequence: Sequence, **kwargs): + return transform_proxy_to_torch(sequence, level=0, **kwargs) # We have to fix the saved_for_backward tuple as TE output TensorProxy don't have a correct one def fix_te_backward_inputs(inputs: list): @@ -450,6 +410,92 @@ def fix_te_backward_inputs(inputs: list): fixed_inputs_first_index = tuple([tuple(saved_for_bw), inputs[0][1]]) return fixed_inputs_first_index + def TE_backward_trace_preprocessing(): + nonlocal input_args + # Due to the fw and bw split benchmarking we have to check the bw nature by looking for the bsym + is_bw = is_te_ex_bw_used(trace) + # If transformer_engine executor is used and it is the bw function we have to recover the forward context from the forward trace + if is_bw and 'fw_trace' not in kwargs: + raise RuntimeError('Set the associated forward trace in order to benchmark backward pass with TE executor') + elif is_bw: + # print('TE Benchmarking fw trace for bw') + fw_trace = kwargs.get('fw_trace', None) + if not isinstance(fw_trace, TraceCtx): + raise AssertionError(f'forward trace is not a TraceCtx. Received: {type(fw_trace)}') + # Run the fw trace and get the outputs + fw_output = benchmark_trace(fw_trace, apply_del_last_used=False)[2] + # Retrive the context from the fw pass output + # Currently it will contain an empty transformer_engineex.Context but might be useful for the future + cached_fw_te_ctx_out = fw_output[1][1][0] + + # After got the Context object from the fw pass we con build the input args list for the bw pass + input_args = build_static_args(trace.args, cached_fw_te_ctx_out=cached_fw_te_ctx_out, te_used=True) + + # Fix TE arguments for benchmark + first_tuple = fix_te_backward_inputs(input_args) + input_args.pop(0) + input_args.insert(0, first_tuple) + + def SDPA_backward_trace_preprocessing(): + nonlocal input_args + if 'fw_trace' not in kwargs: + raise RuntimeError('Set the associated forward trace in order to benchmark backward pass with sdpa executor') + fw_trace = kwargs.get('fw_trace', None) + if not isinstance(fw_trace, TraceCtx): + raise AssertionError(f'forward trace is not a TraceCtx. Received: {type(fw_trace)}') + # Run the fw trace and get the outputs + fw_output = benchmark_trace(fw_trace, apply_del_last_used=False)[2] + + # Check if the fw trace is a final trace or an intermediate one (used for single trace region benchmarks) + sig = fw_trace.signature_with_no_ctx() + is_fw_final_trace = sig.startswith('def augmented') + print(f'Is backward sdpa final trace? {is_fw_final_trace}') + + # Filter the output tuple + saved_for_bw_C0 = fw_output[1] if not is_fw_final_trace else fw_output[1][0] + + # Retrieve the compile time arguments + input_args = build_static_args(trace.args) + + # Now, we expected that if the fw trace is a final trace also the bw trace is a final one. And vice versa + if is_fw_final_trace: + # Swap saved_for_backward_traces + saved_for_bw = saved_for_bw_C0, fw_output[1][1] # Saved for backward tuple unpacks in (C0, _) + input_args.pop(0) + input_args.insert(0, saved_for_bw) + else: + # SDPA single region backward trace receives as input the saved_for_bw tensors plus some others. + # They are indexed like [saved_for_bw, others...] + """ + Example: + @torch.no_grad() + @no_autocast + def _cudnn_sdpa_bwd_wrapper(query, key, value, attn_mask, dropout_p=0.0, is_causal=False, *, scale=None): + # query: "cuda:0 bf16[32, 8, 128, 64]" + # key: "cuda:0 bf16[32, 8, 128, 64]" + # value: "cuda:0 bf16[32, 8, 128, 64]" + # dropout_p: "float 0.0" + # is_causal: "bool False" + (t0, t1, t2, t3) = cudnn_sdpa_fwd(query, key, value, None, dropout_p, is_causal, scale=None) + return (t0, [query, key, value, dropout_p, is_causal, t0, t1, t2, t3]) + + @torch.no_grad() + @no_autocast + def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causal, t0, t1, t2, t3, t4): + (t5, t6, t7) = cudnn_sdpa_bwd(t4, query, key, value, None, dropout_p, is_causal, t0, t1, t2, t3, scale=None, cat_grad_qkv=False) + return {'query': t5, 'key': t6, 'value': t7, 'attn_mask': None, 'dropout_p': None, 'is_causal': None, 'scale': None} + + See how the backward trace need t4 as argument recoveered from the static args + """ + updated_input_args = [t for t in saved_for_bw_C0] + updated_input_args.extend(input_args[len(updated_input_args):]) + # print('Updated input_args') + # print_args(updated_input_args) + input_args = updated_input_args + + # Input args for the trace to benchmark + input_args = [] + # Check for correctness if trace.bound_symbols[-1].sym.id != PrimIDs.RETURN: raise AssertionError("Missing return statement") @@ -457,33 +503,23 @@ def fix_te_backward_inputs(inputs: list): if apply_del_last_used: trace = del_last_used(trace) - # Enable TE fp8 autocast if needed + # Handle TE traces te_used = is_te_used(trace) + sdpa_ex_bw_used = is_sdpa_ex_bw_used(trace) + if te_used and sdpa_ex_bw_used: + raise AssertionError("Not handled") + if te_used: cached_te_fp8_autocast_value = trace._include_te_fp8_autocast trace._include_te_fp8_autocast = True + TE_backward_trace_preprocessing() + # Fix sdpaex arguments for backward benchmarks + elif sdpa_ex_bw_used: + SDPA_backward_trace_preprocessing() + # "Default" trace, parse the input args...(input args parsing will be performed by the TE trace handling) + else: + input_args: list = build_static_args(trace.args) - # If transformer_engine executor is used and it is the bw function we have to recover the forward context from the forward trace - trace_signature = trace.signature_with_no_ctx() - if te_used and trace_signature.startswith('def backward') and 'fw_trace' not in kwargs: - raise RuntimeError('Set the associated forward trace in order to benchmark backward pass with TE executor') - elif te_used and trace_signature.startswith('def backward'): - # print('TE Benchmarking fw trace for bw') - fw_trace = kwargs.get('fw_trace', None) - if not isinstance(fw_trace, TraceCtx): - raise AssertionError(f'forward trace is not a TraceCtx. Received: {type(fw_trace)}') - # Run the fw trace and get the outputs - fw_output = benchmark_trace(fw_trace, apply_del_last_used=False)[2] - # Retrive the context from the fw pass output - # Currently it will contain an empty transformer_engineex.Context but might be useful for the future - cached_fw_te_ctx_out = fw_output[1][1][0] - - input_args: list = transform_input_sequence(trace.args) - - if te_used and trace_signature.startswith('def backward'): - first_tuple = fix_te_backward_inputs(input_args) - input_args.pop(0) - input_args.insert(0, first_tuple) trace_tok = set_tracectx(trace) @@ -580,15 +616,114 @@ def print_trace_args(trace: TraceCtx): # Display nest sequence arguments def print_args(args): - print('\n\n###################################### Debug args') - def _print(args): - print('Sequence start') + + def is_tensor(t): + return isinstance(t, torch.Tensor) or isinstance(t, TensorProxy) + + if not isinstance(args, Sequence): + return + print('###################################### Sequence start') + def _print(args, level): + tabs = '\t' * level + print(f'Level {level} start') for arg in args: if isinstance(arg, Sequence): - _print(arg) + _print(arg, level+1) else: - tensor_shape = arg.shape if isinstance(arg, torch.Tensor) else None - print(f'{type(arg)} {tensor_shape if tensor_shape else ""}') - print('Sequence end') - _print(args) + tensor_shape = arg.shape if is_tensor(arg) else None + dtype = arg.dtype if is_tensor(arg) else None + name = arg.name if isinstance(arg, TensorProxy) else "" + print(f'{tabs}{name + ": " if name else ""}{type(arg)}{arg if isinstance(arg, dict) else ""} {tensor_shape if tensor_shape else ""} {dtype if dtype else ""}') + print(f'Level {level} end') + _print(args, 0) print('###################################### Debug args\n') + +def update_compile_options_executor_list_after_fw_bw_split() -> None: + from thunder.backend_optimizer.optimizer import get_fw_bw_split_backends_options + cd = get_compile_data() + assert cd + + # Get all the possible options that the vjp_optimization pass will use + options: dict = get_fw_bw_split_backends_options() + executors_list = list(cd.executors_list) + + # Remove all the initial options + for _, v in options.items(): + for ex in v: + if ex in executors_list: + executors_list.remove(ex) + + # Putting at the front event though order does not matter + for ex in cd.compile_options['executors_placed_by_fw_bw_split']: + executors_list.insert(0, ex) + + # Assign new compilation executors options + cd.executors_list = executors_list + +def transform_tensor(arg: TensorProxy, **kwargs) -> torch.Tensor: + from thunder.core.dtypes import is_float_dtype, is_signedinteger_dtype, is_boolean_dtype + + # TODO (matteochen): Missing parallel and fsdp handling... + # TODO (matteochen): Missing support for meta types ... + dtype = arg.dtype + shape = arg.shape + device = arg.device + requires_grad = arg.requires_grad + torch_dtype = to_torch_dtype(dtype) + if torch_dtype is None: + raise AssertionError(f'Unrecognized thunder dtype: {dtype}') + if is_float_dtype(dtype): + # Use TE Float8 if TE is enabled, it has float32 ad torch dtype + te_used = kwargs.get('te_used', False) + if te_used: + tensor: torch.Tensor = torch.randn( + shape, dtype=torch_dtype if dtype.bytes > 1 else torch.float32, device=device.device_str(), requires_grad=requires_grad + ) + if dtype.bytes == 1: + import transformer_engine.pytorch as te + tensor = te.float8_tensor.Float8Tensor.to_float8(tensor) + # Support standard float tensors + else: + tensor: torch.Tensor = torch.randn( + shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad + ) + elif is_signedinteger_dtype(dtype): + tensor: torch.Tensor = torch.randint( + 0, 8, shape, dtype=torch_dtype, device=device.device_str(), requires_grad=requires_grad + ) + elif is_boolean_dtype(dtype): + # TODO (matteochen): maybe random? + tensor: torch.Tensor = torch.zeros( + *shape, dtype=torch.bool, device=device.device_str(), requires_grad=requires_grad + ) + else: + raise AssertionError(f"dtype {dtype} not supported yet") + + return tensor + +# TODO (matteochen): use more appropriate mock int and float +def transform_proxy_to_torch(sequence: Sequence, level=0, **kwargs) -> tuple | list: + res = [] + for e in sequence: + if type(e) is tuple: + res.append(transform_proxy_to_torch(e, level + 1)) + else: + if isinstance(e, TensorProxy): + res.append(transform_tensor(e, **kwargs)) + elif isinstance(e, IntegerProxy): + if e.python_type is bool: + res.append(False if e.value is None else e.value) + else: + res.append(0 if e.value is None else e.value) + elif isinstance(e, FloatProxy): + res.append(0.0 if e.value is None else e.value) + # Transformer engine context object + elif hasattr(e, 'name') and isinstance(e, AnyProxy) and e.name.startswith('ctx_te'): + context = kwargs.get('cached_fw_te_ctx_out', None) + assert context is not None + res.append(context) + elif e is None: + res.append(None) + else: + raise AssertionError(f'Input arg type not recognized: {type(e)} with name: {e.name if hasattr(e, "name") else "unknown"} with value: {e}') + return tuple(res) if level > 0 else res diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index c946c334bf..160b3809fe 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -202,8 +202,7 @@ def bw_fn(*args, **kwargs): from thunder.backend_optimizer.utils import benchmark_trace # In order define this unique trace region we need a unique id - bsym_id = id(bsym) - key = (bsym.sym, Executor(f'{bsym_id}-autotuned'), subkey := _make_cache_key(bsym.args, bsym.kwargs)) + key = (bsym.sym, Executor(f'{id(bsym)}-autotuned'), subkey := _make_cache_key(bsym.args, bsym.kwargs)) # We do check the cache here as the key in the inner fn does not know about this special id cached_result = _cache.get(key, None) if subkey is not None else None # NOTE: cache is always enabled here @@ -217,25 +216,24 @@ def bw_fn(*args, **kwargs): cached_executors_list = list(cd.executors_list) # Retrieve all the executors which are requested to be used requested_executors_list_for_bsym = [ex for ex in cached_executors_list if ex in backends] - # from thunder.executors.torchex import ex as torchex - # if torchex not in executors_list: - # executors_list.append(torchex) from thunder.benchmarks.utils import SplitFwBwBenchmarkUtils from thunder.backend_optimizer.optimizer import OptimizerType best = SplitFwBwBenchmarkUtils() # Restrict the search space backends = list(requested_executors_list_for_bsym) - print(f'Search space for {bsym.sym.name}: {backends}') + + from thunder.backend_optimizer.optimizer import log, LogLevel + log(f'Search space for {bsym.sym.name}: {backends}', level=LogLevel.DEBUG) for b in backends: - print('Benchmarking backend', b.name) + log(f'Benchmarking executor {b.name} for {bsym.sym.name}', level=LogLevel.DEBUG) # Let downstream fn to pick up this requested_executors_list_for_bsym.remove(b) requested_executors_list_for_bsym.insert(0, b) cd.executors_list = requested_executors_list_for_bsym fw_fn, bw_fn, fw_trace, bw_trace = _make_aug_forward_and_backward(True) fw_time, fw_mem, _ = benchmark_trace(fw_trace, iters=20, apply_del_last_used=False) - bw_time, bw_mem, _ = benchmark_trace(bw_trace, iters=20, apply_del_last_used=False) + bw_time, bw_mem, _ = benchmark_trace(bw_trace, iters=20, apply_del_last_used=False, fw_trace=fw_trace) cost = fw_time + bw_time if cd.compile_options['autotune_type'] == OptimizerType.RUNTIME else fw_mem + bw_mem print('cost', cost) if cost < best.cost: @@ -243,14 +241,13 @@ def bw_fn(*args, **kwargs): assert best.cost != float('inf') from thunder.backend_optimizer.optimizer import log - log(f'Best executor for symbol [{bsym.output.name} = {bsym.sym.name}]: {best.executor.name}') + log(f'Best executor for symbol [{bsym.output.name} = {bsym.sym.name}]: {best.executor.name}', level=LogLevel.DEBUG) # Update the compile options - cached_executors_list = [ex for ex in cached_executors_list if ex not in backends] - cached_executors_list.insert(0, best.executor) - + cd.compile_options["executors_placed_by_fw_bw_split"].add(best.executor) # Restore executor list for downstream optimizations cd.executors_list = cached_executors_list + # The executors used in this pass will be updated after the termination of the forward_and_backward_from_trace call _cache[key] = best.fw_fn, best.bw_fn return best.fw_fn, best.bw_fn diff --git a/thunder/executors/sdpaex.py b/thunder/executors/sdpaex.py index ddd82ca915..76f0e733d8 100644 --- a/thunder/executors/sdpaex.py +++ b/thunder/executors/sdpaex.py @@ -389,7 +389,7 @@ def _scaled_dot_product_efficient_attention_backward_impl( scale=scale, ) - +sdpaex_scaled_dot_product_efficient_attention_backward_name = "sdpaex_scaled_dot_product_efficient_attention_backward" sdpea_bwd = sdpa_ex.register_operator( "sdpaex_scaled_dot_product_efficient_attention_backward", meta=_scaled_dot_product_efficient_attention_backward_meta, @@ -470,8 +470,9 @@ def _scaled_dot_product_flash_attention_backward_impl( return (_sdpa_slice_head_dimension(g, value.shape[-1]) for g in grads) +sdpafx_scaled_dot_product_efficient_attention_backward_name = "sdpafx_scaled_dot_product_efficient_attention_backward" sdpfa_bwd = sdpa_ex.register_operator( - "sdpafx_scaled_dot_product_efficient_attention_backward", + sdpafx_scaled_dot_product_efficient_attention_backward_name, meta=_scaled_dot_product_flash_attention_backward_meta, fn=_scaled_dot_product_flash_attention_backward_impl, ) @@ -504,11 +505,13 @@ def _scaled_dot_product_attention_fused( tensor_args = (query, key, value) scalar_args = (dropout_p, is_causal) if backend == SpdaBackend.FLASH_ATTENTION: + print('FLASH ATT') # Use flash attention kernel (primal, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, _) = sdpfa_gradfwd( *tensor_args, *scalar_args, scale=scale ) elif backend == SpdaBackend.MEMORY_EFFICIENT: + print('MEM EFF') # Use memory efficient kernel, which supports fp32 and attention mask arguments (primal, logsumexp, philox_seed, philox_offset) = sdpea_gradfwd( *tensor_args, attn_mask, *scalar_args, scale=scale diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 85f112b1e8..477087d9e9 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -137,13 +137,14 @@ def update_bw_from_forward_optimization(*, fw: TraceCtx, bw:TraceCtx) -> TraceCt return bw def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stats, /, *flat_args): + from thunder.backend_optimizer.optimizer import TraceType, BackendOptimizer + from thunder.backend_optimizer.utils import update_compile_options_executor_list_after_fw_bw_split from thunder.core.rematerialization import rematerialize_all_gather, rematerialize_forward_and_backward from thunder.core.transforms import forward_and_backward_from_trace from thunder.distributed.transforms import FSDPCommBucketing from thunder.distributed.utils import sort_data_parallel_syncs, sort_waits, sort_communication_ops from thunder.executors.passes import del_last_used, transform_for_execution, autotune_transform_for_execution from thunder.visualizer.visualizer_helper import Visualizer - from thunder.backend_optimizer.optimizer import log, LogLevel, TraceType, BackendOptimizer, OptimizerType, benchmark_trace utils.check(compile_data is not None, lambda: "`compile_data` is required") # NOTE: This function is rather slow, so it's intended to be used @@ -170,6 +171,10 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat # the forward trace and inputs of the backward trace. fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True) + # Update the autotuned executors list + if autotune_type: + update_compile_options_executor_list_after_fw_bw_split() + fw_traces = [fw_trace] bw_traces = [bw_trace] From 2eeb6fcaab37d65dfd8dc0663006ae5bbc577053 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 21 Aug 2024 15:56:32 +0300 Subject: [PATCH 060/171] Updated test --- examples/dev/sdpa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dev/sdpa.py b/examples/dev/sdpa.py index 63a71f9aef..aafd2f0256 100644 --- a/examples/dev/sdpa.py +++ b/examples/dev/sdpa.py @@ -4,7 +4,7 @@ dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 torch.set_default_dtype(dtype) -print(f'Script data type: {dtype}') +print(f'Script data type: {dtype}\n') class Model(torch.nn.Module): def __init__(self) -> None: From 090b5958fb87e329398a32cb677b1f45e459a29e Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 21 Aug 2024 17:13:21 +0300 Subject: [PATCH 061/171] Removed print --- thunder/core/vjp_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index 160b3809fe..edae535811 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -235,7 +235,6 @@ def bw_fn(*args, **kwargs): fw_time, fw_mem, _ = benchmark_trace(fw_trace, iters=20, apply_del_last_used=False) bw_time, bw_mem, _ = benchmark_trace(bw_trace, iters=20, apply_del_last_used=False, fw_trace=fw_trace) cost = fw_time + bw_time if cd.compile_options['autotune_type'] == OptimizerType.RUNTIME else fw_mem + bw_mem - print('cost', cost) if cost < best.cost: best = SplitFwBwBenchmarkUtils(cost = cost, fw_fn = fw_fn, bw_fn = bw_fn, executor = b) From d030b1e18fa053ff82c5ec11ad0b14bd69a7a026 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 21 Aug 2024 18:25:33 +0300 Subject: [PATCH 062/171] Updated tests --- examples/dev/sdpa.py | 52 ++++++++++- examples/dev/te.py | 7 +- thunder/backend_optimizer/utils.py | 137 ++++++++++++++++++++--------- 3 files changed, 148 insertions(+), 48 deletions(-) diff --git a/examples/dev/sdpa.py b/examples/dev/sdpa.py index aafd2f0256..441a65474f 100644 --- a/examples/dev/sdpa.py +++ b/examples/dev/sdpa.py @@ -1,6 +1,6 @@ import torch import thunder -from thunder.benchmarks.utils import thunder_fw_bw_benchmark +from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_total_benchmark dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 torch.set_default_dtype(dtype) @@ -14,9 +14,47 @@ def forward(self, query, key, value): a = torch.nn.functional.scaled_dot_product_attention(query, key, value) # Make different inputs as happens in a real model b = torch.nn.functional.scaled_dot_product_attention(query+query, key+key, value+value) - c = torch.nn.functional.scaled_dot_product_attention(query+query, key+key, value+value) - d = torch.nn.functional.scaled_dot_product_attention(query+query, key+key, value+value) - return a * b + c - d + # c = torch.nn.functional.scaled_dot_product_attention(query*query, key*key, value*value) + # d = torch.nn.functional.scaled_dot_product_attention(query-query, key-key, value-value) + return a + b + + +def bench(m, label, iters): + q = torch.rand(32, 8, 128, 64*1, requires_grad=True) + k = torch.rand(32, 8, 128, 64*1, requires_grad=True) + v = torch.rand(32, 8, 128, 64*1, requires_grad=True) + + # warm up + for _ in range(50): + y = m(q, k, v) + # y.sum().backward() + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + stream = torch.cuda.current_stream() + max_allocated_bytes = 0 + torch.cuda.synchronize() + for i in range(iters): + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) + + start_events[i].record(stream) + y = m(q, k, v) + loss = y.sum() + # loss.backward() + end_events[i].record(stream) + + max_allocated_bytes = max( + max_allocated_bytes, torch.cuda.max_memory_allocated( + torch.cuda.current_device()) + ) + + torch.cuda.synchronize() + tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + tot_time = sum(tot) / iters + print(f'{label} tot time: {tot_time} ms') + print(f'{label} max allocated memory: {max_allocated_bytes / (2**30)} gb') with torch.device('cuda'): model = Model() @@ -42,8 +80,10 @@ def forward(self, query, key, value): ] fw_labels = ["fw_def", "fw_auto"] bw_labels = ["bw_def", "bw_auto"] + print('Thunder benchmark:') thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) + print('\n\n\n\n\n\n') print(f'{thunder.last_traces(jmodel_def)[-1]}') print('###############################################################################') @@ -53,3 +93,7 @@ def forward(self, query, key, value): print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') print('###############################################################################') print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') + + print('\nTorch benchmark:') + bench(jmodel_def, 'def', iters) + bench(jmodel_auto, 'auto', iters) diff --git a/examples/dev/te.py b/examples/dev/te.py index efd59c7392..1c1d541ab1 100644 --- a/examples/dev/te.py +++ b/examples/dev/te.py @@ -7,11 +7,12 @@ def __init__(self, in_features, out_features) -> None: super().__init__() self.linear = torch.nn.Sequential( torch.nn.Linear(in_features, out_features), + torch.nn.Linear(out_features, in_features), + torch.nn.Linear(in_features, out_features), ) def forward(self, x: torch.Tensor): - a = x + x - return self.linear(a) + return self.linear(x) with torch.device('cuda'): m = 1 @@ -24,7 +25,7 @@ def forward(self, x: torch.Tensor): jmodel_auto = thunder.jit( model, autotune_type="runtime", - executors=["nvfuser", "transformer_engine", "cudnn", "torch"], + executors=["nvfuser", "transformer_engine", ], use_cudagraphs=False, ) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 6dd5412cc9..fe82c13c1c 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -394,27 +394,30 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl def build_static_args(sequence: Sequence, **kwargs): return transform_proxy_to_torch(sequence, level=0, **kwargs) - # We have to fix the saved_for_backward tuple as TE output TensorProxy don't have a correct one - def fix_te_backward_inputs(inputs: list): - saved_for_bw = [] - for i, e in enumerate(inputs[0][0]): - # This tensor should be an uint8 https://github.com/NVIDIA/TransformerEngine/blob/4edcff5777be08b6f89658572c433aa8f36acf0d/transformer_engine/pytorch/module/linear.py#L366 - if i == 1: - inputmat_t = e - if inputmat_t.dtype != torch.uint8: - inputmat_t = torch.randint(0, 8, (e.shape), dtype=torch.uint8, device=e.device) - saved_for_bw.append(inputmat_t) - else: - saved_for_bw.append(e) - - fixed_inputs_first_index = tuple([tuple(saved_for_bw), inputs[0][1]]) - return fixed_inputs_first_index - def TE_backward_trace_preprocessing(): nonlocal input_args - # Due to the fw and bw split benchmarking we have to check the bw nature by looking for the bsym + # Due to the fw and bw split benchmarking we have to check the bw nature by looking at the bsym is_bw = is_te_ex_bw_used(trace) - # If transformer_engine executor is used and it is the bw function we have to recover the forward context from the forward trace + """ + If transformer_engine executor is used and it is the bw function we have to recover the forward context and saved_for_backward tensors from the forward trace. + Why we can't generate saved_for_bw tensors from static compiation data? Yes we could but the static compilation data are broken. See the function below: + + def fix_te_backward_inputs(inputs: list): + # inputs = old saved_for_backward + saved_for_bw = [] + for i, e in enumerate(inputs[0][0]): + # This tensor should be an uint8 https://github.com/NVIDIA/TransformerEngine/blob/4edcff5777be08b6f89658572c433aa8f36acf0d/transformer_engine/pytorch/module/linear.py#L366 + if i == 1: + inputmat_t = e + if inputmat_t.dtype != torch.uint8: + inputmat_t = torch.randint(0, 8, (e.shape), dtype=torch.uint8, device=e.device) + saved_for_bw.append(inputmat_t) + else: + saved_for_bw.append(e) + + fixed_inputs_first_index = tuple([tuple(saved_for_bw), inputs[0][1]]) + return fixed_inputs_first_index + """ if is_bw and 'fw_trace' not in kwargs: raise RuntimeError('Set the associated forward trace in order to benchmark backward pass with TE executor') elif is_bw: @@ -422,19 +425,60 @@ def TE_backward_trace_preprocessing(): fw_trace = kwargs.get('fw_trace', None) if not isinstance(fw_trace, TraceCtx): raise AssertionError(f'forward trace is not a TraceCtx. Received: {type(fw_trace)}') - # Run the fw trace and get the outputs + # Run the fw trace and get the runtime outputs fw_output = benchmark_trace(fw_trace, apply_del_last_used=False)[2] - # Retrive the context from the fw pass output - # Currently it will contain an empty transformer_engineex.Context but might be useful for the future - cached_fw_te_ctx_out = fw_output[1][1][0] - # After got the Context object from the fw pass we con build the input args list for the bw pass - input_args = build_static_args(trace.args, cached_fw_te_ctx_out=cached_fw_te_ctx_out, te_used=True) + # Check if the fw trace is a final trace or an intermediate one (used for single trace region benchmarks) + sig = fw_trace.signature_with_no_ctx() + is_fw_final_trace = sig.startswith('def augmented') - # Fix TE arguments for benchmark - first_tuple = fix_te_backward_inputs(input_args) - input_args.pop(0) - input_args.insert(0, first_tuple) + # Filter the output tuple + saved_for_bw_C0 = fw_output[1] if not is_fw_final_trace else fw_output[1][0] + + # After got the Context object from the fw pass we can build the input args list for the bw pass + # The underlying API will generate TE.Floa8 tensors also hence it must know if we are dealing with torch.float8 or TE.Float8 + input_args = build_static_args(trace.args, te_used=True) + + # Now, we expect that if the fw trace is a final trace also the bw trace is a final one. And vice versa + if is_fw_final_trace: + # Swap saved_for_backward_traces + saved_for_bw = saved_for_bw_C0, fw_output[1][1] # Saved for backward tuple unpacks in (C0, _) + # Subsitute the static inputs for saved_for_backward with the runtime ones + input_args.pop(0) + input_args.insert(0, saved_for_bw) + else: + # Transformer Engine single trace region backward trace receives as input the saved_for_bw tensors plus some others. + # They are indexed like [saved_for_bw, others...] + # NOTE: This may change in the future + """ + Example: + @transformer_engine.fp8_autocast(fp8_recipe=te_fp8_recipe) + @torch.no_grad() + @no_autocast + def _linear_grad(a, w, b): + # a: "cuda:0 f32[768, 4096]" + # w: "cuda:0 f32[4096, 4096]" + # b: "cuda:0 f32[4096]" + (t5, (t0, t1, t2, t3, _, t4), ctx_te_2) = te_linear_2(a, w, b) + return (t5, [ctx_te_2, t0, t1, t2, t3, t4]) + Trace + import torch + from thunder.executors.torchex import no_autocast + + @torch.no_grad() + @no_autocast + def linear_backward(ctx_te_2, t0, t1, t2, t3, t4, t6): + (t7, t8, t9) = te_functional_linear_backward((768, 4096), (4096, 4096), (4096,), ctx_te_2, (t0, t1, t2, t3, None, t4), t6) + return {'a': t7, 'w': t8, 'bias': t9} + + See how the backward trace need t6 as argument recoveered from the static args + """ + updated_input_args = [t for t in saved_for_bw_C0] + updated_input_args.extend(input_args[len(updated_input_args):]) + input_args = updated_input_args + # Forward pass + else: + input_args = build_static_args(trace.args) def SDPA_backward_trace_preprocessing(): nonlocal input_args @@ -449,7 +493,6 @@ def SDPA_backward_trace_preprocessing(): # Check if the fw trace is a final trace or an intermediate one (used for single trace region benchmarks) sig = fw_trace.signature_with_no_ctx() is_fw_final_trace = sig.startswith('def augmented') - print(f'Is backward sdpa final trace? {is_fw_final_trace}') # Filter the output tuple saved_for_bw_C0 = fw_output[1] if not is_fw_final_trace else fw_output[1][0] @@ -461,11 +504,13 @@ def SDPA_backward_trace_preprocessing(): if is_fw_final_trace: # Swap saved_for_backward_traces saved_for_bw = saved_for_bw_C0, fw_output[1][1] # Saved for backward tuple unpacks in (C0, _) + # Subsitute the static inputs for saved_for_backward with the runtime ones input_args.pop(0) input_args.insert(0, saved_for_bw) else: - # SDPA single region backward trace receives as input the saved_for_bw tensors plus some others. + # SDPA single trace region backward trace receives as input the saved_for_bw tensors plus some others. # They are indexed like [saved_for_bw, others...] + # NOTE: This may change in the future """ Example: @torch.no_grad() @@ -489,8 +534,6 @@ def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causa """ updated_input_args = [t for t in saved_for_bw_C0] updated_input_args.extend(input_args[len(updated_input_args):]) - # print('Updated input_args') - # print_args(updated_input_args) input_args = updated_input_args # Input args for the trace to benchmark @@ -520,7 +563,6 @@ def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causa else: input_args: list = build_static_args(trace.args) - trace_tok = set_tracectx(trace) # Obtain the python executable string @@ -612,10 +654,10 @@ def wrap_fn_with_exeuctor_compile_option(option, fn: Callable | None = None, *ar return out def print_trace_args(trace: TraceCtx): - print_args(trace.args) + print_nested_sequence(trace.args) # Display nest sequence arguments -def print_args(args): +def print_nested_sequence(args, show_dicts=False): def is_tensor(t): return isinstance(t, torch.Tensor) or isinstance(t, TensorProxy) @@ -633,7 +675,7 @@ def _print(args, level): tensor_shape = arg.shape if is_tensor(arg) else None dtype = arg.dtype if is_tensor(arg) else None name = arg.name if isinstance(arg, TensorProxy) else "" - print(f'{tabs}{name + ": " if name else ""}{type(arg)}{arg if isinstance(arg, dict) else ""} {tensor_shape if tensor_shape else ""} {dtype if dtype else ""}') + print(f'{tabs}{name + ": " if name else ""}{type(arg)}{arg if isinstance(arg, dict) and show_dicts else ""} {tensor_shape if tensor_shape else ""} {dtype if dtype else ""}') print(f'Level {level} end') _print(args, 0) print('###################################### Debug args\n') @@ -674,6 +716,7 @@ def transform_tensor(arg: TensorProxy, **kwargs) -> torch.Tensor: raise AssertionError(f'Unrecognized thunder dtype: {dtype}') if is_float_dtype(dtype): # Use TE Float8 if TE is enabled, it has float32 ad torch dtype + # NOTE: if we have a standalone torch.float8 inside the args and it is not the TE Float8 it won't be parsed correctly for now te_used = kwargs.get('te_used', False) if te_used: tensor: torch.Tensor = torch.randn( @@ -703,10 +746,11 @@ def transform_tensor(arg: TensorProxy, **kwargs) -> torch.Tensor: # TODO (matteochen): use more appropriate mock int and float def transform_proxy_to_torch(sequence: Sequence, level=0, **kwargs) -> tuple | list: + from thunder.executors.transformer_engineex import Context as C res = [] for e in sequence: if type(e) is tuple: - res.append(transform_proxy_to_torch(e, level + 1)) + res.append(transform_proxy_to_torch(e, level + 1, **kwargs)) else: if isinstance(e, TensorProxy): res.append(transform_tensor(e, **kwargs)) @@ -717,11 +761,22 @@ def transform_proxy_to_torch(sequence: Sequence, level=0, **kwargs) -> tuple | l res.append(0 if e.value is None else e.value) elif isinstance(e, FloatProxy): res.append(0.0 if e.value is None else e.value) - # Transformer engine context object + # Transformer engine Context object + # + # This instruction will populate the args with a dummy context which is not correct in theory. + # For the benchmark purpose (where this fn is currently used) this error will not impact on the runtime correctness as at the end we + # will use the cached runtime contexts from the forward pass. + # We need this only to generate a context for the static inputs (which are discarded afterwards). + # + # Backward args: (saved_for_backward, cotangents) + # saved_for_backward -> replaced by the runtime tuple + # cotangents -> static inputs will be used + # If the static input generator will be capable to generate only the cotangents then branch will not be used anymore + # + # Currently an option to fill a custom maybe real context is left. elif hasattr(e, 'name') and isinstance(e, AnyProxy) and e.name.startswith('ctx_te'): - context = kwargs.get('cached_fw_te_ctx_out', None) - assert context is not None - res.append(context) + required_context = kwargs.get('cached_fw_te_ctx_out', None) + res.append(required_context if required_context is not None else C()) elif e is None: res.append(None) else: From bb53a5eaf5f2ee8cf04594b858f6c23c90d9d90b Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 21 Aug 2024 18:47:51 +0300 Subject: [PATCH 063/171] Supporting tuples --- thunder/benchmarks/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/benchmarks/utils.py b/thunder/benchmarks/utils.py index 00c7ac844c..395f614ffe 100644 --- a/thunder/benchmarks/utils.py +++ b/thunder/benchmarks/utils.py @@ -125,7 +125,7 @@ def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int) for m, input, label in zip(models, inputs, labels): # Warm up for _ in range(warm_up_iters): - y = m(input) + y = m(*input if isinstance(input, tuple) else input) y.sum().backward() start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] @@ -139,7 +139,7 @@ def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int) torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) start_events[i].record(stream) - y = m(input) + y = m(*input if isinstance(input, tuple) else input) loss = y.sum() loss.backward() end_events[i].record(stream) From 74ae689edd61c057200221c48172f602e7d8d1b2 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 21 Aug 2024 18:51:52 +0300 Subject: [PATCH 064/171] Updated example --- examples/dev/sdpa.py | 46 ++------------------------------------------ 1 file changed, 2 insertions(+), 44 deletions(-) diff --git a/examples/dev/sdpa.py b/examples/dev/sdpa.py index 441a65474f..49fd08c71a 100644 --- a/examples/dev/sdpa.py +++ b/examples/dev/sdpa.py @@ -14,48 +14,8 @@ def forward(self, query, key, value): a = torch.nn.functional.scaled_dot_product_attention(query, key, value) # Make different inputs as happens in a real model b = torch.nn.functional.scaled_dot_product_attention(query+query, key+key, value+value) - # c = torch.nn.functional.scaled_dot_product_attention(query*query, key*key, value*value) - # d = torch.nn.functional.scaled_dot_product_attention(query-query, key-key, value-value) return a + b - -def bench(m, label, iters): - q = torch.rand(32, 8, 128, 64*1, requires_grad=True) - k = torch.rand(32, 8, 128, 64*1, requires_grad=True) - v = torch.rand(32, 8, 128, 64*1, requires_grad=True) - - # warm up - for _ in range(50): - y = m(q, k, v) - # y.sum().backward() - - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - stream = torch.cuda.current_stream() - max_allocated_bytes = 0 - torch.cuda.synchronize() - for i in range(iters): - torch.cuda.empty_cache() - torch.cuda._sleep(1_000_000) - torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) - - start_events[i].record(stream) - y = m(q, k, v) - loss = y.sum() - # loss.backward() - end_events[i].record(stream) - - max_allocated_bytes = max( - max_allocated_bytes, torch.cuda.max_memory_allocated( - torch.cuda.current_device()) - ) - - torch.cuda.synchronize() - tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] - tot_time = sum(tot) / iters - print(f'{label} tot time: {tot_time} ms') - print(f'{label} max allocated memory: {max_allocated_bytes / (2**30)} gb') - with torch.device('cuda'): model = Model() @@ -83,6 +43,8 @@ def bench(m, label, iters): print('Thunder benchmark:') thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) + # print('\n\nThunder benchmark:') + # torch_total_benchmark([jmodel_def, jmodel_auto], ['def', 'auto'], [(q, k, v), (q, k, v)], iters) print('\n\n\n\n\n\n') print(f'{thunder.last_traces(jmodel_def)[-1]}') @@ -93,7 +55,3 @@ def bench(m, label, iters): print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') print('###############################################################################') print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') - - print('\nTorch benchmark:') - bench(jmodel_def, 'def', iters) - bench(jmodel_auto, 'auto', iters) From 75816513473ffe572010feeb259e13c41da36bda Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 21 Aug 2024 20:33:17 +0300 Subject: [PATCH 065/171] Enabled log and modified tests --- examples/dev/nanogpt.py | 6 +++--- examples/dev/sdpa.py | 7 ++++--- thunder/backend_optimizer/optimizer.py | 2 +- thunder/core/vjp_utils.py | 6 +++--- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/examples/dev/nanogpt.py b/examples/dev/nanogpt.py index da560160f8..55005b5b03 100644 --- a/examples/dev/nanogpt.py +++ b/examples/dev/nanogpt.py @@ -42,7 +42,7 @@ def run(target: str = 'runtime'): # model init gptconf = GPTConfig( block_size = block_size, # how far back does the model look? i.e. context size - n_layer = 12, n_head = 12, n_embd = 768, # size of the model + n_layer = 4, n_head = 12, n_embd = 768, # size of the model dropout = 0, # for determinism bias = bias, ) @@ -51,7 +51,7 @@ def run(target: str = 'runtime'): jmodel_def = thunder.jit(model) # Currently sdpa does not work? - jmodel_auto = thunder.jit(model, autotune_type=target, executors = ['torchcompile', 'nvfuser', 'cudnn', 'sdpa', 'torch', 'python'], use_cudagraphs=False) + jmodel_auto = thunder.jit(model, autotune_type=target, executors = ['torchcompile', 'nvfuser', 'cudnn', 'sdpa'], use_cudagraphs=False) if compile: print("Compiling model...") @@ -157,7 +157,7 @@ def measure_nvsight(m, label): ] fw_labels = ["fw_def", "fw_auto"] bw_labels = ["bw_def", "bw_auto"] - thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, 20) + thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, 100) measure_nvsight(jmodel_def, 'def') measure_nvsight(jmodel_auto, 'auto') diff --git a/examples/dev/sdpa.py b/examples/dev/sdpa.py index 49fd08c71a..75fd67cd39 100644 --- a/examples/dev/sdpa.py +++ b/examples/dev/sdpa.py @@ -43,9 +43,6 @@ def forward(self, query, key, value): print('Thunder benchmark:') thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) - # print('\n\nThunder benchmark:') - # torch_total_benchmark([jmodel_def, jmodel_auto], ['def', 'auto'], [(q, k, v), (q, k, v)], iters) - print('\n\n\n\n\n\n') print(f'{thunder.last_traces(jmodel_def)[-1]}') print('###############################################################################') @@ -55,3 +52,7 @@ def forward(self, query, key, value): print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') print('###############################################################################') print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') + + print('\nTorch benchmark:') + bench(jmodel_def, 'def', iters) + bench(jmodel_auto, 'auto', iters) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index f3e74b8dd5..dd12cc476c 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -19,7 +19,7 @@ def get_fw_bw_split_backends_options(bsym: BoundSymbol | None = None) -> list | #Current configuration options: dict[str, list] = { # TODO: filter out TE only if requested - 'linear': [transformer_engine_ex], + # 'linear': [transformer_engine_ex], 'scaled_dot_product_attention': [sdpa_ex, cudnn_ex, fa3_ex], } diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index edae535811..caf904652b 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -224,9 +224,9 @@ def bw_fn(*args, **kwargs): backends = list(requested_executors_list_for_bsym) from thunder.backend_optimizer.optimizer import log, LogLevel - log(f'Search space for {bsym.sym.name}: {backends}', level=LogLevel.DEBUG) + log(f'Search space for {bsym.sym.name}: {backends}', level=LogLevel.INFO) for b in backends: - log(f'Benchmarking executor {b.name} for {bsym.sym.name}', level=LogLevel.DEBUG) + log(f'Benchmarking executor {b.name} for {bsym.sym.name}', level=LogLevel.INFO) # Let downstream fn to pick up this requested_executors_list_for_bsym.remove(b) requested_executors_list_for_bsym.insert(0, b) @@ -240,7 +240,7 @@ def bw_fn(*args, **kwargs): assert best.cost != float('inf') from thunder.backend_optimizer.optimizer import log - log(f'Best executor for symbol [{bsym.output.name} = {bsym.sym.name}]: {best.executor.name}', level=LogLevel.DEBUG) + log(f'Best executor for symbol [{bsym.output.name} = {bsym.sym.name}]: {best.executor.name}', level=LogLevel.INFO) # Update the compile options cd.compile_options["executors_placed_by_fw_bw_split"].add(best.executor) From fc090c8494af6c0c1729638815d7e617daafad32 Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 22 Aug 2024 00:04:30 +0300 Subject: [PATCH 066/171] Fixed executors list for gradfn picking --- thunder/__init__.py | 3 +++ thunder/backend_optimizer/utils.py | 24 ++++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/thunder/__init__.py b/thunder/__init__.py index f932b0cc40..739c64be4a 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -339,6 +339,9 @@ def jit( # Remove python and cudagraph executors = [ex for ex in executors if ex.name != 'python' and ex.name != 'cudagraphex'] + from thunder.backend_optimizer.utils import reorder_executors_list + executors = reorder_executors_list(executors) + # Resolve names of executors executors = resolve_executors(executors) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index fe82c13c1c..1ab923cdf8 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -782,3 +782,27 @@ def transform_proxy_to_torch(sequence: Sequence, level=0, **kwargs) -> tuple | l else: raise AssertionError(f'Input arg type not recognized: {type(e)} with name: {e.name if hasattr(e, "name") else "unknown"} with value: {e}') return tuple(res) if level > 0 else res + +def reorder_executors_list(executors: Sequence): + from thunder.backend_optimizer.optimizer import get_fw_bw_split_backends_options + + reordered = [] + options = get_fw_bw_split_backends_options() + + are_inputs_names = isinstance(executors[0], str) + + # Put these in front to be picked up by _get_gradfn_and_executor + for _, v in options.items(): + print(v) + for ex in v: + if are_inputs_names: + if ex.name in executors: + reordered.append(ex.name) + elif ex in executors: + reordered.append(ex) + + # Add others + for ex in executors: + if ex not in reordered: + reordered.append(ex) + return reordered From df177703786a36c0bf8fcba8a5ba77636372860d Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 22 Aug 2024 10:59:24 +0300 Subject: [PATCH 067/171] Adding fusion ex to executors list if not present --- thunder/backend_optimizer/utils.py | 102 ++++++++++++++++++++++++++--- 1 file changed, 93 insertions(+), 9 deletions(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 1ab923cdf8..53b0e94b27 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -287,6 +287,10 @@ def is_sdpa_ex_bw_used(trace: TraceCtx) -> bool: return True return False +def is_backward_trace(trace: TraceCtx) -> bool: + sig = trace.signature_with_no_ctx() + return sig.find('backward') >= 0 + def benchmark_trace( trace: TraceCtx, iters: int = 1, @@ -391,7 +395,7 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl print(f"#Trace execution failed at iter {current_iter} (error: {e})\n{repr}") raise e - def build_static_args(sequence: Sequence, **kwargs): + def build_static_args(sequence: Sequence, **kwargs) -> list: return transform_proxy_to_torch(sequence, level=0, **kwargs) def TE_backward_trace_preprocessing(): @@ -536,8 +540,71 @@ def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causa updated_input_args.extend(input_args[len(updated_input_args):]) input_args = updated_input_args + def backward_trace_args_preprocess() -> list: + if 'fw_trace' not in kwargs: + raise RuntimeError('Set the associated forward trace in order to benchmark backward pass with sdpa executor') + fw_trace = kwargs.get('fw_trace', None) + if not isinstance(fw_trace, TraceCtx): + raise AssertionError(f'forward trace is not a TraceCtx. Received: {type(fw_trace)}') + # Run the fw trace and get the outputs + fw_output = benchmark_trace(fw_trace, apply_del_last_used=False)[2] + + # Check if the fw trace is a final trace or an intermediate one (used for single trace region benchmarks) + sig = fw_trace.signature_with_no_ctx() + is_fw_final_trace = sig.startswith('def augmented') + + # Filter the output tuple + # These location might change if the implementation of the automatic + # differentiation transform changes. The saved tensors are the second output + # of the return statement. There's a prototype changing the saved tensors to + # be part of the output of a special symbol + # https://github.com/Lightning-AI/lightning-thunder/pull/214 + saved_for_bw_C0 = fw_output[1] if not is_fw_final_trace else fw_output[1][0] + + # te_used = is_te_used(trace) + # Retrieve the compile time arguments + input_args = build_static_args(trace.args, te_used=is_te_used(trace)) + + # Now, we expected that if the fw trace is a final trace also the bw trace is a final one. And vice versa + if is_fw_final_trace: + # Swap saved_for_backward_traces + saved_for_bw = saved_for_bw_C0, fw_output[1][1] # Saved for backward tuple unpacks in (C0, _) + # Subsitute the static inputs for saved_for_backward with the runtime ones + input_args.pop(0) + input_args.insert(0, saved_for_bw) + else: + # Currently single trace region backward trace receives as input the saved_for_bw tensors plus some others. + # They are indexed like [saved_for_bw, others...] + # NOTE: This may change in the future + """ + Example: + @torch.no_grad() + @no_autocast + def _cudnn_sdpa_bwd_wrapper(query, key, value, attn_mask, dropout_p=0.0, is_causal=False, *, scale=None): + # query: "cuda:0 bf16[32, 8, 128, 64]" + # key: "cuda:0 bf16[32, 8, 128, 64]" + # value: "cuda:0 bf16[32, 8, 128, 64]" + # dropout_p: "float 0.0" + # is_causal: "bool False" + (t0, t1, t2, t3) = cudnn_sdpa_fwd(query, key, value, None, dropout_p, is_causal, scale=None) + return (t0, [query, key, value, dropout_p, is_causal, t0, t1, t2, t3]) + + @torch.no_grad() + @no_autocast + def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causal, t0, t1, t2, t3, t4): + (t5, t6, t7) = cudnn_sdpa_bwd(t4, query, key, value, None, dropout_p, is_causal, t0, t1, t2, t3, scale=None, cat_grad_qkv=False) + return {'query': t5, 'key': t6, 'value': t7, 'attn_mask': None, 'dropout_p': None, 'is_causal': None, 'scale': None} + + See how the backward trace need t4 as argument recoveered from the static args + """ + updated_input_args = [t for t in saved_for_bw_C0] + updated_input_args.extend(input_args[len(updated_input_args):]) # Should be only one variable but leave this dyanamic + input_args = updated_input_args + + return input_args + # Input args for the trace to benchmark - input_args = [] + # input_args = [] # Check for correctness if trace.bound_symbols[-1].sym.id != PrimIDs.RETURN: @@ -546,20 +613,22 @@ def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causa if apply_del_last_used: trace = del_last_used(trace) - # Handle TE traces + # # Handle TE traces te_used = is_te_used(trace) - sdpa_ex_bw_used = is_sdpa_ex_bw_used(trace) - if te_used and sdpa_ex_bw_used: - raise AssertionError("Not handled") + # sdpa_ex_bw_used = is_sdpa_ex_bw_used(trace) + # if te_used and sdpa_ex_bw_used: + # raise AssertionError("Not handled") if te_used: cached_te_fp8_autocast_value = trace._include_te_fp8_autocast trace._include_te_fp8_autocast = True - TE_backward_trace_preprocessing() + # TE_backward_trace_preprocessing() # Fix sdpaex arguments for backward benchmarks - elif sdpa_ex_bw_used: - SDPA_backward_trace_preprocessing() + # elif sdpa_ex_bw_used: + # SDPA_backward_trace_preprocessing() # "Default" trace, parse the input args...(input args parsing will be performed by the TE trace handling) + if is_backward_trace(trace): + input_args = backward_trace_args_preprocess() else: input_args: list = build_static_args(trace.args) @@ -781,10 +850,13 @@ def transform_proxy_to_torch(sequence: Sequence, level=0, **kwargs) -> tuple | l res.append(None) else: raise AssertionError(f'Input arg type not recognized: {type(e)} with name: {e.name if hasattr(e, "name") else "unknown"} with value: {e}') + # Outer container must be a list return tuple(res) if level > 0 else res def reorder_executors_list(executors: Sequence): from thunder.backend_optimizer.optimizer import get_fw_bw_split_backends_options + from thunder.executors.torch_compile import torch_compile_ex + from thunder.executors.nvfuserex_impl import ex as nvfuser_ex reordered = [] options = get_fw_bw_split_backends_options() @@ -805,4 +877,16 @@ def reorder_executors_list(executors: Sequence): for ex in executors: if ex not in reordered: reordered.append(ex) + + # NOTE: Currently the autotuner expects at least one Fusion executor otherwise it won't work. + # If other techniques will be added then this constraint will not be necessary + found = False + for ex in reordered: + if are_inputs_names and (ex == nvfuser_ex.name or ex == torch_compile_ex.name): + found = True + elif (ex == nvfuser_ex or ex == torch_compile_ex): + found = True + if not found: + reordered.append(nvfuser_ex) + return reordered From f7a8b16d85ce4e0aa86d24f233d1e5ecd2cfe1fb Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 22 Aug 2024 11:48:45 +0300 Subject: [PATCH 068/171] Benchmark bw fn with runtime args for all traces --- thunder/backend_optimizer/optimizer.py | 2 +- thunder/backend_optimizer/utils.py | 197 ++----------------------- thunder/core/vjp_utils.py | 14 +- 3 files changed, 27 insertions(+), 186 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index dd12cc476c..f3e74b8dd5 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -19,7 +19,7 @@ def get_fw_bw_split_backends_options(bsym: BoundSymbol | None = None) -> list | #Current configuration options: dict[str, list] = { # TODO: filter out TE only if requested - # 'linear': [transformer_engine_ex], + 'linear': [transformer_engine_ex], 'scaled_dot_product_attention': [sdpa_ex, cudnn_ex, fa3_ex], } diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 53b0e94b27..18601b7380 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -269,24 +269,6 @@ def is_te_used(trace: TraceCtx) -> bool: return True return False -def is_te_ex_bw_used(trace: TraceCtx) -> bool: - from thunder.executors.transformer_engineex import te_functional_linear_backward_name - for bsym in trace.bound_symbols: - if bsym.sym.name == te_functional_linear_backward_name: - return True - return False - -def is_sdpa_ex_bw_used(trace: TraceCtx) -> bool: - from thunder.executors.sdpaex import ( - sdpaex_scaled_dot_product_efficient_attention_backward_name as n1, - sdpafx_scaled_dot_product_efficient_attention_backward_name as n2, - ) - - for bsym in trace.bound_symbols: - if bsym.sym.name == n1 or bsym.sym.name == n2: - return True - return False - def is_backward_trace(trace: TraceCtx) -> bool: sig = trace.signature_with_no_ctx() return sig.find('backward') >= 0 @@ -398,148 +380,6 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl def build_static_args(sequence: Sequence, **kwargs) -> list: return transform_proxy_to_torch(sequence, level=0, **kwargs) - def TE_backward_trace_preprocessing(): - nonlocal input_args - # Due to the fw and bw split benchmarking we have to check the bw nature by looking at the bsym - is_bw = is_te_ex_bw_used(trace) - """ - If transformer_engine executor is used and it is the bw function we have to recover the forward context and saved_for_backward tensors from the forward trace. - Why we can't generate saved_for_bw tensors from static compiation data? Yes we could but the static compilation data are broken. See the function below: - - def fix_te_backward_inputs(inputs: list): - # inputs = old saved_for_backward - saved_for_bw = [] - for i, e in enumerate(inputs[0][0]): - # This tensor should be an uint8 https://github.com/NVIDIA/TransformerEngine/blob/4edcff5777be08b6f89658572c433aa8f36acf0d/transformer_engine/pytorch/module/linear.py#L366 - if i == 1: - inputmat_t = e - if inputmat_t.dtype != torch.uint8: - inputmat_t = torch.randint(0, 8, (e.shape), dtype=torch.uint8, device=e.device) - saved_for_bw.append(inputmat_t) - else: - saved_for_bw.append(e) - - fixed_inputs_first_index = tuple([tuple(saved_for_bw), inputs[0][1]]) - return fixed_inputs_first_index - """ - if is_bw and 'fw_trace' not in kwargs: - raise RuntimeError('Set the associated forward trace in order to benchmark backward pass with TE executor') - elif is_bw: - # print('TE Benchmarking fw trace for bw') - fw_trace = kwargs.get('fw_trace', None) - if not isinstance(fw_trace, TraceCtx): - raise AssertionError(f'forward trace is not a TraceCtx. Received: {type(fw_trace)}') - # Run the fw trace and get the runtime outputs - fw_output = benchmark_trace(fw_trace, apply_del_last_used=False)[2] - - # Check if the fw trace is a final trace or an intermediate one (used for single trace region benchmarks) - sig = fw_trace.signature_with_no_ctx() - is_fw_final_trace = sig.startswith('def augmented') - - # Filter the output tuple - saved_for_bw_C0 = fw_output[1] if not is_fw_final_trace else fw_output[1][0] - - # After got the Context object from the fw pass we can build the input args list for the bw pass - # The underlying API will generate TE.Floa8 tensors also hence it must know if we are dealing with torch.float8 or TE.Float8 - input_args = build_static_args(trace.args, te_used=True) - - # Now, we expect that if the fw trace is a final trace also the bw trace is a final one. And vice versa - if is_fw_final_trace: - # Swap saved_for_backward_traces - saved_for_bw = saved_for_bw_C0, fw_output[1][1] # Saved for backward tuple unpacks in (C0, _) - # Subsitute the static inputs for saved_for_backward with the runtime ones - input_args.pop(0) - input_args.insert(0, saved_for_bw) - else: - # Transformer Engine single trace region backward trace receives as input the saved_for_bw tensors plus some others. - # They are indexed like [saved_for_bw, others...] - # NOTE: This may change in the future - """ - Example: - @transformer_engine.fp8_autocast(fp8_recipe=te_fp8_recipe) - @torch.no_grad() - @no_autocast - def _linear_grad(a, w, b): - # a: "cuda:0 f32[768, 4096]" - # w: "cuda:0 f32[4096, 4096]" - # b: "cuda:0 f32[4096]" - (t5, (t0, t1, t2, t3, _, t4), ctx_te_2) = te_linear_2(a, w, b) - return (t5, [ctx_te_2, t0, t1, t2, t3, t4]) - Trace - import torch - from thunder.executors.torchex import no_autocast - - @torch.no_grad() - @no_autocast - def linear_backward(ctx_te_2, t0, t1, t2, t3, t4, t6): - (t7, t8, t9) = te_functional_linear_backward((768, 4096), (4096, 4096), (4096,), ctx_te_2, (t0, t1, t2, t3, None, t4), t6) - return {'a': t7, 'w': t8, 'bias': t9} - - See how the backward trace need t6 as argument recoveered from the static args - """ - updated_input_args = [t for t in saved_for_bw_C0] - updated_input_args.extend(input_args[len(updated_input_args):]) - input_args = updated_input_args - # Forward pass - else: - input_args = build_static_args(trace.args) - - def SDPA_backward_trace_preprocessing(): - nonlocal input_args - if 'fw_trace' not in kwargs: - raise RuntimeError('Set the associated forward trace in order to benchmark backward pass with sdpa executor') - fw_trace = kwargs.get('fw_trace', None) - if not isinstance(fw_trace, TraceCtx): - raise AssertionError(f'forward trace is not a TraceCtx. Received: {type(fw_trace)}') - # Run the fw trace and get the outputs - fw_output = benchmark_trace(fw_trace, apply_del_last_used=False)[2] - - # Check if the fw trace is a final trace or an intermediate one (used for single trace region benchmarks) - sig = fw_trace.signature_with_no_ctx() - is_fw_final_trace = sig.startswith('def augmented') - - # Filter the output tuple - saved_for_bw_C0 = fw_output[1] if not is_fw_final_trace else fw_output[1][0] - - # Retrieve the compile time arguments - input_args = build_static_args(trace.args) - - # Now, we expected that if the fw trace is a final trace also the bw trace is a final one. And vice versa - if is_fw_final_trace: - # Swap saved_for_backward_traces - saved_for_bw = saved_for_bw_C0, fw_output[1][1] # Saved for backward tuple unpacks in (C0, _) - # Subsitute the static inputs for saved_for_backward with the runtime ones - input_args.pop(0) - input_args.insert(0, saved_for_bw) - else: - # SDPA single trace region backward trace receives as input the saved_for_bw tensors plus some others. - # They are indexed like [saved_for_bw, others...] - # NOTE: This may change in the future - """ - Example: - @torch.no_grad() - @no_autocast - def _cudnn_sdpa_bwd_wrapper(query, key, value, attn_mask, dropout_p=0.0, is_causal=False, *, scale=None): - # query: "cuda:0 bf16[32, 8, 128, 64]" - # key: "cuda:0 bf16[32, 8, 128, 64]" - # value: "cuda:0 bf16[32, 8, 128, 64]" - # dropout_p: "float 0.0" - # is_causal: "bool False" - (t0, t1, t2, t3) = cudnn_sdpa_fwd(query, key, value, None, dropout_p, is_causal, scale=None) - return (t0, [query, key, value, dropout_p, is_causal, t0, t1, t2, t3]) - - @torch.no_grad() - @no_autocast - def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causal, t0, t1, t2, t3, t4): - (t5, t6, t7) = cudnn_sdpa_bwd(t4, query, key, value, None, dropout_p, is_causal, t0, t1, t2, t3, scale=None, cat_grad_qkv=False) - return {'query': t5, 'key': t6, 'value': t7, 'attn_mask': None, 'dropout_p': None, 'is_causal': None, 'scale': None} - - See how the backward trace need t4 as argument recoveered from the static args - """ - updated_input_args = [t for t in saved_for_bw_C0] - updated_input_args.extend(input_args[len(updated_input_args):]) - input_args = updated_input_args - def backward_trace_args_preprocess() -> list: if 'fw_trace' not in kwargs: raise RuntimeError('Set the associated forward trace in order to benchmark backward pass with sdpa executor') @@ -561,9 +401,8 @@ def backward_trace_args_preprocess() -> list: # https://github.com/Lightning-AI/lightning-thunder/pull/214 saved_for_bw_C0 = fw_output[1] if not is_fw_final_trace else fw_output[1][0] - # te_used = is_te_used(trace) - # Retrieve the compile time arguments - input_args = build_static_args(trace.args, te_used=is_te_used(trace)) + # The underlying API will generate TE.Float8 tensors also, hence it must know if TE executor is used or not + input_args = build_static_args(trace.args, te_used=te_used) # Now, we expected that if the fw trace is a final trace also the bw trace is a final one. And vice versa if is_fw_final_trace: @@ -603,9 +442,6 @@ def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causa return input_args - # Input args for the trace to benchmark - # input_args = [] - # Check for correctness if trace.bound_symbols[-1].sym.id != PrimIDs.RETURN: raise AssertionError("Missing return statement") @@ -613,26 +449,24 @@ def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causa if apply_del_last_used: trace = del_last_used(trace) - # # Handle TE traces - te_used = is_te_used(trace) - # sdpa_ex_bw_used = is_sdpa_ex_bw_used(trace) - # if te_used and sdpa_ex_bw_used: - # raise AssertionError("Not handled") - + # Handle TE traces + cd = get_compile_data() + # We might benchmarking a partial trace where the TE symbol is not included yet, in this case rely on the compile option which tells us + # that afterwards at least one TE symbol will be included + # NOTE: compile data could be None if this benchmark util is used outside the compilation process. If this is the case we are benchmarking + # a whole trace (in theory) and is_te_used API will return the needed result. + te_used = (cd.compile_options.get('te_used', False) if cd else False) or is_te_used(trace) if te_used: cached_te_fp8_autocast_value = trace._include_te_fp8_autocast trace._include_te_fp8_autocast = True - # TE_backward_trace_preprocessing() - # Fix sdpaex arguments for backward benchmarks - # elif sdpa_ex_bw_used: - # SDPA_backward_trace_preprocessing() - # "Default" trace, parse the input args...(input args parsing will be performed by the TE trace handling) + + # Build trace arguments: forward trace will receive compile time tensors while + # backward trace will receive dynamic inputs (runtime) to match real training env. if is_backward_trace(trace): input_args = backward_trace_args_preprocess() + # Forward or computational trace, parse the compile time input args... else: - input_args: list = build_static_args(trace.args) - - trace_tok = set_tracectx(trace) + input_args: list = build_static_args(trace.args, te_used=te_used) # Obtain the python executable string executable_str = trace.python() @@ -640,6 +474,8 @@ def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causa if show_func: print(inspect.getsource(executable)) + trace_tok = set_tracectx(trace) + t = float("inf") m = float("inf") answer = None @@ -865,7 +701,6 @@ def reorder_executors_list(executors: Sequence): # Put these in front to be picked up by _get_gradfn_and_executor for _, v in options.items(): - print(v) for ex in v: if are_inputs_names: if ex.name in executors: diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index caf904652b..d321bea478 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -201,7 +201,7 @@ def bw_fn(*args, **kwargs): from thunder.backend_optimizer.optimizer import get_fw_bw_split_backends_options from thunder.backend_optimizer.utils import benchmark_trace - # In order define this unique trace region we need a unique id + # In order define this unique trace region we need an unique id key = (bsym.sym, Executor(f'{id(bsym)}-autotuned'), subkey := _make_cache_key(bsym.args, bsym.kwargs)) # We do check the cache here as the key in the inner fn does not know about this special id cached_result = _cache.get(key, None) if subkey is not None else None @@ -211,7 +211,10 @@ def bw_fn(*args, **kwargs): # Get the possible backends for the current bsym backends = get_fw_bw_split_backends_options(bsym) - assert backends + if not backends: + raise AssertionError( + f"No enabled backends found for {bsym.sym.name} but an executor for that symbol it is present in the executors list. Either remove that from the executors list or enable at least one backend for {bsym.sym.linear} inside 'get_fw_bw_split_backends_options'." + ) cached_executors_list = list(cd.executors_list) # Retrieve all the executors which are requested to be used @@ -232,8 +235,9 @@ def bw_fn(*args, **kwargs): requested_executors_list_for_bsym.insert(0, b) cd.executors_list = requested_executors_list_for_bsym fw_fn, bw_fn, fw_trace, bw_trace = _make_aug_forward_and_backward(True) - fw_time, fw_mem, _ = benchmark_trace(fw_trace, iters=20, apply_del_last_used=False) - bw_time, bw_mem, _ = benchmark_trace(bw_trace, iters=20, apply_del_last_used=False, fw_trace=fw_trace) + # What should be the optimal iter? + fw_time, fw_mem, _ = benchmark_trace(fw_trace, iters=100, apply_del_last_used=False) + bw_time, bw_mem, _ = benchmark_trace(bw_trace, iters=100, apply_del_last_used=False, fw_trace=fw_trace) cost = fw_time + bw_time if cd.compile_options['autotune_type'] == OptimizerType.RUNTIME else fw_mem + bw_mem if cost < best.cost: best = SplitFwBwBenchmarkUtils(cost = cost, fw_fn = fw_fn, bw_fn = bw_fn, executor = b) @@ -244,6 +248,8 @@ def bw_fn(*args, **kwargs): # Update the compile options cd.compile_options["executors_placed_by_fw_bw_split"].add(best.executor) + from thunder.executors.transformer_engineex import transformer_engine_ex + cd.compile_options |= {'te_used': True if best.executor == transformer_engine_ex else False} # Restore executor list for downstream optimizations cd.executors_list = cached_executors_list # The executors used in this pass will be updated after the termination of the forward_and_backward_from_trace call From a0cc9ee4c015a70f6e1452684a3a769396402ff9 Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 22 Aug 2024 13:18:39 +0300 Subject: [PATCH 069/171] Restore input --- thunder/benchmarks/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/benchmarks/utils.py b/thunder/benchmarks/utils.py index 395f614ffe..00c7ac844c 100644 --- a/thunder/benchmarks/utils.py +++ b/thunder/benchmarks/utils.py @@ -125,7 +125,7 @@ def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int) for m, input, label in zip(models, inputs, labels): # Warm up for _ in range(warm_up_iters): - y = m(*input if isinstance(input, tuple) else input) + y = m(input) y.sum().backward() start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] @@ -139,7 +139,7 @@ def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int) torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) start_events[i].record(stream) - y = m(*input if isinstance(input, tuple) else input) + y = m(input) loss = y.sum() loss.backward() end_events[i].record(stream) From 96a4f93cc5aa55ed871056647dcefa6fb0a1c066 Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 22 Aug 2024 13:21:37 +0300 Subject: [PATCH 070/171] Updated litgpt runner --- examples/dev/litGPT.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 85dc00d35a..76fd8430d6 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -10,7 +10,7 @@ def __init__(self, layers: int, autotune_type: str, batch_size: int) -> None: self.autotune_type = autotune_type self.batch_size = batch_size -layers = [Test(8, 'runtime', 1), Test(8, 'runtime', 4)] +layers = [Test(4, 'runtime', 1)] model_name = 'Llama-3-8B' @@ -25,7 +25,7 @@ def __init__(self, layers: int, autotune_type: str, batch_size: int) -> None: x = torch.randint(1, model.config.vocab_size, (test.batch_size, 512)) jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type=test.autotune_type, executors = ['nvfuser', 'torchcompile', 'cudnn', 'torch', 'python']) + jmodel_auto = thunder.jit(model, autotune_type=test.autotune_type, executors = ['nvfuser', 'torchcompile', 'cudnn', 'torch', 'python'], use_cudagraphs=False) print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) From 2e796a31e6cd5cf1b7d0d607a1d8a5effbdd5655 Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 22 Aug 2024 13:23:22 +0300 Subject: [PATCH 071/171] Updated litgpt runner --- examples/dev/litGPT.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 76fd8430d6..b8d8c533b5 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -25,7 +25,12 @@ def __init__(self, layers: int, autotune_type: str, batch_size: int) -> None: x = torch.randint(1, model.config.vocab_size, (test.batch_size, 512)) jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type=test.autotune_type, executors = ['nvfuser', 'torchcompile', 'cudnn', 'torch', 'python'], use_cudagraphs=False) + jmodel_auto = thunder.jit( + model, + autotune_type=test.autotune_type, + executors=["nvfuser", "torchcompile", "cudnn", "sdpa", "fa3"], + use_cudagraphs=False, + ) print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) From a4d5fa5d3534c5117b2274cbeb1866f6406d98eb Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 22 Aug 2024 16:53:00 +0300 Subject: [PATCH 072/171] Enhanced logs --- thunder/backend_optimizer/optimizer.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index f3e74b8dd5..e44ef117a3 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -542,8 +542,9 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): bound_symbol_groups = fuse_bound_symbols( self.trace, merge_fn ) - log(f"Num of groups = {len(bound_symbol_groups)}", level=LogLevel.DEBUG) + log(f"Number of Fusion groups = {len(bound_symbol_groups)}", level=LogLevel.DEBUG) + # Print fusion groups if requested for id, group in enumerate(bound_symbol_groups): log(f"Group id: {id}", level=LogLevel.DEBUG) for sub in group: @@ -555,9 +556,9 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): dict_mem_strat: dict[str, Executor] = {} increasing_symbols = [] for group_id, group in enumerate(bound_symbol_groups): - log(f"Group id: {group_id}", level=LogLevel.DEBUG) - log(f"group start = {group[0].sym.name}", level=LogLevel.DEBUG) - log(f"group end = {group[-1].sym.name}", level=LogLevel.DEBUG) + log(f"Fusion group id: {group_id}", level=LogLevel.DEBUG) + log(f"Fusion group start = [{group[0].output.name if hasattr(group[0].output, 'name') else 'unknown'} = {group[0].sym.name}]", level=LogLevel.DEBUG) + log(f"Fusion group end = [{group[-1].output.name if hasattr(group[-1].output, 'name') else 'unknown'} = {group[-1].sym.name}]", level=LogLevel.DEBUG) if group[0].sym.name != "return": increasing_symbols += group From 0782407e0da6e05b148195dca79af426053e8eb4 Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 22 Aug 2024 16:53:48 +0300 Subject: [PATCH 073/171] Unpacking sequences during search of not used proxies --- thunder/backend_optimizer/utils.py | 63 +++++++++++++++++++----------- 1 file changed, 41 insertions(+), 22 deletions(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 18601b7380..78a41fd246 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -50,10 +50,10 @@ def get_first_available_operator_executor( return Executor(name=empty_hash) -def get_not_used_intermediate_outsputs(trace_in: TraceCtx) -> list[TensorProxy]: - def is_in_sequence(seq: Sequence[Any], t: TensorProxy): +def get_not_used_intermediate_outsputs(trace_in: TraceCtx) -> list[Proxy]: + def is_in_sequence(seq: Sequence[Any], t: Proxy): for e in seq: - if isinstance(e, TensorProxy) and e.name == t.name: + if hasattr(e, "name") and hasattr(t, "name") and e.name == t.name: return True return False @@ -63,25 +63,44 @@ def is_possible_out(name: str): num = name[1:] return num.isdigit() - ans: list[TensorProxy] = [] - for b in trace_in.bound_symbols: + def flatten_sequence(sequence: Sequence) -> list: + res = [] + for e in sequence: + if isinstance(e, Sequence): + res.extend(flatten_sequence(e)) + # Skip Nones as they are not useful + elif e is not None: + res.append(e) + return res + + def unpack_output(out) -> Sequence[Proxy]: + if issubclass(type(out), Proxy): + return [out] + elif isinstance(out, Sequence): + return flatten_sequence(out) + else: + raise RuntimeError(f'Unpack operation not defined for {type(out)}') + + ans: list[Proxy] = [] + # Currently this is O(max(len(bsym.output)) * N^2) + # Can we check only bsym after the one in the outer loop in the inner loop (over trace.bound_symbols) ? + for a in trace_in.bound_symbols: f = False - # Not a tensor - if not isinstance(b.output, TensorProxy): - continue - # Not a produced tensor - if not is_possible_out(b.output.name): - continue - for test in trace_in.bound_symbols: - if ( - test.args is not None - and (isinstance(test.args, tuple) or isinstance(test.args, list)) - and is_in_sequence(test.args, b.output) - ): - f = True - break - if not f: - ans.append(b.output) + unpacked_out = unpack_output(a.output) + for e in unpacked_out: + # None values are checked inside the unpack_output fn + for b in trace_in.bound_symbols: + if ( + b.args is not None + and isinstance(b.args, Sequence) + and is_in_sequence(b.args, e) + ): + f = True + break + if not f: + ans.append(e) + from thunder.backend_optimizer.optimizer import log, LogLevel + log(f'Returning not used proxies: {[p.name for p in ans]}', level=LogLevel.DEBUG) return ans @@ -224,7 +243,7 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: # Restore subsymbols # TODO (matteochen): Improve this search for k, v in cached_subsymbols.items(): - # Note some symbols may be cut out by the fusion pass -> CSE + # NOTE: Some symbols may be cut out by the fusion pass -> CSE # For example: # a = 1 + 1 # b = 1 + 1 From ed1a2e27c6f6c1def766343646e565bf36046cee Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 22 Aug 2024 16:55:59 +0300 Subject: [PATCH 074/171] Updated comment --- thunder/backend_optimizer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 78a41fd246..576b22ab84 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -412,7 +412,7 @@ def backward_trace_args_preprocess() -> list: sig = fw_trace.signature_with_no_ctx() is_fw_final_trace = sig.startswith('def augmented') - # Filter the output tuple + # Filter the C0 tuple # These location might change if the implementation of the automatic # differentiation transform changes. The saved tensors are the second output # of the return statement. There's a prototype changing the saved tensors to From c1560b01edc5e9ec4a1b9c0c54ff46277e892c5d Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 22 Aug 2024 16:59:10 +0300 Subject: [PATCH 075/171] Updated model tests --- examples/dev/LLaMAMLP.py | 28 +++++++++++++++++----------- examples/dev/nanogpt.py | 3 +-- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/examples/dev/LLaMAMLP.py b/examples/dev/LLaMAMLP.py index 9e89b0139c..708b529ab5 100644 --- a/examples/dev/LLaMAMLP.py +++ b/examples/dev/LLaMAMLP.py @@ -1,6 +1,7 @@ import torch import thunder -from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_fw_bw_benchmark_nvsight, torch_total_benchmark +from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_total_benchmark + class LLaMAMLP(torch.nn.Module): def __init__(self, n_embd, intermediate_size) -> None: @@ -8,13 +9,15 @@ def __init__(self, n_embd, intermediate_size) -> None: self.fc_1 = torch.nn.Linear(n_embd, intermediate_size, bias=False) self.fc_2 = torch.nn.Linear(n_embd, intermediate_size, bias=False) self.proj = torch.nn.Linear(intermediate_size, n_embd, bias=False) + def forward(self, x: torch.Tensor) -> torch.Tensor: x_fc_1 = self.fc_1(x) x_fc_2 = self.fc_2(x) x = torch.nn.functional.silu(x_fc_1) * x_fc_2 return self.proj(x) -with torch.device('cuda'): + +with torch.device("cuda"): mult = 1 a = 4096 * mult b = 11008 * mult @@ -23,13 +26,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: model = LLaMAMLP(a, b) jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type='runtime', use_cudagraphs=False) + jmodel_auto = thunder.jit( + model, + autotune_type="runtime", + executors=["nvfuser", "torchcompile", "transformer_engine"], + use_cudagraphs=False, + ) - print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) - print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) + print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item()) + print("deviation auto:", (jmodel_auto(x) - model(x)).abs().max().item()) iters = 100 - print('Results with thunder benchmark:') + print("Results with thunder benchmark:") fw_traces = [ thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], @@ -43,11 +51,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) callables = [jmodel_def, jmodel_auto] - labels = ['def', 'auto'] + labels = ["def", "auto"] inputs = [x, x] - print('\nResults with torch fw bw benchmark:') + print("\nResults with torch fw bw benchmark:") torch_fw_bw_benchmark(callables, labels, inputs, iters) - print('\nResults with torch total benchmark:') + print("\nResults with torch total benchmark:") torch_total_benchmark(callables, labels, inputs, iters) - torch_fw_bw_benchmark_nvsight(callables, labels, inputs, iters) - diff --git a/examples/dev/nanogpt.py b/examples/dev/nanogpt.py index 55005b5b03..bcf7e91781 100644 --- a/examples/dev/nanogpt.py +++ b/examples/dev/nanogpt.py @@ -50,8 +50,7 @@ def run(target: str = 'runtime'): model.to(device) jmodel_def = thunder.jit(model) - # Currently sdpa does not work? - jmodel_auto = thunder.jit(model, autotune_type=target, executors = ['torchcompile', 'nvfuser', 'cudnn', 'sdpa'], use_cudagraphs=False) + jmodel_auto = thunder.jit(model, autotune_type=target, executors = ['torchcompile', 'nvfuser', 'cudnn', 'sdpa', 'transformer_engine'], use_cudagraphs=False) if compile: print("Compiling model...") From b890ac603ff333c6c2d0771944d16501a2017613 Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 22 Aug 2024 17:00:18 +0300 Subject: [PATCH 076/171] Updated model tests --- examples/dev/litGPT.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index b8d8c533b5..51de4ad613 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -10,7 +10,7 @@ def __init__(self, layers: int, autotune_type: str, batch_size: int) -> None: self.autotune_type = autotune_type self.batch_size = batch_size -layers = [Test(4, 'runtime', 1)] +layers = [Test(4, 'runtime', 1), Test(4, 'memory', 1)] model_name = 'Llama-3-8B' @@ -28,7 +28,7 @@ def __init__(self, layers: int, autotune_type: str, batch_size: int) -> None: jmodel_auto = thunder.jit( model, autotune_type=test.autotune_type, - executors=["nvfuser", "torchcompile", "cudnn", "sdpa", "fa3"], + executors=["nvfuser", "torchcompile", "cudnn", "sdpa", "fa3", "transformer_engine"], use_cudagraphs=False, ) From 4228ff89ec3d9b391717d305bb4b45cfb5ddf48d Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 23 Aug 2024 11:21:50 +0300 Subject: [PATCH 077/171] Updated log and comment --- thunder/backend_optimizer/optimizer.py | 4 ++-- thunder/core/vjp_utils.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index e44ef117a3..6bf8c84482 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -563,10 +563,10 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): if group[0].sym.name != "return": increasing_symbols += group - # Is not a fusion region, get the optimal executor (OperatorExecutor) + # We assign to a Fusion executor only region with at least 2 elements. Otherwise let the best OperatorExecutor pick the symbol up if len(group) < 2: current_bsym = group[0] - log(f"--> Single group: {current_bsym.sym.name}", level=LogLevel.DEBUG) + log(f"--> Single group: [{current_bsym.output.name if hasattr(current_bsym.output, 'name') else 'unknown'} = {current_bsym.sym.name}]", level=LogLevel.DEBUG) # Filter out all possible candidates for the current symbol candidate_executors = [ ex diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index d321bea478..112b4d23ba 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -236,6 +236,7 @@ def bw_fn(*args, **kwargs): cd.executors_list = requested_executors_list_for_bsym fw_fn, bw_fn, fw_trace, bw_trace = _make_aug_forward_and_backward(True) # What should be the optimal iter? + # TODO: make benchmark info taken from an autotuner config fw_time, fw_mem, _ = benchmark_trace(fw_trace, iters=100, apply_del_last_used=False) bw_time, bw_mem, _ = benchmark_trace(bw_trace, iters=100, apply_del_last_used=False, fw_trace=fw_trace) cost = fw_time + bw_time if cd.compile_options['autotune_type'] == OptimizerType.RUNTIME else fw_mem + bw_mem From 3682c8297050082b09eb9b060ca7c80ada5231fc Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 23 Aug 2024 13:18:55 +0300 Subject: [PATCH 078/171] Fixed comment --- thunder/backend_optimizer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 576b22ab84..c1d7650c80 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -619,7 +619,7 @@ def update_compile_options_executor_list_after_fw_bw_split() -> None: if ex in executors_list: executors_list.remove(ex) - # Putting at the front event though order does not matter + # Putting at the front even though order does not matter for ex in cd.compile_options['executors_placed_by_fw_bw_split']: executors_list.insert(0, ex) From f6b8d163a5fc2fadc09fd6abb14dc824796ac4d1 Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 23 Aug 2024 15:41:44 +0300 Subject: [PATCH 079/171] Removed file --- examples/dev/sdpaex_qsize.txt:q | 707 -------------------------------- 1 file changed, 707 deletions(-) delete mode 100644 examples/dev/sdpaex_qsize.txt:q diff --git a/examples/dev/sdpaex_qsize.txt:q b/examples/dev/sdpaex_qsize.txt:q deleted file mode 100644 index a27d6e5e39..0000000000 --- a/examples/dev/sdpaex_qsize.txt:q +++ /dev/null @@ -1,707 +0,0 @@ -filtered: {'linear': [], 'scaled_dot_product_attention': [thunder.extend.OperatorExecutor('sdpa'), thunder.extend.OperatorExecutor('cudnn')]} -->unpack_trivial -->unpack_trivial -->unpack_trivial -->scaled_dot_product_attention -3 -> cudnn -->scaled_dot_product_attention -4 -> fa3 -->add -->python_return -Assigned: {'scaled_dot_product_attention': {3: thunder.extend.OperatorExecutor('cudnn'), 4: thunder.extend.OperatorExecutor('fa3')}} -Input ex: (thunder.extend.OperatorExecutor('nvfuser'), thunder.extend.OperatorExecutor('torchcompile'), thunder.extend.OperatorExecutor('sdpa'), thunder.extend.OperatorExecutor('cudnn'), thunder.extend.OperatorExecutor('torch'), thunder.extend.OperatorExecutor('python')) -================================================================================ Autotune: ================================================================================ Before Autotune Tuning: Optimizing linear -================================================================================ Autotune: ================================================================================ Before Autotune Tuning: Skipping optimization for linear as not requested or not present in computation_trc. -================================================================================ Autotune: ================================================================================ Before Autotune Tuning: Optimizing scaled_dot_product_attention -================================================================================ Autotune: ================================================================================ Before Autotune Tuning: Executors to bench for scaled_dot_product_attention: [thunder.extend.OperatorExecutor('sdpa'), thunder.extend.OperatorExecutor('cudnn'), thunder.extend.OperatorExecutor('torch')] -================================================================================ Autotune: ================================================================================ Before Autotune Tuning: Testing compile data executors: [thunder.extend.OperatorExecutor('sdpa'), thunder.extend.OperatorExecutor('torch'), thunder.extend.OperatorExecutor('nvfuser'), thunder.extend.OperatorExecutor('torchcompile'), thunder.extend.OperatorExecutor('python')] -#FW after trace call -._trace at 0x7cdda07cfe20> -#FW after trace call -._trace at 0x7cdd9ee381f0> -#FW after trace call -._trace at 0x7cdda07cfe20> -#FW after trace call -._trace at 0x7cdd9ee38b80> -#FW after trace call -._trace at 0x7cdda07cfd90> -================================================================================ Autotune: Executors: -================================================================================ Autotune: sdpa -> is operator = True, is fusion = False -================================================================================ Autotune: torch -> is operator = True, is fusion = False -================================================================================ Autotune: nvfuser -> is operator = False, is fusion = True -================================================================================ Autotune: torchcompile -> is operator = False, is fusion = True -================================================================================ Autotune: python -> is operator = True, is fusion = False -================================================================================ Autotune: New forward trace to optimize (strat = OptimizerType.RUNTIME): -# Constructed by Dead Code Elimination (took 0 milliseconds) -import thunder -import thunder.core.dtypes as dtypes -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(query, key, value): - # query: "cuda:0 bf16[4, 128, 6, 64]" - # key: "cuda:0 bf16[4, 128, 6, 64]" - # value: "cuda:0 bf16[4, 128, 6, 64]" - (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) - (t7, t8, t9, t10, _, _, t11, t12, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) - t14 = prims.convert_element_type(t0, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" - t15 = prims.convert_element_type(t7, dtypes.float32) # t15: "cuda:0 f32[4, 128, 6, 64]" - t16 = ltorch.add(t14, t15, alpha=None) # t16: "cuda:0 f32[4, 128, 6, 64]" - # t16 = prims.add(t14, t15) # t16: "cuda:0 f32[4, 128, 6, 64]" - t17 = prims.convert_element_type(t16, dtypes.bfloat16) # t17: "cuda:0 bf16[4, 128, 6, 64]" - return {'output': t17, 'flat_args': [query, key, value], 'flat_output': (t17,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t7, t8, t9, t10, t11, t12), ()) -================================================================================ Autotune: Searching best placement for fusion executor = nvfuser -================================================================================ Autotune: Searching best placement for fusion executor = torchcompile -================================================================================ Autotune: Benchmark fw end: Trace = [nvfuser] (time = 0.014675200078636408 ms, mem = 0.002941131591796875 GB)": -# Constructed by Transform for operator executor execution (took 0 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(query, key, value): - # query: "cuda:0 bf16[4, 128, 6, 64]" - # key: "cuda:0 bf16[4, 128, 6, 64]" - # value: "cuda:0 bf16[4, 128, 6, 64]" - (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) - [t17] = nvFusion0(t0) - # t14 = prims.convert_element_type(t0, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" - # t16 = prims.add(t14, t14) # t16: "cuda:0 f32[4, 128, 6, 64]" - # t17 = prims.convert_element_type(t16, dtypes.bfloat16) # t17: "cuda:0 bf16[4, 128, 6, 64]" - return {'output': t17, 'flat_args': [query, key, value], 'flat_output': (t17,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t0, t1, t2, t3, t4, t5), ()) -================================================================================ Autotune: Benchmark fw end: Trace = [nvfuser] (time = 0.014627200085669756 ms, mem = 0.002941131591796875 GB)": -# Constructed by Transform for operator executor execution (took 0 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(query, key, value): - # query: "cuda:0 bf16[4, 128, 6, 64]" - # key: "cuda:0 bf16[4, 128, 6, 64]" - # value: "cuda:0 bf16[4, 128, 6, 64]" - (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) - [t17] = nvFusion0(t0) - # t14 = prims.convert_element_type(t0, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" - # t16 = prims.add(t14, t14) # t16: "cuda:0 f32[4, 128, 6, 64]" - # t17 = prims.convert_element_type(t16, dtypes.bfloat16) # t17: "cuda:0 bf16[4, 128, 6, 64]" - return {'output': t17, 'flat_args': [query, key, value], 'flat_output': (t17,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t0, t1, t2, t3, t4, t5), ()) -================================================================================ Autotune: Caching fw candidate [compile option: None] -# Constructed by Transform for operator executor execution (took 0 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(query, key, value): - # query: "cuda:0 bf16[4, 128, 6, 64]" - # key: "cuda:0 bf16[4, 128, 6, 64]" - # value: "cuda:0 bf16[4, 128, 6, 64]" - (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) - [t17] = nvFusion0(t0) - # t14 = prims.convert_element_type(t0, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" - # t16 = prims.add(t14, t14) # t16: "cuda:0 f32[4, 128, 6, 64]" - # t17 = prims.convert_element_type(t16, dtypes.bfloat16) # t17: "cuda:0 bf16[4, 128, 6, 64]" - return {'output': t17, 'flat_args': [query, key, value], 'flat_output': (t17,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t0, t1, t2, t3, t4, t5), ()) -================================================================================ Autotune: Caching fw candidate [compile option: None] -# Constructed by Transform for operator executor execution (took 0 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(query, key, value): - # query: "cuda:0 bf16[4, 128, 6, 64]" - # key: "cuda:0 bf16[4, 128, 6, 64]" - # value: "cuda:0 bf16[4, 128, 6, 64]" - (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) - [t17] = nvFusion0(t0) - # t14 = prims.convert_element_type(t0, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" - # t16 = prims.add(t14, t14) # t16: "cuda:0 f32[4, 128, 6, 64]" - # t17 = prims.convert_element_type(t16, dtypes.bfloat16) # t17: "cuda:0 bf16[4, 128, 6, 64]" - return {'output': t17, 'flat_args': [query, key, value], 'flat_output': (t17,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t0, t1, t2, t3, t4, t5), ()) -================================================================================ Autotune: End fw time mem pair -================================================================================ Autotune: Benchmark fw end: Trace = [torchcompile] (time = 0.026291199773550034 ms, mem = 0.00331878662109375 GB)": -# Constructed by Transform for operator executor execution (took 0 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(query, key, value): - # query: "cuda:0 bf16[4, 128, 6, 64]" - # key: "cuda:0 bf16[4, 128, 6, 64]" - # value: "cuda:0 bf16[4, 128, 6, 64]" - (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) - (t7, t8, t9, t10, _, _, t11, t12, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) - [t17] = TorchCompile0(t0, t7) - # t14 = prims.convert_element_type(t0, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" - # t15 = prims.convert_element_type(t7, dtypes.float32) # t15: "cuda:0 f32[4, 128, 6, 64]" - # t16 = ltorch.add(t14, t15, alpha=None) # t16: "cuda:0 f32[4, 128, 6, 64]" - # t16 = prims.add(t14, t15) # t16: "cuda:0 f32[4, 128, 6, 64]" - # t17 = prims.convert_element_type(t16, dtypes.bfloat16) # t17: "cuda:0 bf16[4, 128, 6, 64]" - return {'output': t17, 'flat_args': [query, key, value], 'flat_output': (t17,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t7, t8, t9, t10, t11, t12), ()) -================================================================================ Autotune: Benchmark fw end: Trace = [torchcompile] (time = 0.02619839971885085 ms, mem = 0.00331878662109375 GB)": -# Constructed by Transform for operator executor execution (took 0 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(query, key, value): - # query: "cuda:0 bf16[4, 128, 6, 64]" - # key: "cuda:0 bf16[4, 128, 6, 64]" - # value: "cuda:0 bf16[4, 128, 6, 64]" - (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) - (t7, t8, t9, t10, _, _, t11, t12, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) - [t17] = TorchCompile0(t0, t7) - # t14 = prims.convert_element_type(t0, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" - # t15 = prims.convert_element_type(t7, dtypes.float32) # t15: "cuda:0 f32[4, 128, 6, 64]" - # t16 = ltorch.add(t14, t15, alpha=None) # t16: "cuda:0 f32[4, 128, 6, 64]" - # t16 = prims.add(t14, t15) # t16: "cuda:0 f32[4, 128, 6, 64]" - # t17 = prims.convert_element_type(t16, dtypes.bfloat16) # t17: "cuda:0 bf16[4, 128, 6, 64]" - return {'output': t17, 'flat_args': [query, key, value], 'flat_output': (t17,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t7, t8, t9, t10, t11, t12), ()) -================================================================================ Autotune: Caching fw candidate [compile option: None] -# Constructed by Transform for operator executor execution (took 0 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(query, key, value): - # query: "cuda:0 bf16[4, 128, 6, 64]" - # key: "cuda:0 bf16[4, 128, 6, 64]" - # value: "cuda:0 bf16[4, 128, 6, 64]" - (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) - (t7, t8, t9, t10, _, _, t11, t12, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) - [t17] = TorchCompile0(t0, t7) - # t14 = prims.convert_element_type(t0, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" - # t15 = prims.convert_element_type(t7, dtypes.float32) # t15: "cuda:0 f32[4, 128, 6, 64]" - # t16 = ltorch.add(t14, t15, alpha=None) # t16: "cuda:0 f32[4, 128, 6, 64]" - # t16 = prims.add(t14, t15) # t16: "cuda:0 f32[4, 128, 6, 64]" - # t17 = prims.convert_element_type(t16, dtypes.bfloat16) # t17: "cuda:0 bf16[4, 128, 6, 64]" - return {'output': t17, 'flat_args': [query, key, value], 'flat_output': (t17,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t7, t8, t9, t10, t11, t12), ()) -================================================================================ Autotune: Caching fw candidate [compile option: None] -# Constructed by Transform for operator executor execution (took 0 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(query, key, value): - # query: "cuda:0 bf16[4, 128, 6, 64]" - # key: "cuda:0 bf16[4, 128, 6, 64]" - # value: "cuda:0 bf16[4, 128, 6, 64]" - (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) - (t7, t8, t9, t10, _, _, t11, t12, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) - [t17] = TorchCompile0(t0, t7) - # t14 = prims.convert_element_type(t0, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" - # t15 = prims.convert_element_type(t7, dtypes.float32) # t15: "cuda:0 f32[4, 128, 6, 64]" - # t16 = ltorch.add(t14, t15, alpha=None) # t16: "cuda:0 f32[4, 128, 6, 64]" - # t16 = prims.add(t14, t15) # t16: "cuda:0 f32[4, 128, 6, 64]" - # t17 = prims.convert_element_type(t16, dtypes.bfloat16) # t17: "cuda:0 bf16[4, 128, 6, 64]" - return {'output': t17, 'flat_args': [query, key, value], 'flat_output': (t17,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t7, t8, t9, t10, t11, t12), ()) -================================================================================ Autotune: End fw time mem pair -================================================================================ Autotune: New backward trace to optimize (strat = OptimizerType.RUNTIME): -# Constructed by Backward pass -import thunder -import thunder.core.dtypes as dtypes -import thunder.core.prims as prims -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def backward_fn(saved_for_backward, cotangents): - # saved_for_backward: "Collection" - # cotangents: "Collection" - C0, C1, = saved_for_backward - t18, = cotangents - query, key, value, t0, t1, t2, t3, t4, t5, t7, t8, t9, t10, t11, t12, = C0 - # C1 (empty sequence) - t14 = prims.convert_element_type(t18, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" - t15 = prims.convert_element_type(t14, dtypes.bfloat16) # t15: "cuda:0 bf16[4, 128, 6, 64]" - t16 = prims.convert_element_type(t14, dtypes.bfloat16) # t16: "cuda:0 bf16[4, 128, 6, 64]" - (t17, t19, t20) = sdpafx_scaled_dot_product_efficient_attention_backward(t15, query, key, value, t7, t8, t9, t10, 6, 6, 0.0, False, t11, t12, scale=None) - (t21, t22, t23) = sdpafx_scaled_dot_product_efficient_attention_backward(t16, query, key, value, t0, t1, t2, t3, 6, 6, 0.0, False, t4, t5, scale=None) - t24 = prims.convert_element_type(t17, dtypes.float32) # t24: "cuda:0 f32[4, 128, 6, 64]" - t25 = prims.convert_element_type(t21, dtypes.float32) # t25: "cuda:0 f32[4, 128, 6, 64]" - t26 = prims.add(t24, t25) # t26: "cuda:0 f32[4, 128, 6, 64]" - t27 = prims.convert_element_type(t26, dtypes.bfloat16) # t27: "cuda:0 bf16[4, 128, 6, 64]" - t28 = prims.convert_element_type(t19, dtypes.float32) # t28: "cuda:0 f32[4, 128, 6, 64]" - t29 = prims.convert_element_type(t22, dtypes.float32) # t29: "cuda:0 f32[4, 128, 6, 64]" - t30 = prims.add(t28, t29) # t30: "cuda:0 f32[4, 128, 6, 64]" - t31 = prims.convert_element_type(t30, dtypes.bfloat16) # t31: "cuda:0 bf16[4, 128, 6, 64]" - t32 = prims.convert_element_type(t20, dtypes.float32) # t32: "cuda:0 f32[4, 128, 6, 64]" - t33 = prims.convert_element_type(t23, dtypes.float32) # t33: "cuda:0 f32[4, 128, 6, 64]" - t34 = prims.add(t32, t33) # t34: "cuda:0 f32[4, 128, 6, 64]" - t35 = prims.convert_element_type(t34, dtypes.bfloat16) # t35: "cuda:0 bf16[4, 128, 6, 64]" - return (t27, t31, t35) -================================================================================ Autotune: Backward optimization with fw from nvfuser -Current fw cached ctx: -# Constructed by Autotuned transform for execution (took 3006 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(query, key, value): - # query: "cuda:0 bf16[4, 128, 6, 64]" - # key: "cuda:0 bf16[4, 128, 6, 64]" - # value: "cuda:0 bf16[4, 128, 6, 64]" - (t0, t1, t2, t3, _, _, t4, t5, _) = sdpafx_grad_forward_scaled_dot_product_efficient_attention(query, key, value, 0.0, False, scale=None) - [t17] = nvFusion0(t0) - # t14 = prims.convert_element_type(t0, dtypes.float32) # t14: "cuda:0 f32[4, 128, 6, 64]" - # t16 = prims.add(t14, t14) # t16: "cuda:0 f32[4, 128, 6, 64]" - # t17 = prims.convert_element_type(t16, dtypes.bfloat16) # t17: "cuda:0 bf16[4, 128, 6, 64]" - return {'output': t17, 'flat_args': [query, key, value], 'flat_output': (t17,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t0, t1, t2, t3, t4, t5), ()) -Options: None -================================================================================ Autotune: Searching best placement for fusion executor = nvfuser -#FW after trace call -._trace at 0x7cdd4a1d77f0> -#FW after trace call -._trace at 0x7cdd4a1d4160> -#FW after trace call -._trace at 0x7cdd4a0201f0> -#FW after trace call -._trace at 0x7cdd4a1d4f70> -#FW after trace call -._trace at 0x7cdd4a1d4160> -TORCH DTYPE t2 torch.int64 -TORCH DTYPE t3 torch.int64 -TORCH DTYPE t4 torch.int64 -TORCH DTYPE t5 torch.int64 -#FN EXECUTION FAILED: -# Constructed by Delete Last Used (took 0 milliseconds) -from torch import Tensor -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def backward_fn(saved_for_backward, cotangents): - # saved_for_backward: "Collection" - # cotangents: "Collection" - C0, _, = saved_for_backward - del saved_for_backward - t18, = cotangents - del cotangents - key, query, t0, t1, t2, t3, t4, t5, value, = C0 - del C0 - (t17, t19, t20) = sdpafx_scaled_dot_product_efficient_attention_backward(t18, query, key, value, t0, t1, t2, t3, 6, 6, 0.0, False, t4, t5, scale=None) - del t18, query, key, value, t0, t1, t2, t3, t4, t5 - [t24] = nvFusion1(t17) - t28 = Tensor.to(t19, torch.float32, copy=True) # t28: "cuda:0 f32[4, 128, 6, 64]" - del t19 - t32 = Tensor.to(t20, torch.float32, copy=True) # t32: "cuda:0 f32[4, 128, 6, 64]" - del t20 - t25 = Tensor.to(t17, torch.float32, copy=True) # t25: "cuda:0 f32[4, 128, 6, 64]" - del t17 - t26 = torch.add(t24, t25) # t26: "cuda:0 f32[4, 128, 6, 64]" - del t24, t25 - t30 = torch.add(t28, t28) # t30: "cuda:0 f32[4, 128, 6, 64]" - del t28 - t34 = torch.add(t32, t32) # t34: "cuda:0 f32[4, 128, 6, 64]" - del t32 - t27 = Tensor.to(t26, torch.bfloat16, copy=True) # t27: "cuda:0 bf16[4, 128, 6, 64]" - del t26 - t31 = Tensor.to(t30, torch.bfloat16, copy=True) # t31: "cuda:0 bf16[4, 128, 6, 64]" - del t30 - t35 = Tensor.to(t34, torch.bfloat16, copy=True) # t35: "cuda:0 bf16[4, 128, 6, 64]" - del t34 - return t27, t31, t35 -Traceback (most recent call last): - File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 505, in benchmark_trace - t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) - File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 346, in compute_time_cost_ms - raise e - File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 313, in compute_time_cost_ms - fn(*args) - File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context - return func(*args, **kwargs) - File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast - return func(*args, **kwargs) - File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast - return func(*args, **kwargs) - File "thunder.backward_fn_48", line 17, in backward_fn - (t17, t19, t20) = sdpafx_scaled_dot_product_efficient_attention_backward(t18, query, key, value, t0, t1, t2, t3, 6, 6, 0.0, False, t4, t5, scale=None) - File "/workspace/workdir/thunder/executors/sdpaex.py", line 453, in _scaled_dot_product_flash_attention_backward_impl - grads = torch.ops.aten._scaled_dot_product_flash_attention_backward( - File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1127, in __call__ - return self._op(*args, **(kwargs or {})) - File "/usr/local/lib/python3.10/dist-packages/torch/utils/_device.py", line 79, in __torch_function__ - return func(*args, **kwargs) - File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1127, in __call__ - return self._op(*args, **(kwargs or {})) -RuntimeError: cu_seqlens_q must have dtype int32 - -#FW after trace call -._trace at 0x7cdd9ee3bc70> -TORCH DTYPE t2 torch.int64 -TORCH DTYPE t3 torch.int64 -TORCH DTYPE t4 torch.int64 -TORCH DTYPE t5 torch.int64 -#FN EXECUTION FAILED: -# Constructed by Delete Last Used (took 0 milliseconds) -from torch import Tensor -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def backward_fn(saved_for_backward, cotangents): - # saved_for_backward: "Collection" - # cotangents: "Collection" - C0, _, = saved_for_backward - del saved_for_backward - t18, = cotangents - del cotangents - key, query, t0, t1, t2, t3, t4, t5, value, = C0 - del C0 - (t17, t19, t20) = sdpafx_scaled_dot_product_efficient_attention_backward(t18, query, key, value, t0, t1, t2, t3, 6, 6, 0.0, False, t4, t5, scale=None) - del t18, query, key, value, t0, t1, t2, t3, t4, t5 - t24 = Tensor.to(t17, torch.float32, copy=True) # t24: "cuda:0 f32[4, 128, 6, 64]" - [t27, t31, t35] = nvFusion1(t17, t24, t19, t20) - del t17, t24, t19, t20 - return t27, t31, t35 -Traceback (most recent call last): - File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 505, in benchmark_trace - t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) - File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 346, in compute_time_cost_ms - raise e - File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 313, in compute_time_cost_ms - fn(*args) - File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context - return func(*args, **kwargs) - File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast - return func(*args, **kwargs) - File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast - return func(*args, **kwargs) - File "thunder.backward_fn_49", line 17, in backward_fn - (t17, t19, t20) = sdpafx_scaled_dot_product_efficient_attention_backward(t18, query, key, value, t0, t1, t2, t3, 6, 6, 0.0, False, t4, t5, scale=None) - File "/workspace/workdir/thunder/executors/sdpaex.py", line 453, in _scaled_dot_product_flash_attention_backward_impl - grads = torch.ops.aten._scaled_dot_product_flash_attention_backward( - File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1127, in __call__ - return self._op(*args, **(kwargs or {})) - File "/usr/local/lib/python3.10/dist-packages/torch/utils/_device.py", line 79, in __torch_function__ - return func(*args, **kwargs) - File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1127, in __call__ - return self._op(*args, **(kwargs or {})) -RuntimeError: cu_seqlens_q must have dtype int32 - -#FW after trace call -._trace at 0x7cdd4a0875b0> -TORCH DTYPE t2 torch.int64 -TORCH DTYPE t3 torch.int64 -TORCH DTYPE t4 torch.int64 -TORCH DTYPE t5 torch.int64 -#FN EXECUTION FAILED: -# Constructed by Delete Last Used (took 0 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def backward_fn(saved_for_backward, cotangents): - # saved_for_backward: "Collection" - # cotangents: "Collection" - C0, _, = saved_for_backward - del saved_for_backward - t18, = cotangents - del cotangents - key, query, t0, t1, t2, t3, t4, t5, value, = C0 - del C0 - (t17, t19, t20) = sdpafx_scaled_dot_product_efficient_attention_backward(t18, query, key, value, t0, t1, t2, t3, 6, 6, 0.0, False, t4, t5, scale=None) - del t18, query, key, value, t0, t1, t2, t3, t4, t5 - [t27, t31, t35] = nvFusion1(t17, t19, t20) - del t17, t19, t20 - return t27, t31, t35 -Traceback (most recent call last): - File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 505, in benchmark_trace - t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) - File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 346, in compute_time_cost_ms - raise e - File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 313, in compute_time_cost_ms - fn(*args) - File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context - return func(*args, **kwargs) - File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast - return func(*args, **kwargs) - File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast - return func(*args, **kwargs) - File "thunder.backward_fn_50", line 16, in backward_fn - (t17, t19, t20) = sdpafx_scaled_dot_product_efficient_attention_backward(t18, query, key, value, t0, t1, t2, t3, 6, 6, 0.0, False, t4, t5, scale=None) - File "/workspace/workdir/thunder/executors/sdpaex.py", line 453, in _scaled_dot_product_flash_attention_backward_impl - grads = torch.ops.aten._scaled_dot_product_flash_attention_backward( - File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1127, in __call__ - return self._op(*args, **(kwargs or {})) - File "/usr/local/lib/python3.10/dist-packages/torch/utils/_device.py", line 79, in __torch_function__ - return func(*args, **kwargs) - File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1127, in __call__ - return self._op(*args, **(kwargs or {})) -RuntimeError: cu_seqlens_q must have dtype int32 - -#FW after trace call -._trace at 0x7cdd9ee3aef0> -TORCH DTYPE t2 torch.int64 -TORCH DTYPE t3 torch.int64 -TORCH DTYPE t4 torch.int64 -TORCH DTYPE t5 torch.int64 -#FN EXECUTION FAILED: -# Constructed by Delete Last Used (took 0 milliseconds) -from torch import Tensor -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def backward_fn(saved_for_backward, cotangents): - # saved_for_backward: "Collection" - # cotangents: "Collection" - C0, _, = saved_for_backward - del saved_for_backward - t18, = cotangents - del cotangents - key, query, t0, t1, t2, t3, t4, t5, value, = C0 - del C0 - (t17, t19, t20) = sdpafx_scaled_dot_product_efficient_attention_backward(t18, query, key, value, t0, t1, t2, t3, 6, 6, 0.0, False, t4, t5, scale=None) - del t18, query, key, value, t0, t1, t2, t3, t4, t5 - t24 = Tensor.to(t17, torch.float32, copy=True) # t24: "cuda:0 f32[4, 128, 6, 64]" - del t17 - t28 = Tensor.to(t19, torch.float32, copy=True) # t28: "cuda:0 f32[4, 128, 6, 64]" - del t19 - t32 = Tensor.to(t20, torch.float32, copy=True) # t32: "cuda:0 f32[4, 128, 6, 64]" - del t20 - t26 = torch.add(t24, t24) # t26: "cuda:0 f32[4, 128, 6, 64]" - del t24 - t30 = torch.add(t28, t28) # t30: "cuda:0 f32[4, 128, 6, 64]" - del t28 - t34 = torch.add(t32, t32) # t34: "cuda:0 f32[4, 128, 6, 64]" - del t32 - t27 = Tensor.to(t26, torch.bfloat16, copy=True) # t27: "cuda:0 bf16[4, 128, 6, 64]" - del t26 - t31 = Tensor.to(t30, torch.bfloat16, copy=True) # t31: "cuda:0 bf16[4, 128, 6, 64]" - del t30 - t35 = Tensor.to(t34, torch.bfloat16, copy=True) # t35: "cuda:0 bf16[4, 128, 6, 64]" - del t34 - return t27, t31, t35 -Traceback (most recent call last): - File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 505, in benchmark_trace - t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) - File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 346, in compute_time_cost_ms - raise e - File "/workspace/workdir/thunder/backend_optimizer/utils.py", line 313, in compute_time_cost_ms - fn(*args) - File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context - return func(*args, **kwargs) - File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast - return func(*args, **kwargs) - File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast - return func(*args, **kwargs) - File "thunder.backward_fn_51", line 17, in backward_fn - (t17, t19, t20) = sdpafx_scaled_dot_product_efficient_attention_backward(t18, query, key, value, t0, t1, t2, t3, 6, 6, 0.0, False, t4, t5, scale=None) - File "/workspace/workdir/thunder/executors/sdpaex.py", line 453, in _scaled_dot_product_flash_attention_backward_impl - grads = torch.ops.aten._scaled_dot_product_flash_attention_backward( - File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1127, in __call__ - return self._op(*args, **(kwargs or {})) - File "/usr/local/lib/python3.10/dist-packages/torch/utils/_device.py", line 79, in __torch_function__ - return func(*args, **kwargs) - File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 1127, in __call__ - return self._op(*args, **(kwargs or {})) -RuntimeError: cu_seqlens_q must have dtype int32 - -================================================================================ Autotune: ================================================================================ Before Autotune Tuning: Failed to place sdpa: Failed to get best time placement -Traceback (most recent call last): - File "/workspace/workdir/thunder/executors/torch_autograd.py", line 458, in split_forward_backward - primal_trace, fw_extrace, bw_extrace, fw_traces, bw_traces = split() - File "/workspace/workdir/thunder/executors/torch_autograd.py", line 248, in split - fw_extrace, bw_extrace = autotune_transform_for_execution( - File "/workspace/workdir/thunder/executors/passes.py", line 164, in autotune_transform_for_execution - optimizer_context.optimize() - File "/workspace/workdir/thunder/backend_optimizer/optimizer.py", line 1124, in optimize - self.optimizer.optimize() - File "/workspace/workdir/thunder/backend_optimizer/optimizer.py", line 974, in optimize - _optimize() - File "/workspace/workdir/thunder/backend_optimizer/optimizer.py", line 887, in _optimize - self._search_candidates() - File "/workspace/workdir/thunder/backend_optimizer/optimizer.py", line 827, in _search_candidates - _search(ex) - File "/workspace/workdir/thunder/backend_optimizer/optimizer.py", line 717, in _search - raise AssertionError("Failed to get best time placement") -AssertionError: Failed to get best time placement - -================================================================================ Autotune: ================================================================================ Before Autotune Tuning: Testing compile data executors: [thunder.extend.OperatorExecutor('cudnn'), thunder.extend.OperatorExecutor('torch'), thunder.extend.OperatorExecutor('nvfuser'), thunder.extend.OperatorExecutor('torchcompile'), thunder.extend.OperatorExecutor('python')] -#FW after trace call -._trace at 0x7cdd4a022950> -#FW after trace call -._trace at 0x7cdd5bef48b0> -#FW after trace call -._trace at 0x7cdd5bef48b0> -#FW after trace call -._trace at 0x7cdd4a022830> -#FW after trace call -._trace at 0x7cdd9ee3bb50> -================================================================================ Autotune: Executors: -================================================================================ Autotune: cudnn -> is operator = True, is fusion = False -================================================================================ Autotune: torch -> is operator = True, is fusion = False -================================================================================ Autotune: nvfuser -> is operator = False, is fusion = True -================================================================================ Autotune: torchcompile -> is operator = False, is fusion = True -================================================================================ Autotune: python -> is operator = True, is fusion = False -================================================================================ Autotune: New forward trace to optimize (strat = OptimizerType.RUNTIME): -# Constructed by Dead Code Elimination (took 0 milliseconds) -import thunder -import thunder.core.dtypes as dtypes -import thunder.core.prims as prims -import thunder.torch as ltorch -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(query, key, value): - # query: "cuda:0 bf16[4, 128, 6, 64]" - # key: "cuda:0 bf16[4, 128, 6, 64]" - # value: "cuda:0 bf16[4, 128, 6, 64]" - (t0, t1, t2, t3) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) - (t4, t5, t6, t7) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) - t8 = prims.convert_element_type(t0, dtypes.float32) # t8: "cuda:0 f32[4, 128, 6, 64]" - t9 = prims.convert_element_type(t4, dtypes.float32) # t9: "cuda:0 f32[4, 128, 6, 64]" - t10 = ltorch.add(t8, t9, alpha=None) # t10: "cuda:0 f32[4, 128, 6, 64]" - # t10 = prims.add(t8, t9) # t10: "cuda:0 f32[4, 128, 6, 64]" - t11 = prims.convert_element_type(t10, dtypes.bfloat16) # t11: "cuda:0 bf16[4, 128, 6, 64]" - return {'output': t11, 'flat_args': [query, key, value], 'flat_output': (t11,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t6, t7), ()) -================================================================================ Autotune: Searching best placement for fusion executor = nvfuser -================================================================================ Autotune: Searching best placement for fusion executor = torchcompile -================================================================================ Autotune: Benchmark fw end: Trace = [nvfuser] (time = 0.010593599872663617 ms, mem = 0.002941131591796875 GB)": -# Constructed by Transform for operator executor execution (took 0 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(query, key, value): - # query: "cuda:0 bf16[4, 128, 6, 64]" - # key: "cuda:0 bf16[4, 128, 6, 64]" - # value: "cuda:0 bf16[4, 128, 6, 64]" - (t0, t1, t2, t3) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) - [t11] = nvFusion0(t0) - # t8 = prims.convert_element_type(t0, dtypes.float32) # t8: "cuda:0 f32[4, 128, 6, 64]" - # t10 = prims.add(t8, t8) # t10: "cuda:0 f32[4, 128, 6, 64]" - # t11 = prims.convert_element_type(t10, dtypes.bfloat16) # t11: "cuda:0 bf16[4, 128, 6, 64]" - return {'output': t11, 'flat_args': [query, key, value], 'flat_output': (t11,)}, ((query, key, value, t0, t1, t2, t3, t0, t1, t2, t3), ()) -================================================================================ Autotune: Benchmark fw end: Trace = [nvfuser] (time = 0.010647999914363026 ms, mem = 0.002941131591796875 GB)": -# Constructed by Transform for operator executor execution (took 0 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(query, key, value): - # query: "cuda:0 bf16[4, 128, 6, 64]" - # key: "cuda:0 bf16[4, 128, 6, 64]" - # value: "cuda:0 bf16[4, 128, 6, 64]" - (t0, t1, t2, t3) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) - [t11] = nvFusion0(t0) - # t8 = prims.convert_element_type(t0, dtypes.float32) # t8: "cuda:0 f32[4, 128, 6, 64]" - # t10 = prims.add(t8, t8) # t10: "cuda:0 f32[4, 128, 6, 64]" - # t11 = prims.convert_element_type(t10, dtypes.bfloat16) # t11: "cuda:0 bf16[4, 128, 6, 64]" - return {'output': t11, 'flat_args': [query, key, value], 'flat_output': (t11,)}, ((query, key, value, t0, t1, t2, t3, t0, t1, t2, t3), ()) -================================================================================ Autotune: Caching fw candidate [compile option: None] -# Constructed by Transform for operator executor execution (took 0 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(query, key, value): - # query: "cuda:0 bf16[4, 128, 6, 64]" - # key: "cuda:0 bf16[4, 128, 6, 64]" - # value: "cuda:0 bf16[4, 128, 6, 64]" - (t0, t1, t2, t3) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) - [t11] = nvFusion0(t0) - # t8 = prims.convert_element_type(t0, dtypes.float32) # t8: "cuda:0 f32[4, 128, 6, 64]" - # t10 = prims.add(t8, t8) # t10: "cuda:0 f32[4, 128, 6, 64]" - # t11 = prims.convert_element_type(t10, dtypes.bfloat16) # t11: "cuda:0 bf16[4, 128, 6, 64]" - return {'output': t11, 'flat_args': [query, key, value], 'flat_output': (t11,)}, ((query, key, value, t0, t1, t2, t3, t0, t1, t2, t3), ()) -================================================================================ Autotune: Caching fw candidate [compile option: None] -# Constructed by Transform for operator executor execution (took 0 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(query, key, value): - # query: "cuda:0 bf16[4, 128, 6, 64]" - # key: "cuda:0 bf16[4, 128, 6, 64]" - # value: "cuda:0 bf16[4, 128, 6, 64]" - (t0, t1, t2, t3) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) - [t11] = nvFusion0(t0) - # t8 = prims.convert_element_type(t0, dtypes.float32) # t8: "cuda:0 f32[4, 128, 6, 64]" - # t10 = prims.add(t8, t8) # t10: "cuda:0 f32[4, 128, 6, 64]" - # t11 = prims.convert_element_type(t10, dtypes.bfloat16) # t11: "cuda:0 bf16[4, 128, 6, 64]" - return {'output': t11, 'flat_args': [query, key, value], 'flat_output': (t11,)}, ((query, key, value, t0, t1, t2, t3, t0, t1, t2, t3), ()) -================================================================================ Autotune: End fw time mem pair -================================================================================ Autotune: Benchmark fw end: Trace = [torchcompile] (time = 0.018396800477057697 ms, mem = 0.00331878662109375 GB)": -# Constructed by Transform for operator executor execution (took 0 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(query, key, value): - # query: "cuda:0 bf16[4, 128, 6, 64]" - # key: "cuda:0 bf16[4, 128, 6, 64]" - # value: "cuda:0 bf16[4, 128, 6, 64]" - (t0, t1, t2, t3) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) - (t4, t5, t6, t7) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) - [t11] = TorchCompile0(t0, t4) - # t8 = prims.convert_element_type(t0, dtypes.float32) # t8: "cuda:0 f32[4, 128, 6, 64]" - # t9 = prims.convert_element_type(t4, dtypes.float32) # t9: "cuda:0 f32[4, 128, 6, 64]" - # t10 = ltorch.add(t8, t9, alpha=None) # t10: "cuda:0 f32[4, 128, 6, 64]" - # t10 = prims.add(t8, t9) # t10: "cuda:0 f32[4, 128, 6, 64]" - # t11 = prims.convert_element_type(t10, dtypes.bfloat16) # t11: "cuda:0 bf16[4, 128, 6, 64]" - return {'output': t11, 'flat_args': [query, key, value], 'flat_output': (t11,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t6, t7), ()) -================================================================================ Autotune: Benchmark fw end: Trace = [torchcompile] (time = 0.01832640040665865 ms, mem = 0.00331878662109375 GB)": -# Constructed by Transform for operator executor execution (took 0 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(query, key, value): - # query: "cuda:0 bf16[4, 128, 6, 64]" - # key: "cuda:0 bf16[4, 128, 6, 64]" - # value: "cuda:0 bf16[4, 128, 6, 64]" - (t0, t1, t2, t3) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) - (t4, t5, t6, t7) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) - [t11] = TorchCompile0(t0, t4) - # t8 = prims.convert_element_type(t0, dtypes.float32) # t8: "cuda:0 f32[4, 128, 6, 64]" - # t9 = prims.convert_element_type(t4, dtypes.float32) # t9: "cuda:0 f32[4, 128, 6, 64]" - # t10 = ltorch.add(t8, t9, alpha=None) # t10: "cuda:0 f32[4, 128, 6, 64]" - # t10 = prims.add(t8, t9) # t10: "cuda:0 f32[4, 128, 6, 64]" - # t11 = prims.convert_element_type(t10, dtypes.bfloat16) # t11: "cuda:0 bf16[4, 128, 6, 64]" - return {'output': t11, 'flat_args': [query, key, value], 'flat_output': (t11,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t6, t7), ()) -================================================================================ Autotune: Caching fw candidate [compile option: None] -# Constructed by Transform for operator executor execution (took 0 milliseconds) -import torch -from thunder.executors.torchex import no_autocast - -@torch.no_grad() -@no_autocast -def augmented_forward_fn(query, key, value): - # query: "cuda:0 bf16[4, 128, 6, 64]" - # key: "cuda:0 bf16[4, 128, 6, 64]" - # value: "cuda:0 bf16[4, 128, 6, 64]" - (t0, t1, t2, t3) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) - (t4, t5, t6, t7) = cudnn_sdpa_fwd(query, key, value, None, 0.0, False, scale=None) - [t11] = TorchCompile0(t0, t4) - # t8 = prims.convert_element_type(t0, dtypes.float32) # t8: "cuda:0 f32[4, 128, 6, 64]" - # t9 = prims.convert_element_type(t4, dtypes.float32) # t9: "cuda:0 f32[4, 128, 6, 64]" - # t10 = ltorch.add(t8, t9, alpha=None) # t10: "cuda:0 f32[4, 128, 6, 64]" - # t10 = prims.add(t8, t9) # t10: "cuda:0 f32[4, 128, 6, 64]" - # t11 = prims.convert_element_type(t10, dtypes.bfloat16) # t11: "cuda:0 bf16[4, 128, 6, 64]" - return {'output': t11, 'flat_args': [query, key, value], 'flat_output': (t11,)}, ((query, key, value, t0, t1, t2, t3, t4, t5, t6, t7), ()) From 14074c41616511947a66db93d8225cb5deffc5bd Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 23 Aug 2024 15:42:36 +0300 Subject: [PATCH 080/171] Fixed nvsight bench when args need to be cloned as in TE --- thunder/backend_optimizer/utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index c1d7650c80..6d104d9744 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -313,16 +313,20 @@ def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, f torch.cuda.synchronize() # Warm up cycles for _ in range(warm_up_iters): - fn(*args) + cloned_args = clone_args(args) + fn(*cloned_args) + del cloned_args # Benchmark torch.cuda.empty_cache() torch.cuda.synchronize() torch.cuda.cudart().cudaProfilerStart() for i in range(iters): + cloned_args = clone_args(args) torch.cuda.empty_cache() torch.cuda.nvtx.range_push(f"thunder benchmark fn:{nvsight_fn_name}, iter{i}") - fn(*args) + fn(*cloned_args) torch.cuda.nvtx.range_pop() + del cloned_args torch.cuda.cudart().cudaProfilerStop() return float("inf"), float("inf"), None @@ -339,7 +343,7 @@ def clone_args(args): res.append(clone_args(arg)) else: if isinstance(arg, torch.Tensor): - res.append(arg.clone()) + res.append(arg.clone().detach()) else: res.append(arg) return tuple(res) From 706ff91b68962ae2d1501d539ecd7122b628b4e8 Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 23 Aug 2024 17:44:51 +0300 Subject: [PATCH 081/171] Benchmarking TE on llama --- examples/dev/litGPT.py | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 51de4ad613..f9e6612575 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -4,55 +4,65 @@ import thunder import torch +torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul +torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + class Test: - def __init__(self, layers: int, autotune_type: str, batch_size: int) -> None: + def __init__(self, layers: int, autotune_type: str, batch_size: int, seq_len: int) -> None: self.layers = layers self.autotune_type = autotune_type self.batch_size = batch_size + self.seq_len = seq_len -layers = [Test(4, 'runtime', 1), Test(4, 'memory', 1)] +layers = [Test(1, 'runtime', 1, 512)] -model_name = 'Llama-3-8B' +model_name = 'open_llama_3b' for test in layers: try: print('\n\nLayers:', test.layers) cfg = Config.from_name(model_name) + print(cfg) cfg.n_layer = test.layers torch.set_default_dtype(torch.bfloat16) with torch.device('cuda'): model = GPT(cfg) x = torch.randint(1, model.config.vocab_size, (test.batch_size, 512)) - jmodel_def = thunder.jit(model) + jmodel_def = thunder.jit(model, executors=['cudnn', 'nvfuser']) + jmodel_def_te = thunder.jit(model, executors=['cudnn', 'transformer_engine', 'nvfuser']) jmodel_auto = thunder.jit( model, autotune_type=test.autotune_type, - executors=["nvfuser", "torchcompile", "cudnn", "sdpa", "fa3", "transformer_engine"], + executors=["nvfuser", "cudnn", "sdpa", "transformer_engine"], use_cudagraphs=False, ) print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) + print('deviation def_te:', (jmodel_def_te(x) - model(x)).abs().max().item()) print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) iters = 100 - print(f'Results thunder benchmark ({iters} iters):') fw_traces = [ thunder.last_traces(jmodel_def)[-1], + thunder.last_traces(jmodel_def_te)[-1], thunder.last_traces(jmodel_auto)[-1], ] bw_traces = [ thunder.last_backward_traces(jmodel_def)[-1], + thunder.last_backward_traces(jmodel_def_te)[-1], thunder.last_backward_traces(jmodel_auto)[-1], ] - fw_labels = ["fw_def", "fw_auto"] - bw_labels = ["bw_def", "bw_auto"] - thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) + fw_labels = ["fw_def", "fw_def_te", "fw_auto"] + bw_labels = ["bw_def", "bw_def_te", "bw_auto"] + print(f'Results thunder benchmark ({iters} iters):') + thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters, nvsight=True) + thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, 10, nvsight=True) print(f'\n\nResults torch fw bw benchmark ({iters} iters):') - callables = [jmodel_def, jmodel_auto] - labels = ['def', 'auto'] - inputs = [x, x] + callables = [jmodel_def, jmodel_def_te, jmodel_auto] + labels = ['def', 'def_te', 'auto'] + inputs = [x.clone().detach(), x.clone().detach(), x.clone().detach()] torch_fw_bw_benchmark(callables, labels, inputs, iters) print(f'\n\nResults torch total benchmark ({iters} iters):') torch_total_benchmark(callables, labels, inputs, iters) From 4c32776b812479133885e80e37e0a210baf102f3 Mon Sep 17 00:00:00 2001 From: matteochen Date: Sat, 24 Aug 2024 00:31:07 +0300 Subject: [PATCH 082/171] nvmath ex, integrated matmul --- thunder/backend_optimizer/optimizer.py | 23 ++++++++------- thunder/executors/nvmathex.py | 41 ++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 11 deletions(-) create mode 100644 thunder/executors/nvmathex.py diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 6bf8c84482..5f857eca19 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -5,6 +5,7 @@ from thunder.core.proxies import CollectionProxy, FloatProxy, IntegerProxy, TensorProxy from thunder.core.symbol import BoundSymbol from thunder.core.trace import from_trace, TraceCtx +from thunder.core.transforms import construct_trace from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_always_executors from thunder.visualizer.visualizer_helper import Visualizer from typing import Hashable @@ -147,7 +148,7 @@ class LogLevel(Enum): INFO = 1 -log_level: LogLevel = LogLevel.INFO +log_level: LogLevel = LogLevel.DEBUG def log(what: str, level: LogLevel = LogLevel.INFO): @@ -239,7 +240,7 @@ def __init__( self.known_fusion_ex_compile_options: dict[str | Hashable, list[FusionCompileOptionsHelper]] = { "nvfuser": [ FusionCompileOptionsHelper("nv_enable_linear", "linear", PrimIDs.LINEAR, linear, _linear_check), - FusionCompileOptionsHelper("nv_enable_matmul", "matmul", PrimIDs.MATMUL, matmul, _matmul_check), + # FusionCompileOptionsHelper("nv_enable_matmul", "matmul", PrimIDs.MATMUL, matmul, _matmul_check), # FusionCompileOptionsHelper("nv_enable_bookend", "bookend"), ] } @@ -591,6 +592,9 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): else: log(f"Available executors for single region:\n{candidate_executors}", level=LogLevel.DEBUG) + # Define the standalone trace in order to benchmark this symbol + subtrace = construct_trace()(current_bsym.sym, *current_bsym.args, **current_bsym.kwargs) + # Helpers candidate_best_time = BenchmarkResult() candidate_best_mem = BenchmarkResult() @@ -598,15 +602,12 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): # TODO: enable requests for no remat becnhmarks # TODO: we should consider also FusionExecutor that can execute this single bsym in this beam search for i, candidate in enumerate(candidate_executors): - # Match the current candidate to benchmark partial trace - match_bsym_output(current_bsym, [dict_time_strat, dict_mem_strat], candidate) - # Retrieve partial trace and benchmark, apply remat if possible - trc, _, _ = get_placed_trace(dict_time_strat, increasing_symbols) - # Apply fw bw remat - # if self.trace_type == TraceType.BW and self.active_fw_trace_ctx[0] is not None: - # _, trc = rematerialize_forward_and_backward(self.active_fw_trace_ctx[0], trc) - # Now, benchmark - t, m, _ = benchmark_trace(trc, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0]) + + from thunder.common import transform_for_execution + subtrace_placed = transform_for_execution(subtrace, executors_list=[candidate])[-1] + log(f"Subtrace to benchmark single symbol:\n{subtrace_placed}", level=LogLevel.DEBUG) + t, m, _ = benchmark_trace(subtrace_placed, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0]) + log(f'Operator excutor [{candidate.name}] candidate perf: {t} ms {m/(2**30)} GB', level=LogLevel.DEBUG) # Update results if t < candidate_best_time.runtime: candidate_best_time = BenchmarkResult(time=t, index=i) diff --git a/thunder/executors/nvmathex.py b/thunder/executors/nvmathex.py new file mode 100644 index 0000000000..cea6c4e286 --- /dev/null +++ b/thunder/executors/nvmathex.py @@ -0,0 +1,41 @@ +from thunder import TensorProxy +from thunder.core.prims import PrimIDs +import nvmath +import thunder +import thunder.torch as ltorch +import torch + +nvmath_ex = thunder.extend.OperatorExecutor('nvmath', version='0.1.0') +thunder.extend.register_executor(nvmath_ex) + +def _nvmath_linalg_advanced_matmul_impl(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return nvmath.linalg.advanced.matmul(a, b) + +def _nvmath_linalg_advanced_matmul_checker(a: TensorProxy, b: TensorProxy) -> bool: + if len(a.shape) < 2 or len(b.shape) < 2: + return False + if a.shape[-1] != b.shape[-2]: + return False + if a.device != b.device: + return False + if a.dtype != b.dtype: + return False + # Handle distribuited + return True + +nvmath_linalg_advanced_matmul = nvmath_ex.register_operator( + "nvmath_linalg_advanced_matmul", + like=ltorch.matmul, + fn=_nvmath_linalg_advanced_matmul_impl, +) +nvmath_ex.register_implementation( + ltorch.matmul, + nvmath_linalg_advanced_matmul, + checker=_nvmath_linalg_advanced_matmul_checker +) + +nvmath_ex.register_implementation( + PrimIDs.MATMUL, + nvmath_linalg_advanced_matmul, + checker=_nvmath_linalg_advanced_matmul_checker +) From 0b9c5fe870f0f00a4bf8df8f084f04ad40ccaca7 Mon Sep 17 00:00:00 2001 From: matteochen Date: Sat, 24 Aug 2024 12:03:58 +0300 Subject: [PATCH 083/171] Removed old counter --- thunder/executors/nvfuserex_impl.py | 4 +--- thunder/executors/torch_compile.py | 3 +-- thunder/extend/__init__.py | 12 ------------ 3 files changed, 2 insertions(+), 17 deletions(-) diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index e4967cc913..befb4e16bd 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -766,9 +766,7 @@ def _can_fuse_node(n: Node): bound_symbol_groups = fuse_bound_symbols(trace, _should_fuse) - # Counts how many fusions (per executor) have been constructed - # (Used to name fusions like nvFusion0, nvFusion1, ...) - fusion_counter: int = self.count_fusion_regions(trace, nvFuserExecutor) + fusion_counter = 0 for bsyms in bound_symbol_groups: # TODO The following allows generating single node fusions, which # may be suboptimal for real-world performance. diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index 6247a0d272..9b4e86b980 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -155,8 +155,7 @@ def _can_fuse_node(n: Node): bound_symbol_groups = fuse_bound_symbols(trace, _should_fuse) fused_bsyms = [] - # Counts how many fusions (per executor) have been constructed - fusion_counter: int = self.count_fusion_regions(trace, TorchCompileExecutor) + fusion_counter = 0 for bsyms in bound_symbol_groups: if len(bsyms) == 1: bsym: BoundSymbol = bsyms[0] diff --git a/thunder/extend/__init__.py b/thunder/extend/__init__.py index 4c5bed0a1f..c2fbae97fa 100644 --- a/thunder/extend/__init__.py +++ b/thunder/extend/__init__.py @@ -193,18 +193,6 @@ def _bind_postprocess(bsym: BoundSymbol) -> None: sym = Symbol(name=name, meta=_meta, is_fusion=True, _bind_postprocess=_bind_postprocess, executor=self) return sym.bind(*inputs, output=outputs) - # If a trace comes in with already placed fusion region we have to updated the initial counter (see derived class) - def count_fusion_regions(self, trace_in: TraceCtx, ex_type: type) -> int: - count = 0 - for bsym in trace_in.bound_symbols: - if not isinstance(bsym, BoundSymbol): - raise AssertionError(f"Expected a BoundSymbol, got: {type(bsym)}") - if type(bsym.sym.executor) is ex_type: - # if isinstance(bsym.sym.executor, FusionExecutor): - count += 1 - # ex.fuseion_pass regions are zero indexed - return max(0, count) - class OperatorExecutor(Executor): def __init__(self, name: Hashable, *, version: None | Any = None): From 511fe4bbe9f6fe49038b6e193dbf7b2745f7b12b Mon Sep 17 00:00:00 2001 From: matteochen Date: Sat, 24 Aug 2024 12:22:32 +0300 Subject: [PATCH 084/171] Restored imports --- thunder/core/codeutils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/thunder/core/codeutils.py b/thunder/core/codeutils.py index 2442c4ede0..d1c6d30406 100644 --- a/thunder/core/codeutils.py +++ b/thunder/core/codeutils.py @@ -5,12 +5,15 @@ from collections.abc import Mapping, Sequence, Iterable import inspect from inspect import Parameter +import string import functools from functools import partial import dis import linecache import dataclasses +import torch + import thunder.core.baseutils as baseutils from thunder.core.baseutils import ProxyInterface, check import thunder.core.dtypes as dtypes From 8cce6dfba437a7a5d8cf5cae82383155f96ea84f Mon Sep 17 00:00:00 2001 From: matteochen Date: Sat, 24 Aug 2024 12:24:29 +0300 Subject: [PATCH 085/171] Restored line order --- thunder/core/prims.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/core/prims.py b/thunder/core/prims.py index 70d48c690e..a4c897c57b 100644 --- a/thunder/core/prims.py +++ b/thunder/core/prims.py @@ -3841,8 +3841,8 @@ def check_sequence(seq, seq_str_name, rank, *, min_val): utils.check(len(seq) == 1 or len(seq) == rank, lambda: f"len({seq_str_name}) should be either 1 or {rank}") - for i, e in enumerate(seq): # Check all elements are >= min_val + for i, e in enumerate(seq): utils.check( isinstance(e, (int, IntegerProxy)) and e >= min_val, lambda: f"all elements in {seq_str_name} should be integers at least {min_val}, " From cc42db2c056f7c024a776180494808ea3457127c Mon Sep 17 00:00:00 2001 From: matteochen Date: Sat, 24 Aug 2024 16:34:40 +0300 Subject: [PATCH 086/171] Updated torch_compile_ex to synch with main --- thunder/executors/torch_compile.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/thunder/executors/torch_compile.py b/thunder/executors/torch_compile.py index 9b4e86b980..204bc2046d 100644 --- a/thunder/executors/torch_compile.py +++ b/thunder/executors/torch_compile.py @@ -7,11 +7,12 @@ from lightning_utilities import compare_version from thunder.core import prims, utils -from thunder.core.proxies import Proxy, unvariableify, Variable +from thunder.core.proxies import Proxy, TensorProxy, unvariableify, Variable from thunder.core.rematerialization import rematerialize from thunder.core.symbol import BoundSymbol, Symbol from thunder.core.trace import from_trace, TraceCtx, TraceProvenance from thunder.core.transform_common import dce +from thunder.core.pytree import tree_flatten from thunder.executors.passes import update_fusion_call_ctx from thunder.executors.utils import Region from thunder.extend import FusionExecutor, register_executor, ImplInfo @@ -155,7 +156,8 @@ def _can_fuse_node(n: Node): bound_symbol_groups = fuse_bound_symbols(trace, _should_fuse) fused_bsyms = [] - fusion_counter = 0 + # Counts how many fusions (per executor) have been constructed + fusion_counter: int = 0 for bsyms in bound_symbol_groups: if len(bsyms) == 1: bsym: BoundSymbol = bsyms[0] @@ -189,6 +191,16 @@ def _can_fuse_node(n: Node): from thunder.executors.torchex import ex as pytorch_ex +def cuda_device_checker(*args, **kwargs): + # We only want to compile if all the TensorProxy arguments are on the GPU + flat_args, _ = tree_flatten((args, kwargs)) + flat_tensorproxy_args = [x for x in flat_args if isinstance(x, TensorProxy)] + for arg in flat_tensorproxy_args: + if arg.device.type != "cuda": + return False + return True + + # NOTE: [torch_compile_cat_ex vs torch_compile_ex] # The former only relies on `torch.compile` for the operators where it shines the most and is meant to be used # together with the nvfuser executor. Its current goal is only to fuse RoPE but the set of ops fused will change as each @@ -199,14 +211,13 @@ def _can_fuse_node(n: Node): required_ops = { "torch.cat", prims.cat.id, - prims.pad.id, - prims.slice_prim.id, } torch_compile_cat_ex = TorchCompileExecutor(name="torchcompile_cat", required_ops=required_ops) register_executor(torch_compile_cat_ex) # TODO: Carefully enable more ops checking that they do improve performance supported_ops = { "torch.split", + "torch.sum", prims.add.id, prims.broadcast_in_dim.id, prims.cat.id, @@ -219,7 +230,9 @@ def _can_fuse_node(n: Node): prims.slice_prim.id, prims.transpose.id, } -torch_compile_cat_ex._implmap = {op: ImplInfo() for op in pytorch_ex.implmap if op in supported_ops} +torch_compile_cat_ex._implmap = { + op: ImplInfo(checker=cuda_device_checker) for op in pytorch_ex.implmap if op in supported_ops +} torch_compile_ex = TorchCompileExecutor(name="torchcompile") From c6bcdcca95acee7e66b9a0118546b7dd3adbeee5 Mon Sep 17 00:00:00 2001 From: matteochen Date: Sat, 24 Aug 2024 23:01:20 +0300 Subject: [PATCH 087/171] Fixed print --- examples/dev/litGPT.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index f9e6612575..46136c0db2 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -22,7 +22,6 @@ def __init__(self, layers: int, autotune_type: str, batch_size: int, seq_len: in try: print('\n\nLayers:', test.layers) cfg = Config.from_name(model_name) - print(cfg) cfg.n_layer = test.layers torch.set_default_dtype(torch.bfloat16) with torch.device('cuda'): From 67ee05540f3107643bf48fbf1d7fc5f350a51424 Mon Sep 17 00:00:00 2001 From: matteochen Date: Sat, 24 Aug 2024 23:27:33 +0300 Subject: [PATCH 088/171] Skipping single trace region candidate --- thunder/backend_optimizer/optimizer.py | 38 +++++++++++++++----------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 5f857eca19..3d94b60a47 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -148,7 +148,7 @@ class LogLevel(Enum): INFO = 1 -log_level: LogLevel = LogLevel.DEBUG +log_level: LogLevel = LogLevel.INFO def log(what: str, level: LogLevel = LogLevel.INFO): @@ -598,21 +598,27 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): # Helpers candidate_best_time = BenchmarkResult() candidate_best_mem = BenchmarkResult() - # Search for best candidate, by default remat will be called to find the optimal choice - # TODO: enable requests for no remat becnhmarks - # TODO: we should consider also FusionExecutor that can execute this single bsym in this beam search - for i, candidate in enumerate(candidate_executors): - - from thunder.common import transform_for_execution - subtrace_placed = transform_for_execution(subtrace, executors_list=[candidate])[-1] - log(f"Subtrace to benchmark single symbol:\n{subtrace_placed}", level=LogLevel.DEBUG) - t, m, _ = benchmark_trace(subtrace_placed, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0]) - log(f'Operator excutor [{candidate.name}] candidate perf: {t} ms {m/(2**30)} GB', level=LogLevel.DEBUG) - # Update results - if t < candidate_best_time.runtime: - candidate_best_time = BenchmarkResult(time=t, index=i) - if m < candidate_best_mem.memory: - candidate_best_mem = BenchmarkResult(memory=m, index=i) + + # No choices + if len(candidate_executors) == 1: + candidate_best_time = BenchmarkResult(index=0) + candidate_best_mem = BenchmarkResult(index=0) + else: + # Search for best candidate, by default remat will be called to find the optimal choice + # TODO: enable requests for no remat becnhmarks + # TODO: we should consider also FusionExecutor that can execute this single bsym in this beam search + for i, candidate in enumerate(candidate_executors): + + from thunder.common import transform_for_execution + subtrace_placed = transform_for_execution(subtrace, executors_list=[candidate])[-1] + log(f"Subtrace to benchmark single symbol:\n{subtrace_placed}", level=LogLevel.DEBUG) + t, m, _ = benchmark_trace(subtrace_placed, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0]) + log(f'Operator excutor [{candidate.name}] candidate perf: {t} ms {m/(2**30)} GB', level=LogLevel.DEBUG) + # Update results + if t < candidate_best_time.runtime: + candidate_best_time = BenchmarkResult(time=t, index=i) + if m < candidate_best_mem.memory: + candidate_best_mem = BenchmarkResult(memory=m, index=i) if candidate_best_time.index == -1 or candidate_best_mem.index == -1: raise AssertionError( From 494fa73ca9242192838ac9a775e49a70bdbf463e Mon Sep 17 00:00:00 2001 From: matteochen Date: Mon, 26 Aug 2024 09:06:10 +0300 Subject: [PATCH 089/171] Debug for single trace regions --- thunder/backend_optimizer/optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 3d94b60a47..70885b3278 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -611,9 +611,9 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): from thunder.common import transform_for_execution subtrace_placed = transform_for_execution(subtrace, executors_list=[candidate])[-1] - log(f"Subtrace to benchmark single symbol:\n{subtrace_placed}", level=LogLevel.DEBUG) + log(f"Subtrace to benchmark single symbol:\n{subtrace_placed}", level=LogLevel.INFO) t, m, _ = benchmark_trace(subtrace_placed, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0]) - log(f'Operator excutor [{candidate.name}] candidate perf: {t} ms {m/(2**30)} GB', level=LogLevel.DEBUG) + log(f'Operator excutor [{candidate.name}] candidate perf: {t} ms {m/(2**30)} GB', level=LogLevel.INFO) # Update results if t < candidate_best_time.runtime: candidate_best_time = BenchmarkResult(time=t, index=i) From ed6013a616573bddc693e566e3e4eb95b38afdab Mon Sep 17 00:00:00 2001 From: matteochen Date: Mon, 26 Aug 2024 09:07:16 +0300 Subject: [PATCH 090/171] Print if no nvsight --- thunder/benchmarks/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/thunder/benchmarks/utils.py b/thunder/benchmarks/utils.py index 00c7ac844c..04b66e3d2f 100644 --- a/thunder/benchmarks/utils.py +++ b/thunder/benchmarks/utils.py @@ -160,10 +160,12 @@ def thunder_fw_bw_benchmark(fw_traces: list, bw_traces: list, fw_labels: list, b assert(len(fw_traces) == len(bw_traces) == len(fw_labels) == len(bw_labels)) for trc, label in zip(fw_traces, fw_labels): c, m, _ = benchmark_trace(trc, apply_del_last_used=False, snapshot=True, snapshot_name=label, iters=iters, nvsight=nvsight, nvsight_fn_name=label) - print(f'Executing {label} trace:\n{c} ms, {m / (2**30)} GB') + if not nvsight: + print(f'Executing {label} trace:\n{c} ms, {m / (2**30)} GB') i = 0 for trc, label in zip(bw_traces, bw_labels): c, m, _ = benchmark_trace(trc, apply_del_last_used=False, snapshot=True, snapshot_name=label, iters=iters, nvsight=nvsight, nvsight_fn_name=label, fw_trace=fw_traces[i]) - print(f'Executing {label} trace:\n{c} ms, {m / (2**30)} GB') + if not nvsight: + print(f'Executing {label} trace:\n{c} ms, {m / (2**30)} GB') i += 1 From c6270fbafa082ca5f51d64efcf1bbe46899f81ff Mon Sep 17 00:00:00 2001 From: matteochen Date: Mon, 26 Aug 2024 09:08:37 +0300 Subject: [PATCH 091/171] Updated litgpt runner --- examples/dev/litGPT.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 46136c0db2..a322d3fdf2 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -8,60 +8,60 @@ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn class Test: - def __init__(self, layers: int, autotune_type: str, batch_size: int, seq_len: int) -> None: + def __init__(self, layers: int, autotune_type: str, batch_size: int, seq_len: int = -1, model_name: str = 'Llama-3-8B') -> None: self.layers = layers self.autotune_type = autotune_type self.batch_size = batch_size self.seq_len = seq_len + self.model_name = model_name -layers = [Test(1, 'runtime', 1, 512)] - -model_name = 'open_llama_3b' +layers = [Test(1, 'runtime', 1), Test(1, 'runtime', 1, model_name='Llama-2-7b-hf')] for test in layers: try: print('\n\nLayers:', test.layers) - cfg = Config.from_name(model_name) + cfg = Config.from_name(test.model_name) cfg.n_layer = test.layers + if test.seq_len != -1: + cfg.block_size = test.seq_len torch.set_default_dtype(torch.bfloat16) with torch.device('cuda'): model = GPT(cfg) - x = torch.randint(1, model.config.vocab_size, (test.batch_size, 512)) + x = torch.randint(1, model.config.vocab_size, (test.batch_size, cfg.block_size)) + print(f'Input size: {x.size()}') - jmodel_def = thunder.jit(model, executors=['cudnn', 'nvfuser']) - jmodel_def_te = thunder.jit(model, executors=['cudnn', 'transformer_engine', 'nvfuser']) + jmodel_def = thunder.jit(model) + from thunder.executors.nvmathex import nvmath_ex jmodel_auto = thunder.jit( model, autotune_type=test.autotune_type, - executors=["nvfuser", "cudnn", "sdpa", "transformer_engine"], + executors=["cudnn", "sdpa", "fa3", "nvfuser", "torchcompile", nvmath_ex], use_cudagraphs=False, ) print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) - print('deviation def_te:', (jmodel_def_te(x) - model(x)).abs().max().item()) print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) iters = 100 fw_traces = [ thunder.last_traces(jmodel_def)[-1], - thunder.last_traces(jmodel_def_te)[-1], thunder.last_traces(jmodel_auto)[-1], ] bw_traces = [ thunder.last_backward_traces(jmodel_def)[-1], - thunder.last_backward_traces(jmodel_def_te)[-1], thunder.last_backward_traces(jmodel_auto)[-1], ] - fw_labels = ["fw_def", "fw_def_te", "fw_auto"] - bw_labels = ["bw_def", "bw_def_te", "bw_auto"] + fw_labels = ["fw_def", "fw_auto"] + bw_labels = ["bw_def", "bw_auto"] print(f'Results thunder benchmark ({iters} iters):') thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters, nvsight=True) thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, 10, nvsight=True) + print(test.model_name) print(f'\n\nResults torch fw bw benchmark ({iters} iters):') - callables = [jmodel_def, jmodel_def_te, jmodel_auto] - labels = ['def', 'def_te', 'auto'] - inputs = [x.clone().detach(), x.clone().detach(), x.clone().detach()] + callables = [jmodel_def, jmodel_auto] + labels = ['def', 'auto'] + inputs = [x, x] torch_fw_bw_benchmark(callables, labels, inputs, iters) print(f'\n\nResults torch total benchmark ({iters} iters):') torch_total_benchmark(callables, labels, inputs, iters) @@ -79,3 +79,5 @@ def __init__(self, layers: int, autotune_type: str, batch_size: int, seq_len: in print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') except Exception as e: print(f'Test failed:\n{e}') + import traceback + traceback.print_exc() From 656e4e3fbf72ebf2ccaf5a691952972c03d3f9fc Mon Sep 17 00:00:00 2001 From: matteochen Date: Mon, 26 Aug 2024 10:25:00 +0300 Subject: [PATCH 092/171] Fixed cd assertion and print --- thunder/core/vjp_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index 112b4d23ba..c33909d34b 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -183,9 +183,8 @@ def bw_fn(*args, **kwargs): return fw_fn, bw_fn, augmented_forward_trace, backward_trace cd = get_compile_data() - assert cd # No autotuning - if not cd.compile_options.get('autotune_type', None): + if not cd or not cd.compile_options.get('autotune_type', None): return _make_aug_forward_and_backward() # This search will be performed on the requested executors list @@ -213,7 +212,7 @@ def bw_fn(*args, **kwargs): backends = get_fw_bw_split_backends_options(bsym) if not backends: raise AssertionError( - f"No enabled backends found for {bsym.sym.name} but an executor for that symbol it is present in the executors list. Either remove that from the executors list or enable at least one backend for {bsym.sym.linear} inside 'get_fw_bw_split_backends_options'." + f"No enabled backends found for {bsym.sym.name} but an executor for that symbol it is present in the executors list. Either remove that from the executors list or enable at least one backend for {bsym.sym.name} inside 'get_fw_bw_split_backends_options'." ) cached_executors_list = list(cd.executors_list) From 9468802de7db094ddbc59bdfbfc26b3fc287b262 Mon Sep 17 00:00:00 2001 From: matteochen Date: Mon, 26 Aug 2024 11:38:57 +0300 Subject: [PATCH 093/171] Fixed cached update and restore missing args check --- thunder/core/vjp_utils.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index c33909d34b..87f4832b6e 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -51,7 +51,7 @@ def make_aug_forward_and_backward(bsym: BoundSymbol) -> tuple[Callable, Callable from thunder.common import _make_cache_key from thunder.core.transforms import _get_gradfn_and_executor, eval_trace - def _make_aug_forward_and_backward(return_traces = False) -> tuple[Callable, Callable] | tuple[Callable, Callable, TraceCtx, TraceCtx]: + def _make_aug_forward_and_backward(*, return_traces=False, update_cache=True): joint_forward_backward, executor = _get_gradfn_and_executor(bsym) utils.check( joint_forward_backward is not None, @@ -139,6 +139,17 @@ def find_backward_output(forward_input): # Remove put/get grad and backward symbols from augmented forward trace augmented_forward_trace = dce(augmented_forward_trace) + # Check that the number of outputs of the original forward function is the + # same as the number of primal outputs of the augmented forward trace + utils.check( + len(utils.sequencify(bsym.output)) == len(utils.sequencify(augmented_forward_trace.output[0])), + lambda: f"While generating forward and backward functions for {bsym.sym.name}, encountered an error.\n" + "The number of outputs of the original forward function must be the same as the number of primal outputs of the augmented forward trace.\n" + f"Number of outputs of the original forward function: {len(utils.sequencify(bsym.output))}\n" + f"Number of primal outputs of the augmented forward trace: {len(utils.sequencify(augmented_forward_trace.output[0]))}\n" + "Please check the forward function and the augmented forward trace to ensure that they have the same number of outputs.", + ) + # Check if any of the bound symbols in the backward trace are also in the # augmented forward trace # If so, remove them from the backward trace @@ -178,6 +189,9 @@ def fw_fn(*args, **kwargs): def bw_fn(*args, **kwargs): return eval_trace(backward_trace, *args, **kwargs) + if update_cache: + _cache[key] = fw_fn, bw_fn + if not return_traces: return fw_fn, bw_fn return fw_fn, bw_fn, augmented_forward_trace, backward_trace @@ -193,7 +207,6 @@ def bw_fn(*args, **kwargs): key = (bsym.sym, None, subkey := _make_cache_key(bsym.args, bsym.kwargs)) # Cached will be checked in the inner fn if not miss fw_fn, bw_fn = _make_aug_forward_and_backward() - _cache[key] = fw_fn, bw_fn return fw_fn, bw_fn # We have a backend else: @@ -233,7 +246,7 @@ def bw_fn(*args, **kwargs): requested_executors_list_for_bsym.remove(b) requested_executors_list_for_bsym.insert(0, b) cd.executors_list = requested_executors_list_for_bsym - fw_fn, bw_fn, fw_trace, bw_trace = _make_aug_forward_and_backward(True) + fw_fn, bw_fn, fw_trace, bw_trace = _make_aug_forward_and_backward(return_traces=True, update_cache=False) # What should be the optimal iter? # TODO: make benchmark info taken from an autotuner config fw_time, fw_mem, _ = benchmark_trace(fw_trace, iters=100, apply_del_last_used=False) From f3713b301ac2ca00f38d837e345d294f6c6ae53f Mon Sep 17 00:00:00 2001 From: matteochen Date: Mon, 26 Aug 2024 11:45:29 +0300 Subject: [PATCH 094/171] Removed visualizer --- thunder/backend_optimizer/optimizer.py | 7 ------- thunder/executors/torch_autograd.py | 15 --------------- 2 files changed, 22 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 70885b3278..041aee8a99 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -7,7 +7,6 @@ from thunder.core.trace import from_trace, TraceCtx from thunder.core.transforms import construct_trace from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_always_executors -from thunder.visualizer.visualizer_helper import Visualizer from typing import Hashable from thunder.backend_optimizer.utils import benchmark_trace @@ -163,7 +162,6 @@ def __init__( produce_log: bool = True, apply_bucketing_bw_trace: bool, log_file_name: str, - visualizer: Visualizer | None = None, optimizer_type: OptimizerType = OptimizerType.RUNTIME, compile_data, ) -> None: @@ -178,7 +176,6 @@ def __init__( self.debug_msg: str = "" self.partial_costs: dict[TraceCtx, float] = {} - self.visualizer: Visualizer | None = visualizer self.log_file_name: str = log_file_name self.produce_log: bool = produce_log @@ -217,7 +214,6 @@ def __init__( produce_log: bool = True, apply_bucketing_bw_trace: bool, log_file_name: str, - visualizer: Visualizer | None = None, optimizer_type: OptimizerType = OptimizerType.RUNTIME, compile_data, ) -> None: @@ -226,7 +222,6 @@ def __init__( produce_log=produce_log, apply_bucketing_bw_trace=apply_bucketing_bw_trace, log_file_name=log_file_name, - visualizer=visualizer, optimizer_type=optimizer_type, compile_data=compile_data ) @@ -1007,7 +1002,6 @@ def __init__( produce_log=True, apply_bucketing_bw_trace: bool, log_file_name="autotune_debug.log", - visualizer: Visualizer | None = None, optimizer_type: OptimizerType = OptimizerType.RUNTIME, optimizer_algorithm: OptimizationAlgorithm = OptimizationAlgorithm.BEST_FUSER, compile_data, @@ -1020,7 +1014,6 @@ def __init__( produce_log=produce_log, apply_bucketing_bw_trace=apply_bucketing_bw_trace, log_file_name=log_file_name, - visualizer=visualizer, optimizer_type=optimizer_type, compile_data=compile_data, ) diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 477087d9e9..e4ac056fd8 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -144,7 +144,6 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat from thunder.distributed.transforms import FSDPCommBucketing from thunder.distributed.utils import sort_data_parallel_syncs, sort_waits, sort_communication_ops from thunder.executors.passes import del_last_used, transform_for_execution, autotune_transform_for_execution - from thunder.visualizer.visualizer_helper import Visualizer utils.check(compile_data is not None, lambda: "`compile_data` is required") # NOTE: This function is rather slow, so it's intended to be used @@ -206,7 +205,6 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat do_apply_bucketing_bw_trace: bool = getattr(compile_data.fn, "use_fsdp", False) # Now we can run the optimization passes on the forward trace - visualizer = Visualizer(produce_hidden=False) backend_optimizer_ctx: BackendOptimizer | None = ( None if autotune_type is None @@ -214,13 +212,11 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat priority_executors=compile_data.executors_list, apply_bucketing_bw_trace=do_apply_bucketing_bw_trace, produce_log=True, - visualizer=visualizer, optimizer_type=autotune_type, compile_data=compile_data ) ) - visualizer.set_fw_initial_trace(fw_trace) # Get optimzied fw trace fw_extrace = ( transform_for_execution(fw_trace, executors_list=compile_data.executors_list) @@ -235,7 +231,6 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat # Here fw_extrace is not None fw_traces.append(fw_extrace) - visualizer.set_fw_optimized_trace(fw_extrace) # If autotuning is activated, it will take care of the following 2 calls bw_trace = update_bw_from_forward_optimization(fw=fw_extrace, bw=bw_trace) @@ -243,20 +238,17 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat bw_trace = _fsdp_comm_bucketing.apply_bucketing_to_backward_trace(bw_trace) # Now we can run the optimization passes on the backward trace - visualizer.set_bw_initial_trace(bw_trace) if autotune_type is not None: fw_extrace, bw_extrace = autotune_transform_for_execution( optimizer_context=backend_optimizer_ctx, trace=bw_trace, trace_type=TraceType.BW ) fw_traces.append(fw_extrace) - visualizer.set_bw_optimized_trace(fw_extrace) else: bw_extrace = transform_for_execution( bw_trace, executors_list=compile_data.executors_list, ) bw_traces.append(bw_extrace) - visualizer.set_bw_optimized_trace(bw_extrace) if autotune_type is None: # TODO Restore request for no rematerialization @@ -339,11 +331,4 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat # We only want the forward function to be called with `te.fp8_autocast` manager. bw_extrace._include_te_fp8_autocast = False - # Let's include the last traces also after all the passes - visualizer.set_fw_final_trace(fw_extrace) - visualizer.set_bw_final_trace(bw_extrace) - - # TODO: implement new visualizer - # visualizer.produce() - return fw_extrace, bw_extrace From 1d0224ecce7934c89c521d207d64e289154b06a0 Mon Sep 17 00:00:00 2001 From: matteochen Date: Mon, 26 Aug 2024 11:54:18 +0300 Subject: [PATCH 095/171] Prev commit --- thunder/executors/passes.py | 2 -- thunder/visualizer/__init__.py | 0 thunder/visualizer/graphviz.py | 36 ------------------- thunder/visualizer/visualizer_helper.py | 48 ------------------------- 4 files changed, 86 deletions(-) delete mode 100644 thunder/visualizer/__init__.py delete mode 100644 thunder/visualizer/graphviz.py delete mode 100644 thunder/visualizer/visualizer_helper.py diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index b07178aae0..e84cb2136d 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -24,8 +24,6 @@ from thunder.extend import Executor, get_all_executors, get_always_executors, OperatorExecutor, FusionExecutor from thunder.backend_optimizer.optimizer import BackendOptimizer, OptimizerType, TraceCandidates, TraceType -from thunder.visualizer.graphviz import create_graphviz_pdf -from thunder.visualizer.visualizer_helper import Visualizer comment_symbols = {prims.PrimIDs.COMMENT, prims.PrimIDs.UNPACK_TRIVIAL} diff --git a/thunder/visualizer/__init__.py b/thunder/visualizer/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/thunder/visualizer/graphviz.py b/thunder/visualizer/graphviz.py deleted file mode 100644 index db3a98f17f..0000000000 --- a/thunder/visualizer/graphviz.py +++ /dev/null @@ -1,36 +0,0 @@ -import graphviz -from thunder.executors.data_dependent_partition import Node, Graph - -def to_graphviz_dag(g: Graph) -> graphviz.Digraph: - dot = graphviz.Digraph() - visit_stack = list(g.roots) - # Add root nodes - r: Node - for r in g.roots: - dot.node(f'{r.ID}({r.group_bsyms[0].sym.name})') - - visited = set() - cur: Node - while visit_stack: - cur = visit_stack.pop(0) - if cur in visited: - continue - - cur_node_str = f'{cur.ID}({cur.group_bsyms[0].sym.name})' - dot.node(cur_node_str) - - # Connect with parent - for p in cur.parents: - id = p.ID - op = p.group_bsyms[0].sym.name - parent_str = f'{id}({op})' - dot.edge(parent_str, cur_node_str) - - visited.add(cur) - visit_stack.extend(cur.children) - return dot - - -def create_graphviz_pdf(g: Graph, name='graph', directory='./'): - dot = to_graphviz_dag(g) - dot.render(name, view=False, cleanup=True, directory=directory) diff --git a/thunder/visualizer/visualizer_helper.py b/thunder/visualizer/visualizer_helper.py deleted file mode 100644 index 3abdf26121..0000000000 --- a/thunder/visualizer/visualizer_helper.py +++ /dev/null @@ -1,48 +0,0 @@ -from thunder.core.trace import TraceCtx -from thunder.core.transform_common import dce -from thunder.executors.data_dependent_partition import Graph -from thunder.visualizer.graphviz import create_graphviz_pdf - -class Visualizer(): - def __init__(self, produce_hidden = False, traces_directory='traces/') -> None: - self.produce_hidden = produce_hidden - self.traces: dict[str, TraceCtx] = {} - self.hidden_traces: dict[str, TraceCtx] = {} - self.traces_directory = traces_directory - - def set_fw_initial_trace(self, trace: TraceCtx) -> None: - self.traces['fw_initial'] = dce(trace) - - def set_fw_optimized_trace(self, trace: TraceCtx) -> None: - self.traces['fw_optimized'] = dce(trace) - - def set_fw_final_trace(self, trace: TraceCtx) -> None: - self.traces['fw_final'] = dce(trace) - - def set_bw_initial_trace(self, trace: TraceCtx) -> None: - self.traces['bw_initial'] = dce(trace) - - def set_bw_optimized_trace(self, trace: TraceCtx) -> None: - self.traces['bw_optimized'] = dce(trace) - - def set_bw_final_trace(self, trace: TraceCtx) -> None: - self.traces['bw_final'] = dce(trace) - - def set_hidden_trace(self, name: str, trace: TraceCtx) -> None: - self.traces[name] = dce(trace) - - def produce(self): - for k, v in self.traces.items(): - try: - g = Graph(v) - create_graphviz_pdf(g, k, directory=self.traces_directory) - except Exception as e: - print(f"Visualizer failed to produce {k}: {e}") - - if self.produce_hidden: - for k, v in self.hidden_traces.items(): - try: - g = Graph(v) - create_graphviz_pdf(g, k, directory=self.traces_directory) - except Exception as e: - print(f"Visualizer failed to produce hidden {k}: {e}") From f93c950c4cc6142d66a21ad7887467199014b428 Mon Sep 17 00:00:00 2001 From: matteochen Date: Mon, 26 Aug 2024 15:51:13 +0300 Subject: [PATCH 096/171] Updated test runner --- examples/dev/litGPT.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index a322d3fdf2..7b3aaedaec 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -3,19 +3,21 @@ from thunder.tests.litgpt_model import Config import thunder import torch +from thunder.executors.nvmathex import nvmath_ex torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn class Test: - def __init__(self, layers: int, autotune_type: str, batch_size: int, seq_len: int = -1, model_name: str = 'Llama-3-8B') -> None: + def __init__(self, layers: int, autotune_type: str, batch_size: int, seq_len: int = -1, model_name: str = 'Llama-3-8B', executors = None) -> None: self.layers = layers self.autotune_type = autotune_type self.batch_size = batch_size self.seq_len = seq_len self.model_name = model_name + self.executors = executors -layers = [Test(1, 'runtime', 1), Test(1, 'runtime', 1, model_name='Llama-2-7b-hf')] +layers = [Test(1, 'runtime', 1, executors=["cudnn", "sdpa", "fa3", "nvfuser", "torchcompile", ]), Test(1, 'runtime', 1, model_name='Llama-2-7b-hf', executors=["cudnn", "sdpa", "fa3", "nvfuser", "torchcompile", nvmath_ex])] for test in layers: try: @@ -31,11 +33,10 @@ def __init__(self, layers: int, autotune_type: str, batch_size: int, seq_len: in print(f'Input size: {x.size()}') jmodel_def = thunder.jit(model) - from thunder.executors.nvmathex import nvmath_ex jmodel_auto = thunder.jit( model, autotune_type=test.autotune_type, - executors=["cudnn", "sdpa", "fa3", "nvfuser", "torchcompile", nvmath_ex], + executors=test.executors, use_cudagraphs=False, ) From 5d7bd9d60537fc30d9cd1f8ba79a90707b76d9ed Mon Sep 17 00:00:00 2001 From: matteochen Date: Mon, 26 Aug 2024 16:14:54 +0300 Subject: [PATCH 097/171] Fixed cache --- thunder/core/vjp_utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index 87f4832b6e..69881df51e 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -3,6 +3,7 @@ from functools import wraps from inspect import Parameter, Signature from itertools import chain +from os import execl from thunder.core import prims, utils from thunder.core.compile_data import get_compile_data @@ -51,16 +52,20 @@ def make_aug_forward_and_backward(bsym: BoundSymbol) -> tuple[Callable, Callable from thunder.common import _make_cache_key from thunder.core.transforms import _get_gradfn_and_executor, eval_trace - def _make_aug_forward_and_backward(*, return_traces=False, update_cache=True): + def _make_aug_forward_and_backward(*, return_traces=False, update_cache=True) -> tuple: joint_forward_backward, executor = _get_gradfn_and_executor(bsym) utils.check( joint_forward_backward is not None, lambda: f"Cannot generate forward and backward functions for {bsym.sym.name}", ) key = (bsym.sym, executor, subkey := _make_cache_key(bsym.args, bsym.kwargs)) - cached_result = _cache.get(key, None) if subkey is not None else None - if cached_result is not None and not getattr(joint_forward_backward, "_disable_caching", False): - return cached_result + print(key) + # If we update the cache we are not using the autotuner hence cache values for the key entry generated above is valid. + # If autotuner is used, each bsym has an unique key id hence this cache entry is not valid anymore. + if update_cache: + cached_result = _cache.get(key, None) if subkey is not None else None + if cached_result is not None and not getattr(joint_forward_backward, "_disable_caching", False): + return cached_result joint_trace = thunder.trace(inline_trace=False, use_dce=False)(joint_forward_backward, *bsym.args, **bsym.kwargs) consumers = utils.consumers(joint_trace) From 93175616bbf6ebed89261b75f4dcd417baa925a2 Mon Sep 17 00:00:00 2001 From: matteochen Date: Mon, 26 Aug 2024 17:02:04 +0300 Subject: [PATCH 098/171] Removed print --- thunder/core/vjp_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index 69881df51e..055dc5103d 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -59,7 +59,6 @@ def _make_aug_forward_and_backward(*, return_traces=False, update_cache=True) -> lambda: f"Cannot generate forward and backward functions for {bsym.sym.name}", ) key = (bsym.sym, executor, subkey := _make_cache_key(bsym.args, bsym.kwargs)) - print(key) # If we update the cache we are not using the autotuner hence cache values for the key entry generated above is valid. # If autotuner is used, each bsym has an unique key id hence this cache entry is not valid anymore. if update_cache: From 3aaf44d0d60ec1d920c7da9542d9472cadb5fe33 Mon Sep 17 00:00:00 2001 From: matteochen Date: Mon, 26 Aug 2024 17:03:06 +0300 Subject: [PATCH 099/171] Disabled debug --- thunder/backend_optimizer/optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 041aee8a99..2059a451eb 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -606,9 +606,9 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): from thunder.common import transform_for_execution subtrace_placed = transform_for_execution(subtrace, executors_list=[candidate])[-1] - log(f"Subtrace to benchmark single symbol:\n{subtrace_placed}", level=LogLevel.INFO) + log(f"Subtrace to benchmark single symbol:\n{subtrace_placed}", level=LogLevel.DEBUG) t, m, _ = benchmark_trace(subtrace_placed, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0]) - log(f'Operator excutor [{candidate.name}] candidate perf: {t} ms {m/(2**30)} GB', level=LogLevel.INFO) + log(f'Operator excutor [{candidate.name}] candidate perf: {t} ms {m/(2**30)} GB', level=LogLevel.DEBUG) # Update results if t < candidate_best_time.runtime: candidate_best_time = BenchmarkResult(time=t, index=i) From 5b6deb077dcdf6a9c1c5aa4761ef9c976a874edc Mon Sep 17 00:00:00 2001 From: matteochen Date: Mon, 26 Aug 2024 17:12:21 +0300 Subject: [PATCH 100/171] Updated litgpt --- examples/dev/litGPT.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 7b3aaedaec..363f548ecb 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -17,16 +17,41 @@ def __init__(self, layers: int, autotune_type: str, batch_size: int, seq_len: in self.model_name = model_name self.executors = executors -layers = [Test(1, 'runtime', 1, executors=["cudnn", "sdpa", "fa3", "nvfuser", "torchcompile", ]), Test(1, 'runtime', 1, model_name='Llama-2-7b-hf', executors=["cudnn", "sdpa", "fa3", "nvfuser", "torchcompile", nvmath_ex])] +layers = [ + Test( + 1, + "runtime", + 1, + executors=[ + "cudnn", + "sdpa", + "fa3", + "nvfuser", + "torchcompile", + ], + ), + Test( + 1, + "runtime", + 1, + executors=["cudnn", "sdpa", "nvfuser", "torchcompile",], + ), + Test( + 1, + "runtime", + 1, + executors=["cudnn", "sdpa", "fa3", "nvfuser", "torchcompile", nvmath_ex], + ), +] for test in layers: try: - print('\n\nLayers:', test.layers) cfg = Config.from_name(test.model_name) cfg.n_layer = test.layers if test.seq_len != -1: cfg.block_size = test.seq_len torch.set_default_dtype(torch.bfloat16) + print(cfg) with torch.device('cuda'): model = GPT(cfg) x = torch.randint(1, model.config.vocab_size, (test.batch_size, cfg.block_size)) From 550b6395134f2c62b67f42f4b09a1ea6b34dac9f Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 27 Aug 2024 00:52:05 +0300 Subject: [PATCH 101/171] Unit tests and / minors changes for compilation to gain flexibility / executors in compile stats --- examples/dev/sdpa.py | 3 - thunder/__init__.py | 16 ++ thunder/backend_optimizer/optimizer.py | 9 +- thunder/backend_optimizer/utils.py | 49 +++-- thunder/common.py | 2 + thunder/core/transform_common.py | 106 +++++------ thunder/executors/passes.py | 24 ++- thunder/executors/sdpaex.py | 2 - thunder/executors/torch_autograd.py | 1 + thunder/tests/test_autotuner.py | 237 +++++++++++++++++++++++++ 10 files changed, 354 insertions(+), 95 deletions(-) create mode 100644 thunder/tests/test_autotuner.py diff --git a/examples/dev/sdpa.py b/examples/dev/sdpa.py index 75fd67cd39..abe57d8407 100644 --- a/examples/dev/sdpa.py +++ b/examples/dev/sdpa.py @@ -53,6 +53,3 @@ def forward(self, query, key, value): print('###############################################################################') print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') - print('\nTorch benchmark:') - bench(jmodel_def, 'def', iters) - bench(jmodel_auto, 'auto', iters) diff --git a/thunder/__init__.py b/thunder/__init__.py index 23ee3135fb..ecf729ac2c 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -494,6 +494,7 @@ def get_computation_and_inputs(*args, **kwargs): cs.last_traces = comp_traces cs.last_interpreted_instructions = None cs.last_interpreter_log = None + cs.last_executors = cd.executors_list cs.last_prologue_traces = pro_traces cs.last_prologue = pro cs.last_prologue_transformation_start = 0 @@ -533,6 +534,7 @@ def get_computation_and_inputs(*args, **kwargs): cs.last_traces = comp_traces cs.last_interpreted_instructions = None cs.last_interpreter_log = None + cs.last_executors = cd.executors_list cs.last_prologue_traces = pro_traces cs.last_prologue = pro @@ -650,6 +652,7 @@ def get_computation_and_inputs(*args, **kwargs): cs.last_prologue_traces = prologue_traces cs.last_prologue = pro cs.last_traces = computation_traces + cs.last_executors = cd.executors_list backward_traces = [] cs.last_backward_traces = backward_traces cs.last_interpreter_log = last_interpreter_log @@ -937,6 +940,19 @@ def last_prologue_traces(fn) -> TraceCtx: return cs.last_prologue_traces +def executors_applied(fn) -> Sequence[Executor]: + """Obtains the list of executors that have been applied to the computational trace. + If the backward trace is not None, the list will include also executors used in the backward trace. + + """ + cs = compile_stats(fn) + if cs is None: + raise TypeError(f"{fn} doesn't seem to be a thunder compiled function.") + if cs.last_executors is None: + raise TypeError(f"{fn} doesn't seem to have been called yet.") + return cs.last_executors + + def cache_option(fn) -> CACHE_OPTIONS: """Returns the cache options set when JITting the function.""" cd = compile_data(fn) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 2059a451eb..3cc4e38d70 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -479,7 +479,7 @@ def get_placed_trace(mapping: dict[str, Executor], bound_symbols_in: Sequence[Bo placed_trace = assign_executors( in_trace=trc, - executor_list=executor_configuration, + executors_list=executor_configuration, always_executors=self.always_executors, empty_str=self.empty_executor_hashable_placeholder, ) @@ -698,6 +698,7 @@ def measure_and_update_result(): match_bsym_output( group[k], [dict_time_strat, dict_mem_strat], + # In order to benchmark the fusion placecement, we can use any executor for the excluded bsym from the fusion region get_first_available_operator_executor( bsym=group[k], executors=self.executors, @@ -800,7 +801,7 @@ def measure_and_update_result(): ): trc = assign_executors( in_trace=trace, - executor_list=executors, + executors_list=executors, always_executors=self.always_executors, empty_str=self.empty_executor_hashable_placeholder, ) @@ -916,7 +917,7 @@ def _optimize(): ): trc = assign_executors( in_trace=self.trace, - executor_list=placement_ctx.placement, + executors_list=placement_ctx.placement, always_executors=self.always_executors, empty_str=self.empty_executor_hashable_placeholder, compile_data=self.compile_data, @@ -930,7 +931,7 @@ def _optimize(): ): trc = assign_executors( in_trace=self.trace, - executor_list=placement_ctx.placement, + executors_list=placement_ctx.placement, always_executors=self.always_executors, empty_str=self.empty_executor_hashable_placeholder, compile_data=self.compile_data, diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 6d104d9744..1100b84396 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -18,13 +18,13 @@ def rec(s) -> str: name = "[" for e in s: if e is None: - name += "None" + name += "None#" elif hasattr(e, "name"): - name += e.name - elif isinstance(e, Sequence): + name += e.name + '#' + elif isinstance(e, Sequence) and not isinstance(e, str): name += rec(e) elif isinstance(e, int): - name += 'int' + str(e) + name += 'int' + str(e) + '#' else: raise AssertionError(f"Unsupported type = {type(e)}") name += ']' @@ -49,6 +49,15 @@ def get_first_available_operator_executor( return ex return Executor(name=empty_hash) +def flatten_sequence(sequence: Sequence) -> list: + res = [] + for e in sequence: + if isinstance(e, Sequence): + res.extend(flatten_sequence(e)) + # Skip Nones as they are not useful + elif e is not None: + res.append(e) + return res def get_not_used_intermediate_outsputs(trace_in: TraceCtx) -> list[Proxy]: def is_in_sequence(seq: Sequence[Any], t: Proxy): @@ -57,22 +66,6 @@ def is_in_sequence(seq: Sequence[Any], t: Proxy): return True return False - def is_possible_out(name: str): - if not name.startswith("t"): - return False - num = name[1:] - return num.isdigit() - - def flatten_sequence(sequence: Sequence) -> list: - res = [] - for e in sequence: - if isinstance(e, Sequence): - res.extend(flatten_sequence(e)) - # Skip Nones as they are not useful - elif e is not None: - res.append(e) - return res - def unpack_output(out) -> Sequence[Proxy]: if issubclass(type(out), Proxy): return [out] @@ -107,7 +100,7 @@ def unpack_output(out) -> Sequence[Proxy]: def assign_executors( *, in_trace: TraceCtx, - executor_list: list[Executor | FusionExecutor | OperatorExecutor] + executors_list: list[Executor | FusionExecutor | OperatorExecutor] | tuple[Executor | FusionExecutor | OperatorExecutor, ...], always_executors: list[Executor] | tuple[Executor, ...], empty_str: str | Hashable, @@ -197,24 +190,24 @@ def visit_helper(bsym: BoundSymbol, ex: Executor) -> None | bool: def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: return transforms.VISIT_TYPE.NO_OP if visit_helper(bsym, ex) is None else transforms.VISIT_TYPE.REPLACE - if len(executor_list) != len(in_trace.bound_symbols): - raise AssertionError("len(executor_list) != len(in_trace.bound_symbols)") + if len(executors_list) != len(in_trace.bound_symbols): + raise AssertionError("len(executors_list) != len(in_trace.bound_symbols)") cached_subsymbols: dict[str, Sequence[BoundSymbol]] = {} executor_mapping: dict[str, Executor] = {} unique_fusion_executors = set() # Input should have equal length - if len(executor_list) != len(in_trace.bound_symbols): - raise AssertionError("len(executor_list) != len(extrace.bound_symbols)") + if len(executors_list) != len(in_trace.bound_symbols): + raise AssertionError("len(executors_list) != len(extrace.bound_symbols)") - for b, e in zip(in_trace.bound_symbols, executor_list): + for b, e in zip(in_trace.bound_symbols, executors_list): if isinstance(e, FusionExecutor): unique_fusion_executors.add(e) if isinstance(b.output, TensorProxy): executor_mapping[b.output.name] = e - extrace = transforms.visitor_transform_paired(in_trace, visit, zip(in_trace.bound_symbols, executor_list)) + extrace = transforms.visitor_transform_paired(in_trace, visit, zip(in_trace.bound_symbols, executors_list)) # Restores original variables bound_symbols: list[BoundSymbol] = [] @@ -692,7 +685,7 @@ def transform_proxy_to_torch(sequence: Sequence, level=0, **kwargs) -> tuple | l # Transformer engine Context object # # This instruction will populate the args with a dummy context which is not correct in theory. - # For the benchmark purpose (where this fn is currently used) this error will not impact on the runtime correctness as at the end we + # For the benchmark purpose (where this fn is currently used) this error will not impact on the runtime correctness as at the end we # will use the cached runtime contexts from the forward pass. # We need this only to generate a context for the static inputs (which are discarded afterwards). # diff --git a/thunder/common.py b/thunder/common.py index 95414015bd..3a40c6a56a 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -70,6 +70,7 @@ class CompileStats: last_prologue_traces (Sequence[TraceCtx]): last_interpreted_instructions (Generator[dist.Instruction, None, None] | None): last_interpreter_log (list[InterpreterLogItem] | None): + last_executors (Sequence[Executor] | None): last_backward_traces (Sequence[TraceCtx]): last_trace_host_start (int): last_trace_host_stop (int): @@ -103,6 +104,7 @@ def __init__(self): self.last_prologue_traces = None self.last_interpreted_instructions: Generator[dis.Instruction, None, None] | None = None self.last_interpreter_log: list[InterpreterLogItem] | None = None + self.last_executors: Sequence[Executor] | None = None # torch.autograd.Function specific data self.last_backward_traces = None diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 1750c7cfe7..88da7d9907 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -9,6 +9,7 @@ from functools import partial import thunder +from thunder.core.compile_data import get_compile_data import thunder.core.prims as prims from thunder.core.baseutils import BoundSymbolInterface from thunder.core.proxies import Proxy, variableify, Variable, TensorProxy, unvariableify @@ -98,65 +99,70 @@ def check(inp, log_str): def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace: start_time_ns = time.perf_counter_ns() - producer_map: ProxyDict = producers(trace) + cd = get_compile_data() + disabled = not(not cd or (cd and not cd.compile_options.get('disable_dce', None))) + if not disabled: + producer_map: ProxyDict = producers(trace) - flat_trace_outputs, _ = tree_flatten(trace.output) - if needed_proxies is None: - needed_proxies: set[Variable] = set(tuple(variableify(x) for x in flat_trace_outputs if isinstance(x, Proxy))) - else: - needed_proxies.update(tuple(variableify(x) for x in flat_trace_outputs if isinstance(x, Proxy))) - dced = [] - - bsym: BoundSymbol - for bsym in reversed(trace.bound_symbols): - # Preserves symbols that should never be collected - if has_tags(bsym, {prims.OpTags.DONT_DCE}): - needed = True + flat_trace_outputs, _ = tree_flatten(trace.output) + if needed_proxies is None: + needed_proxies: set[Variable] = set(tuple(variableify(x) for x in flat_trace_outputs if isinstance(x, Proxy))) else: - needed = False + needed_proxies.update(tuple(variableify(x) for x in flat_trace_outputs if isinstance(x, Proxy))) + dced = [] - # NOTE This block is run even if we know we're preserving the operation, because it - # may mark some of the operation's outputs as unused - some_unused = False - for out in bsym.flat_proxy_outs: - if variableify(out) in needed_proxies and producer_map[out] == bsym: + bsym: BoundSymbol + for bsym in reversed(trace.bound_symbols): + # Preserves symbols that should never be collected + if has_tags(bsym, {prims.OpTags.DONT_DCE}): needed = True else: - some_unused = True - - if needed: - nbsym: BoundSymbol = bsym - - # Replaces unused Proxy outputs with None - if some_unused: - - def _helper(x): - if isinstance(x, Proxy) and (variableify(x) not in needed_proxies or producer_map[x] != bsym): - return None - return x - - nbsym_output = tree_map(_helper, bsym.output) - nbsym = bsym.from_bsym(output=nbsym_output) - - # Eliminates no-op subsymbols - # NOTE In general editing subsymbols doesn't do anything, but no-op subsymbols are a pain - # for transforms to deal with. Transforms typically look for a "flattened" version of an - # operator for which they can apply their rules, and no-op subsymbols have no - # flattening, requiring each transform handle them explicitly or DCE them themselves - # while flattening. - _remove_noop_subsymbols(nbsym) - - dced.append(nbsym) - for x in nbsym.flat_proxy_args: - needed_proxies.add(variableify(x)) - - dcetrace = from_trace(trace) - dcetrace.bound_symbols = list(reversed(dced)) + needed = False + + # NOTE This block is run even if we know we're preserving the operation, because it + # may mark some of the operation's outputs as unused + some_unused = False + for out in bsym.flat_proxy_outs: + if variableify(out) in needed_proxies and producer_map[out] == bsym: + needed = True + else: + some_unused = True + + if needed: + nbsym: BoundSymbol = bsym + + # Replaces unused Proxy outputs with None + if some_unused: + + def _helper(x): + if isinstance(x, Proxy) and (variableify(x) not in needed_proxies or producer_map[x] != bsym): + return None + return x + + nbsym_output = tree_map(_helper, bsym.output) + nbsym = bsym.from_bsym(output=nbsym_output) + + # Eliminates no-op subsymbols + # NOTE In general editing subsymbols doesn't do anything, but no-op subsymbols are a pain + # for transforms to deal with. Transforms typically look for a "flattened" version of an + # operator for which they can apply their rules, and no-op subsymbols have no + # flattening, requiring each transform handle them explicitly or DCE them themselves + # while flattening. + _remove_noop_subsymbols(nbsym) + + dced.append(nbsym) + for x in nbsym.flat_proxy_args: + needed_proxies.add(variableify(x)) + + dcetrace = from_trace(trace) + dcetrace.bound_symbols = list(reversed(dced)) + else: + dcetrace = trace end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns elapsed_time_millis = elapsed_time_ns // 1000000 - dcetrace.set_provenance(TraceProvenance(f"Dead Code Elimination (took {elapsed_time_millis} milliseconds)")) + dcetrace.set_provenance(TraceProvenance(f"Dead Code Elimination{' Skipped Per Compile Options' if disabled else ''} (took {elapsed_time_millis} milliseconds)")) return dcetrace diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index e84cb2136d..dd70fada74 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -8,6 +8,7 @@ from functools import partial import time +from thunder.core.compile_data import get_compile_data from thunder.core.trace import TraceCtx, from_trace, TraceProvenance, VariableInterface, reset_tracectx, set_tracectx from thunder.core.codeutils import SigInfo import thunder.core.dtypes as dtypes @@ -352,21 +353,28 @@ def del_last_used(trace: TraceCtx, *, clear_mutable_collections=False) -> TraceC Returns: list: transformed trace """ - start_time_ns = time.perf_counter_ns() - del_trace = from_trace(trace) + # If dce is disabled, we have to disable this pass also + cd = get_compile_data() + disabled = not(not cd or (cd and not cd.compile_options.get('disable_dce', None))) - outs = cutils.sequencify(trace.output) - flat_outs, _ = tree_flatten(outs) + start_time_ns = time.perf_counter_ns() - del_trace.bound_symbols = _del_last_used( - trace.bound_symbols, flat_outs, clear_mutable_collections=clear_mutable_collections - ) + if not disabled: + del_trace = from_trace(trace) + outs = cutils.sequencify(trace.output) + flat_outs, _ = tree_flatten(outs) + + del_trace.bound_symbols = _del_last_used( + trace.bound_symbols, flat_outs, clear_mutable_collections=clear_mutable_collections + ) + else: + del_trace = trace end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns elapsed_time_millis = elapsed_time_ns // 1000000 - del_trace.set_provenance(TraceProvenance(f"Delete Last Used (took {elapsed_time_millis} milliseconds)")) + del_trace.set_provenance(TraceProvenance(f"Delete Last Used{' Skipped Per Compile Options' if disabled else ''} (took {elapsed_time_millis} milliseconds)")) return del_trace diff --git a/thunder/executors/sdpaex.py b/thunder/executors/sdpaex.py index 76f0e733d8..cb569f3e28 100644 --- a/thunder/executors/sdpaex.py +++ b/thunder/executors/sdpaex.py @@ -505,13 +505,11 @@ def _scaled_dot_product_attention_fused( tensor_args = (query, key, value) scalar_args = (dropout_p, is_causal) if backend == SpdaBackend.FLASH_ATTENTION: - print('FLASH ATT') # Use flash attention kernel (primal, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, _) = sdpfa_gradfwd( *tensor_args, *scalar_args, scale=scale ) elif backend == SpdaBackend.MEMORY_EFFICIENT: - print('MEM EFF') # Use memory efficient kernel, which supports fp32 and attention mask arguments (primal, logsumexp, philox_seed, philox_offset) = sdpea_gradfwd( *tensor_args, attn_mask, *scalar_args, scale=scale diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index e4ac056fd8..15be9079a5 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -325,6 +325,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat if compile_stats is not None: compile_stats.last_traces += fw_traces compile_stats.last_backward_traces += bw_traces + compile_stats.last_executors = compile_data.executors_list # Enable wrapping with `te.fp8_autocast`. fw_extrace._include_te_fp8_autocast = True diff --git a/thunder/tests/test_autotuner.py b/thunder/tests/test_autotuner.py new file mode 100644 index 0000000000..f8937ed05d --- /dev/null +++ b/thunder/tests/test_autotuner.py @@ -0,0 +1,237 @@ +from typing import Sequence +import thunder.backend_optimizer.utils as aut_utils +import pytest +import torch +import thunder +from thunder.core.prims import PrimIDs +from thunder.core.symbol import BoundSymbol +from thunder.core.trace import TraceCtx +from thunder.extend import Executor, get_always_executors +from thunder.executors.torchex import ex as torchex +from thunder.executors.torch_compile import torch_compile_ex +from thunder.executors.nvfuserex import nvfuserex +from thunder.tests.framework import requiresCUDA, run_snippet + +class DummyProxy(): + def __init__(self, name) -> None: + self.name = name + +@pytest.mark.parametrize("data,expected", + [ + ([DummyProxy('a'), DummyProxy('b')], '[a#b#]'), + ([DummyProxy('a'), DummyProxy('b'), 90], '[a#b#int90#]'), + ([DummyProxy('a'), DummyProxy('b'), 90, None], '[a#b#int90#None#]'), + ([DummyProxy('a'), DummyProxy('b'), 90, [DummyProxy('c'), [DummyProxy('d')]]], '[a#b#int90#[c#[d#]]]') + ] +) +def test_sequence_hash(data, expected): + assert(aut_utils.sequence_hash(data) == expected) + +@pytest.mark.parametrize("data,expected", + [ + ([DummyProxy('a'), "b"], '[a#b#]'), + ] +) +def test_sequence_hash_bad_input(data, expected): + with pytest.raises(AssertionError): + assert aut_utils.sequence_hash(data) == expected + + +@pytest.mark.parametrize("data,expected_sum,expected_others", + [ + ([nvfuserex, torch_compile_ex], Executor(name='empty'), Executor(name='empty')), + ([nvfuserex, torchex], torchex, Executor(name='empty')), + ] +) +def test_first_available_operator_executor(data,expected_sum,expected_others): + def fn(a: torch.Tensor, b: torch.Tensor): + return a + b + + a = torch.randn(1,1) + b = torch.randn(1,1) + jitted = thunder.jit(fn) + jitted(a, b) + trace = thunder.last_traces(jitted)[-1] + for bsym in trace.bound_symbols: + if bsym.sym.id == PrimIDs.ADD: + assert aut_utils.get_first_available_operator_executor(bsym = bsym, executors = data, empty_hash = 'empty') == expected_sum + else: + assert aut_utils.get_first_available_operator_executor(bsym = bsym, executors = data, empty_hash = 'empty') == expected_others + +@pytest.mark.parametrize("test,expected", + [ + ([1, 2, 3], [1, 2, 3]), + ([1, 2, [3, 4]], [1, 2, 3, 4]), + ([1, 2, [3, 4, [None]]], [1, 2, 3, 4]), + ] +) +def test_flatten_sequence(test, expected): + assert aut_utils.flatten_sequence(test) == expected + + +def test_get_not_used_intermediate_outputs(): + # Flat outputs + def fn(a: torch.Tensor, b: torch.Tensor): + t1 = a - b + t2 = a * b + t3 = a / b + return (a + b) + t2 + + a = torch.randn(1,1) + b = torch.randn(1,1) + jitted = thunder.jit(fn, disable_dce=True) + jitted(a, b) + trace = thunder.last_traces(jitted)[-1] + + not_used = aut_utils.get_not_used_intermediate_outsputs(trace) + # We have not used t1, t3 in trace + not_used_labels = ['t1', 't3'] + assert len(not_used) == 2 + for t in not_used: + assert t.name in not_used_labels + not_used_labels.remove(t.name) + +def _assign_executors_fn(a: torch.Tensor): + t0 = a * 2 + t1 = a * a + t3 = t0 + t1 + return t3 + +@pytest.mark.parametrize( + "fn, args, executors", + [ + ( + _assign_executors_fn, + torch.randn(1, 1), + [Executor("empty"), torchex, torchex, torchex, Executor("empty")], + ), + ( + _assign_executors_fn, + torch.randn(1, 1), + [Executor("empty"), torch_compile_ex, torch_compile_ex, torch_compile_ex, Executor("empty")], + ), + ( + _assign_executors_fn, + torch.randn(1, 1), + [Executor("empty"), torch_compile_ex, torch_compile_ex, torchex, Executor("empty")], + ), + ( + _assign_executors_fn, + torch.randn(1, 1), + [Executor("empty"), torchex, torch_compile_ex, torch_compile_ex, Executor("empty")], + ) + ], +) +def test_assign_executors(fn, args, executors): + trace: TraceCtx = thunder.trace(inline_trace=True)(fn, args) + placed: TraceCtx = aut_utils.assign_executors( + in_trace=trace, executors_list=executors, always_executors=get_always_executors(), empty_str="empty" + ) + + def _id(bsym: BoundSymbol): + res = bsym.sym.name + if isinstance(bsym.output, Sequence): + res += '#' + aut_utils.sequence_hash(bsym.output) + else: + res += '#' + bsym.output.name + + return res + + # Unapacks and return symbols are filtered out + executor_map = { + _id(b): e if e.name != "empty" else None + for b, e in zip(trace.bound_symbols, executors) + if b.output is not None and b.sym.id != PrimIDs.RETURN + and b.args is not None + } + + for b in placed.bound_symbols: + # print(b) + if b.sym.is_fusion: + # Search in every subsymbol + for sub in b.subsymbols: + identif = _id(sub) + assert b.sym.executor == executor_map[identif] + elif b.sym.id != PrimIDs.RETURN and b.args: + identif = _id(b) + assert b.sym.executor == executor_map[identif] + +class Linear(torch.nn.Module): + def __init__(self, a, b) -> None: + super().__init__() + self.linear = torch.nn.Linear(a, b) + + def forward(self, x): + return self.linear(x) + +class Matmul(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + return x @ x + +@pytest.mark.parametrize( + "model, x, op, expected", + [ + ( + Linear(8,8), + torch.randn(8,8), + 'linear', + True + ), + ( + Linear(8,8), + torch.randn(8,8), + 'add', + False + ), + ( + Matmul(), + torch.randn(8,8), + 'matmul', + True + ), + ], +) +def test_operation_in_trace(model, x, op, expected): + jitted = thunder.jit(model) + jitted(x) + # jitted(args if not isinstance(args, Sequence) else *args) + trace = thunder.last_traces(jitted)[-1] + + assert aut_utils.operation_in_trace(trace=trace, op=op) == expected + +class Sdpa(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, q, k, v): + return torch.nn.functional.scaled_dot_product_attention(q, k, v) + +@pytest.mark.parametrize("device,", ["cuda"]) +@requiresCUDA +def test_sdpa(device: str): + batch = 10 + seq_len = 128 + num_heads = 4 + dim_per_head = 32 + + query = torch.randn([batch, seq_len, num_heads, dim_per_head], dtype=torch.float16, device=device, requires_grad=True) + key = torch.randn([batch, seq_len, num_heads, dim_per_head], dtype=torch.float16, device=device, requires_grad=True) + value = torch.randn([batch, seq_len, num_heads, dim_per_head], dtype=torch.float16, device=device, requires_grad=True) + + model = Sdpa() + executors = ['cudnn', 'sdpa'] + jitted = thunder.jit(model, autotune_type='runtime', executors=executors) + jitted(query, key, value) + + exs: Sequence[Executor] = thunder.executors_applied(jitted) + + sdpa_executors_occurences = 0 + for ex in exs: + if ex.name in executors: + sdpa_executors_occurences += 1 + + assert sdpa_executors_occurences == 1 + From dfc70b46bffb64a7e852ad59d11fe9a3a262531e Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 27 Aug 2024 11:17:08 +0300 Subject: [PATCH 102/171] Fix appended label --- thunder/backend_optimizer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 1100b84396..53f6277030 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -738,6 +738,6 @@ def reorder_executors_list(executors: Sequence): elif (ex == nvfuser_ex or ex == torch_compile_ex): found = True if not found: - reordered.append(nvfuser_ex) + reordered.insert(0, nvfuser_ex.name if are_inputs_names else nvfuser_ex) return reordered From 3b3009b6e16c2cdc59de23e3aa72bdcabdfa16fb Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 27 Aug 2024 11:18:01 +0300 Subject: [PATCH 103/171] Updated comments and removed import --- thunder/backend_optimizer/utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 53f6277030..f58aa18dd9 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -3,7 +3,7 @@ from thunder.core.compile_data import get_compile_data from thunder.core.dtypes import to_torch_dtype from thunder.core.prims import PrimIDs -from thunder.core.proxies import AnyProxy, CollectionProxy, FloatProxy, IntegerProxy, Proxy, TensorProxy, Variable, variableify +from thunder.core.proxies import AnyProxy, FloatProxy, IntegerProxy, Proxy, TensorProxy, Variable, variableify from thunder.core.symbol import BoundSymbol, Symbol from thunder.core.trace import TraceCtx, get_tracectx, reset_tracectx, set_tracectx from thunder.extend import Executor, FusionExecutor, OperatorExecutor @@ -636,8 +636,7 @@ def transform_tensor(arg: TensorProxy, **kwargs) -> torch.Tensor: if torch_dtype is None: raise AssertionError(f'Unrecognized thunder dtype: {dtype}') if is_float_dtype(dtype): - # Use TE Float8 if TE is enabled, it has float32 ad torch dtype - # NOTE: if we have a standalone torch.float8 inside the args and it is not the TE Float8 it won't be parsed correctly for now + # Use TE Float8 if TE is enabled, it has float32 torch dtype te_used = kwargs.get('te_used', False) if te_used: tensor: torch.Tensor = torch.randn( @@ -665,7 +664,6 @@ def transform_tensor(arg: TensorProxy, **kwargs) -> torch.Tensor: return tensor -# TODO (matteochen): use more appropriate mock int and float def transform_proxy_to_torch(sequence: Sequence, level=0, **kwargs) -> tuple | list: from thunder.executors.transformer_engineex import Context as C res = [] From 4ef329b28f913c82cd50134d814298870c16cd7a Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 27 Aug 2024 11:37:14 +0300 Subject: [PATCH 104/171] New tests and linter --- thunder/backend_optimizer/optimizer.py | 91 +++++--- thunder/tests/test_autotuner.py | 292 +++++++++++++++++++------ 2 files changed, 280 insertions(+), 103 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 3cc4e38d70..fad3fe333c 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -10,17 +10,18 @@ from typing import Hashable from thunder.backend_optimizer.utils import benchmark_trace -# Defining a wrapper fn as the imports will crash in the global scope + def get_fw_bw_split_backends_options(bsym: BoundSymbol | None = None) -> list | dict: from thunder.executors.sdpaex import sdpa_ex from thunder.executors.cudnnex import cudnn_ex from thunder.executors.fa3ex import fa3_ex from thunder.executors.transformer_engineex import transformer_engine_ex - #Current configuration + + # Current configuration options: dict[str, list] = { # TODO: filter out TE only if requested - 'linear': [transformer_engine_ex], - 'scaled_dot_product_attention': [sdpa_ex, cudnn_ex, fa3_ex], + "linear": [transformer_engine_ex], + "scaled_dot_product_attention": [sdpa_ex, cudnn_ex, fa3_ex], } return options.get(bsym.sym.name, []) if bsym else options @@ -42,6 +43,7 @@ def __init__( self.label: str | Hashable = label self.index: int = index + class OptimizerType(Enum): MEMORY = 0 RUNTIME = 1 @@ -154,6 +156,7 @@ def log(what: str, level: LogLevel = LogLevel.INFO): if log_level == LogLevel.DEBUG or log_level == level: print(f"================================================================================ Autotune: {what}") + class PlacerBase: def __init__( self, @@ -164,7 +167,7 @@ def __init__( log_file_name: str, optimizer_type: OptimizerType = OptimizerType.RUNTIME, compile_data, - ) -> None: + ) -> None: self.always_executors: tuple[Executor, ...] = get_always_executors() self.empty_executor_hashable_placeholder: str = "empty" self.executors: Sequence[Executor] = priority_executors @@ -206,6 +209,7 @@ def get_optimal_fw_traces(self) -> Sequence[TraceCtx]: def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: return (TraceCtx(), TraceCtx()) + class FusionPlacer_BeamSearch(PlacerBase): def __init__( self, @@ -223,7 +227,7 @@ def __init__( apply_bucketing_bw_trace=apply_bucketing_bw_trace, log_file_name=log_file_name, optimizer_type=optimizer_type, - compile_data=compile_data + compile_data=compile_data, ) # Strat fusion @@ -232,6 +236,7 @@ def __init__( from thunder.executors.nvfuserex_impl import linear, _linear_check from thunder.executors.nvfuserex_impl import matmul, _matmul_check + self.known_fusion_ex_compile_options: dict[str | Hashable, list[FusionCompileOptionsHelper]] = { "nvfuser": [ FusionCompileOptionsHelper("nv_enable_linear", "linear", PrimIDs.LINEAR, linear, _linear_check), @@ -268,6 +273,7 @@ def _best_runtime_and_memory_candidates(self, candidates): # We want to verify that it is not set to false if self.compile_data.use_cudagraphs is None or self.compile_data.use_cudagraphs == True: from thunder.executors.cudagraphex import cudagraphex + pair_options.extend( [ (cudagraphex.fusion_pass(pair.fw), cudagraphex.fusion_pass(pair.bw)), @@ -343,7 +349,9 @@ def fw_benchmark(): for t, o in zip([trc_time, trc_mem], [compile_opt_time, compile_opt_mem]): log(f'Caching fw candidate [compile option: {o.fusion_tag if o else "None"}]') self.cached_fw_traces.append( - TraceCandidate(trace=t, compile_opt=o, label=label + '_enabled_' + o.fusion_tag if o is not None else label) + TraceCandidate( + trace=t, compile_opt=o, label=label + "_enabled_" + o.fusion_tag if o is not None else label + ) ) def bw_benchmark(): @@ -355,7 +363,9 @@ def bw_benchmark(): # Unpack the dict label = list(pair.keys())[0] trace = list(pair.values())[0] - trace_time, trace_mem, res = benchmark_trace(trace, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0]) + trace_time, trace_mem, res = benchmark_trace( + trace, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0] + ) self.debug_msg += f"Trace name = [{label}] - Target: TIME - Time = {trace_time} ms - Mem = {trace_mem / (2**30)} GB\n{trace}\n\n" # log( # f'Benchmark trace (target TIME) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}', @@ -370,7 +380,9 @@ def bw_benchmark(): label = list(pair.keys())[0] trace = list(pair.values())[0] - trace_time, trace_mem, res = benchmark_trace(trace, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0]) + trace_time, trace_mem, res = benchmark_trace( + trace, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0] + ) del res self.debug_msg += f"Trace name = [{label}] - Target: MEM - Mem = {trace_mem / (2**30)} GB - Time = {trace_time} ms\n{trace}\n\n" # log( @@ -489,6 +501,7 @@ def _search(ex: FusionExecutor, executor_compile_option: FusionCompileOptionsHel """ Fusable fn definition for nvFuser """ + # Each executor has a custom should fuse function, but the current impl need to access local executor object def _should_fuse_nvfuser(a: Node, b: Node): def _can_fuse_node(n: Node): @@ -505,6 +518,7 @@ def _can_fuse_node(n: Node): """ Fusable fn definition for torch.compile """ + def _should_fuse_torchcompile(a: Node, b: Node): def _can_fuse_node(n: Node): if len(n.group_bsyms) > 1: @@ -531,13 +545,11 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): merge_fn: Callable match ex.name: - case 'nvfuser': + case "nvfuser": merge_fn = _should_fuse_nvfuser - case 'torchcompile': + case "torchcompile": merge_fn = _should_fuse_torchcompile - bound_symbol_groups = fuse_bound_symbols( - self.trace, merge_fn - ) + bound_symbol_groups = fuse_bound_symbols(self.trace, merge_fn) log(f"Number of Fusion groups = {len(bound_symbol_groups)}", level=LogLevel.DEBUG) # Print fusion groups if requested @@ -553,8 +565,14 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): increasing_symbols = [] for group_id, group in enumerate(bound_symbol_groups): log(f"Fusion group id: {group_id}", level=LogLevel.DEBUG) - log(f"Fusion group start = [{group[0].output.name if hasattr(group[0].output, 'name') else 'unknown'} = {group[0].sym.name}]", level=LogLevel.DEBUG) - log(f"Fusion group end = [{group[-1].output.name if hasattr(group[-1].output, 'name') else 'unknown'} = {group[-1].sym.name}]", level=LogLevel.DEBUG) + log( + f"Fusion group start = [{group[0].output.name if hasattr(group[0].output, 'name') else 'unknown'} = {group[0].sym.name}]", + level=LogLevel.DEBUG, + ) + log( + f"Fusion group end = [{group[-1].output.name if hasattr(group[-1].output, 'name') else 'unknown'} = {group[-1].sym.name}]", + level=LogLevel.DEBUG, + ) if group[0].sym.name != "return": increasing_symbols += group @@ -562,7 +580,10 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): # We assign to a Fusion executor only region with at least 2 elements. Otherwise let the best OperatorExecutor pick the symbol up if len(group) < 2: current_bsym = group[0] - log(f"--> Single group: [{current_bsym.output.name if hasattr(current_bsym.output, 'name') else 'unknown'} = {current_bsym.sym.name}]", level=LogLevel.DEBUG) + log( + f"--> Single group: [{current_bsym.output.name if hasattr(current_bsym.output, 'name') else 'unknown'} = {current_bsym.sym.name}]", + level=LogLevel.DEBUG, + ) # Filter out all possible candidates for the current symbol candidate_executors = [ ex @@ -603,12 +624,17 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): # TODO: enable requests for no remat becnhmarks # TODO: we should consider also FusionExecutor that can execute this single bsym in this beam search for i, candidate in enumerate(candidate_executors): - from thunder.common import transform_for_execution + subtrace_placed = transform_for_execution(subtrace, executors_list=[candidate])[-1] log(f"Subtrace to benchmark single symbol:\n{subtrace_placed}", level=LogLevel.DEBUG) - t, m, _ = benchmark_trace(subtrace_placed, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0]) - log(f'Operator excutor [{candidate.name}] candidate perf: {t} ms {m/(2**30)} GB', level=LogLevel.DEBUG) + t, m, _ = benchmark_trace( + subtrace_placed, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0] + ) + log( + f"Operator excutor [{candidate.name}] candidate perf: {t} ms {m/(2**30)} GB", + level=LogLevel.DEBUG, + ) # Update results if t < candidate_best_time.runtime: candidate_best_time = BenchmarkResult(time=t, index=i) @@ -687,7 +713,7 @@ def measure_and_update_result(): n_missing_bsyms = len(group) - start_idx # TODO (matteochen): consider to add the iteration with no fusion regions - for i in range(0, n_missing_bsyms, n_missing_bsyms-1 if self.trace_type == TraceType.BW else 1): + for i in range(0, n_missing_bsyms, n_missing_bsyms - 1 if self.trace_type == TraceType.BW else 1): # for i in range(0, n_missing_bsyms): # From top to bottom (this will include the whole region) # -> First iteration is the one with fusion region with single element @@ -844,7 +870,7 @@ def measure_and_update_result(): # Currently we are enabling one compile option at the time as testing all the permutations might need too much time. # TODO: Consider implementing patterns based on the executor under investingation if ex_compile_opts: - log(f'{ex.name} compile options: {[option.fusion_tag for option in ex_compile_opts]}') + log(f"{ex.name} compile options: {[option.fusion_tag for option in ex_compile_opts]}") for opt in ex_compile_opts: # Search only if we have an instruction related to the compile option op_in_trace: bool = operation_in_trace(trace=self.trace, op=opt.symbol_tag) @@ -968,7 +994,7 @@ def _optimize(): # Set the current active cached forward trace context log( f'Current fw cached ctx:\n{fw_trace_candidate.trace}\nOptions: {fw_trace_candidate.compile_opt.fusion_tag if fw_trace_candidate.compile_opt is not None else "None"}', - level=LogLevel.DEBUG + level=LogLevel.DEBUG, ) self.active_fw_trace_ctx = fw_trace_candidate.trace, fw_trace_candidate.compile_opt @@ -1008,16 +1034,14 @@ def __init__( compile_data, ) -> None: if optimizer_algorithm != OptimizationAlgorithm.BEST_FUSER: - raise AssertionError(f'Optimization {optimizer_algorithm} not implemented') - self.optimizer: PlacerBase = ( - FusionPlacer_BeamSearch( - priority_executors=priority_executors, - produce_log=produce_log, - apply_bucketing_bw_trace=apply_bucketing_bw_trace, - log_file_name=log_file_name, - optimizer_type=optimizer_type, - compile_data=compile_data, - ) + raise AssertionError(f"Optimization {optimizer_algorithm} not implemented") + self.optimizer: PlacerBase = FusionPlacer_BeamSearch( + priority_executors=priority_executors, + produce_log=produce_log, + apply_bucketing_bw_trace=apply_bucketing_bw_trace, + log_file_name=log_file_name, + optimizer_type=optimizer_type, + compile_data=compile_data, ) log("Executors:", level=LogLevel.INFO) @@ -1038,4 +1062,3 @@ def get_optimal_fw_traces(self) -> Sequence[TraceCtx]: def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: return self.optimizer.get_optimal_fw_bw_traces() - diff --git a/thunder/tests/test_autotuner.py b/thunder/tests/test_autotuner.py index f8937ed05d..db4a025a9f 100644 --- a/thunder/tests/test_autotuner.py +++ b/thunder/tests/test_autotuner.py @@ -1,69 +1,85 @@ -from typing import Sequence +from typing import Callable, Sequence import thunder.backend_optimizer.utils as aut_utils import pytest import torch import thunder +from thunder.core.dtypes import to_torch_dtype from thunder.core.prims import PrimIDs +from thunder.core.proxies import FloatProxy, IntegerProxy, TensorProxy from thunder.core.symbol import BoundSymbol from thunder.core.trace import TraceCtx from thunder.extend import Executor, get_always_executors from thunder.executors.torchex import ex as torchex from thunder.executors.torch_compile import torch_compile_ex from thunder.executors.nvfuserex import nvfuserex -from thunder.tests.framework import requiresCUDA, run_snippet +from thunder.tests.framework import requiresCUDA -class DummyProxy(): + +class DummyProxy: def __init__(self, name) -> None: self.name = name -@pytest.mark.parametrize("data,expected", + +@pytest.mark.parametrize( + "data,expected", [ - ([DummyProxy('a'), DummyProxy('b')], '[a#b#]'), - ([DummyProxy('a'), DummyProxy('b'), 90], '[a#b#int90#]'), - ([DummyProxy('a'), DummyProxy('b'), 90, None], '[a#b#int90#None#]'), - ([DummyProxy('a'), DummyProxy('b'), 90, [DummyProxy('c'), [DummyProxy('d')]]], '[a#b#int90#[c#[d#]]]') - ] + ([DummyProxy("a"), DummyProxy("b")], "[a#b#]"), + ([DummyProxy("a"), DummyProxy("b"), 90], "[a#b#int90#]"), + ([DummyProxy("a"), DummyProxy("b"), 90, None], "[a#b#int90#None#]"), + ([DummyProxy("a"), DummyProxy("b"), 90, [DummyProxy("c"), [DummyProxy("d")]]], "[a#b#int90#[c#[d#]]]"), + ], ) def test_sequence_hash(data, expected): - assert(aut_utils.sequence_hash(data) == expected) + assert aut_utils.sequence_hash(data) == expected + -@pytest.mark.parametrize("data,expected", +@pytest.mark.parametrize( + "data,expected", [ - ([DummyProxy('a'), "b"], '[a#b#]'), - ] + ([DummyProxy("a"), "b"], "[a#b#]"), + ], ) def test_sequence_hash_bad_input(data, expected): with pytest.raises(AssertionError): assert aut_utils.sequence_hash(data) == expected -@pytest.mark.parametrize("data,expected_sum,expected_others", +@pytest.mark.parametrize( + "data,expected_sum,expected_others", [ - ([nvfuserex, torch_compile_ex], Executor(name='empty'), Executor(name='empty')), - ([nvfuserex, torchex], torchex, Executor(name='empty')), - ] + ([nvfuserex, torch_compile_ex], Executor(name="empty"), Executor(name="empty")), + ([nvfuserex, torchex], torchex, Executor(name="empty")), + ], ) -def test_first_available_operator_executor(data,expected_sum,expected_others): +def test_first_available_operator_executor(data, expected_sum, expected_others): def fn(a: torch.Tensor, b: torch.Tensor): return a + b - a = torch.randn(1,1) - b = torch.randn(1,1) + a = torch.randn(1, 1) + b = torch.randn(1, 1) jitted = thunder.jit(fn) jitted(a, b) trace = thunder.last_traces(jitted)[-1] for bsym in trace.bound_symbols: if bsym.sym.id == PrimIDs.ADD: - assert aut_utils.get_first_available_operator_executor(bsym = bsym, executors = data, empty_hash = 'empty') == expected_sum + assert ( + aut_utils.get_first_available_operator_executor(bsym=bsym, executors=data, empty_hash="empty") + == expected_sum + ) else: - assert aut_utils.get_first_available_operator_executor(bsym = bsym, executors = data, empty_hash = 'empty') == expected_others + assert ( + aut_utils.get_first_available_operator_executor(bsym=bsym, executors=data, empty_hash="empty") + == expected_others + ) -@pytest.mark.parametrize("test,expected", + +@pytest.mark.parametrize( + "test,expected", [ ([1, 2, 3], [1, 2, 3]), ([1, 2, [3, 4]], [1, 2, 3, 4]), ([1, 2, [3, 4, [None]]], [1, 2, 3, 4]), - ] + ], ) def test_flatten_sequence(test, expected): assert aut_utils.flatten_sequence(test) == expected @@ -77,26 +93,28 @@ def fn(a: torch.Tensor, b: torch.Tensor): t3 = a / b return (a + b) + t2 - a = torch.randn(1,1) - b = torch.randn(1,1) + a = torch.randn(1, 1) + b = torch.randn(1, 1) jitted = thunder.jit(fn, disable_dce=True) jitted(a, b) trace = thunder.last_traces(jitted)[-1] not_used = aut_utils.get_not_used_intermediate_outsputs(trace) # We have not used t1, t3 in trace - not_used_labels = ['t1', 't3'] + not_used_labels = ["t1", "t3"] assert len(not_used) == 2 for t in not_used: assert t.name in not_used_labels not_used_labels.remove(t.name) + def _assign_executors_fn(a: torch.Tensor): t0 = a * 2 t1 = a * a t3 = t0 + t1 return t3 + @pytest.mark.parametrize( "fn, args, executors", [ @@ -119,7 +137,7 @@ def _assign_executors_fn(a: torch.Tensor): _assign_executors_fn, torch.randn(1, 1), [Executor("empty"), torchex, torch_compile_ex, torch_compile_ex, Executor("empty")], - ) + ), ], ) def test_assign_executors(fn, args, executors): @@ -131,9 +149,9 @@ def test_assign_executors(fn, args, executors): def _id(bsym: BoundSymbol): res = bsym.sym.name if isinstance(bsym.output, Sequence): - res += '#' + aut_utils.sequence_hash(bsym.output) + res += "#" + aut_utils.sequence_hash(bsym.output) else: - res += '#' + bsym.output.name + res += "#" + bsym.output.name return res @@ -141,8 +159,7 @@ def _id(bsym: BoundSymbol): executor_map = { _id(b): e if e.name != "empty" else None for b, e in zip(trace.bound_symbols, executors) - if b.output is not None and b.sym.id != PrimIDs.RETURN - and b.args is not None + if b.output is not None and b.sym.id != PrimIDs.RETURN and b.args is not None } for b in placed.bound_symbols: @@ -156,6 +173,7 @@ def _id(bsym: BoundSymbol): identif = _id(b) assert b.sym.executor == executor_map[identif] + class Linear(torch.nn.Module): def __init__(self, a, b) -> None: super().__init__() @@ -164,6 +182,7 @@ def __init__(self, a, b) -> None: def forward(self, x): return self.linear(x) + class Matmul(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -171,27 +190,13 @@ def __init__(self) -> None: def forward(self, x): return x @ x + @pytest.mark.parametrize( "model, x, op, expected", [ - ( - Linear(8,8), - torch.randn(8,8), - 'linear', - True - ), - ( - Linear(8,8), - torch.randn(8,8), - 'add', - False - ), - ( - Matmul(), - torch.randn(8,8), - 'matmul', - True - ), + (Linear(8, 8), torch.randn(8, 8), "linear", True), + (Linear(8, 8), torch.randn(8, 8), "add", False), + (Matmul(), torch.randn(8, 8), "matmul", True), ], ) def test_operation_in_trace(model, x, op, expected): @@ -202,6 +207,7 @@ def test_operation_in_trace(model, x, op, expected): assert aut_utils.operation_in_trace(trace=trace, op=op) == expected + class Sdpa(torch.nn.Module): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -209,29 +215,177 @@ def __init__(self, *args, **kwargs) -> None: def forward(self, q, k, v): return torch.nn.functional.scaled_dot_product_attention(q, k, v) -@pytest.mark.parametrize("device,", ["cuda"]) + +@pytest.mark.parametrize( + "model, q, k, v, executors, expected", + [ + ( + Sdpa(), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), + ["cudnn", "sdpa", "fa3"], + 1, + ), + ( + Sdpa(), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), + ["cudnn", "sdpa"], + 1, + ), + ( + Sdpa(), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), + [ + "cudnn", + ], + 1, + ), + ( + Sdpa(), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), + [], + 0, + ), + ], +) @requiresCUDA -def test_sdpa(device: str): - batch = 10 - seq_len = 128 - num_heads = 4 - dim_per_head = 32 +# Currently these executors are: cudnn, spda, fa3, TE +def test_update_compile_options_executor_list_after_fw_bw_split(model, q, k, v, executors, expected): + jitted = thunder.jit(model, autotune_type="runtime", executors=executors) + jitted(q, k, v) + + assigned: Sequence[Executor] = thunder.executors_applied(jitted) + + count = 0 + for ex in assigned: + count += 1 if ex.name in executors else 0 + + assert count == expected + + +def _test_transform_proxy_to_torch_fn_1(a: torch.Tensor, b: torch.Tensor, k: int): + t0 = a * b + return t0 * k + + +def _test_transform_proxy_to_torch_fn_2( + a: torch.Tensor, b: torch.Tensor, c: tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]] +): + t0 = c[0] + c[1][0] + t1 = t0 * c[1][1] + return t1 - a + b - query = torch.randn([batch, seq_len, num_heads, dim_per_head], dtype=torch.float16, device=device, requires_grad=True) - key = torch.randn([batch, seq_len, num_heads, dim_per_head], dtype=torch.float16, device=device, requires_grad=True) - value = torch.randn([batch, seq_len, num_heads, dim_per_head], dtype=torch.float16, device=device, requires_grad=True) - model = Sdpa() - executors = ['cudnn', 'sdpa'] - jitted = thunder.jit(model, autotune_type='runtime', executors=executors) - jitted(query, key, value) +def _test_transform_proxy_to_torch_common( + fn: Callable, torch_args: tuple, executors: list, has_backward: bool, **kwargs +): + jitted = thunder.jit(fn, executors=executors) + jitted(*torch_args) - exs: Sequence[Executor] = thunder.executors_applied(jitted) + trace_static_args = thunder.last_traces(jitted)[-1].args + assert trace_static_args - sdpa_executors_occurences = 0 - for ex in exs: - if ex.name in executors: - sdpa_executors_occurences += 1 + transformed_args = aut_utils.transform_proxy_to_torch(trace_static_args, **kwargs) - assert sdpa_executors_occurences == 1 + assert isinstance(transformed_args, list) + def _comp(thunder_seq: Sequence, torch_seq: Sequence): + assert len(thunder_seq) == len(torch_seq) + + for a, b in zip(thunder_seq, torch_seq): + if isinstance(a, TensorProxy): + # handle TE fp32 + # Static type for fp8 is torch.float8 but the runtime is TE Float8 if TE is being used + if a.dtype.bytes == 1 and kwargs.get("te_used"): + assert b.dtype == torch.float32 + else: + assert b.dtype == to_torch_dtype(a.dtype) + assert a.device.device_str() == str(b.device) + assert a.shape == b.shape + assert a.requires_grad == b.requires_grad + elif isinstance(a, IntegerProxy) or isinstance(a, FloatProxy): + assert a.value == b + + if isinstance(a, Sequence): + assert isinstance(b, Sequence) + _comp(a, b) + + _comp(trace_static_args, transformed_args) + + if has_backward: + trace_static_args = thunder.last_backward_traces(jitted)[-1].args + assert trace_static_args + + transformed_args = aut_utils.transform_proxy_to_torch(trace_static_args, **kwargs) + print(trace_static_args) + # print(transformed_args) + + _comp(trace_static_args, transformed_args) + + +@pytest.mark.parametrize( + "fn, torch_args, executors, has_backward", + [ + (_test_transform_proxy_to_torch_fn_1, tuple([torch.randn(1, 1), torch.randn(1, 1), 10]), [], False), + ( + _test_transform_proxy_to_torch_fn_2, + tuple([torch.randn(1, 1), torch.randn(1, 1), (torch.randn(1, 1), (torch.randn(1, 1), torch.rand(1, 1)))]), + [], + False, + ), + ( + Sdpa(), + ( + torch.randn([10, 128, 4, 32], dtype=torch.float16, requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, requires_grad=True), + ), + [], + True, + ), + ], +) +def test_transform_proxy_to_torch(fn: Callable, torch_args: tuple, executors: list, has_backward: bool): + _test_transform_proxy_to_torch_common(fn, torch_args, executors, has_backward) + + +@requiresCUDA +def test_transform_proxy_to_torch_TE(): + class Model(torch.nn.Module): + def __init__(self, in_features, out_features) -> None: + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features) + + def forward(self, x: torch.Tensor): + return self.linear(x) + + model = Model(4096, 4096) + model.to("cuda") + + _test_transform_proxy_to_torch_common( + model, + tuple([torch.randn(4096, 4096, requires_grad=True, device="cuda")]), + ["transformer_engine"], + True, + te_used=True, + ) + + +@pytest.mark.parametrize( + "executors, expected", + [ + (["python"], ["nvfuser", "python"]), + (["nvfuser", "cudnn"], ["cudnn", "nvfuser"]), + (["torch", "nvfuser", "sdpa"], ["sdpa", "torch", "nvfuser"]), + (["transformer_engine", "nvfuser", "sdpa"], ["transformer_engine", "sdpa", "nvfuser"]), + ], +) +def test_reorder_executors_list(executors, expected): + assert aut_utils.reorder_executors_list(executors) == expected From e33f29c07ecd37554ebc2a9e94cdf7f41179b66a Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 27 Aug 2024 11:41:30 +0300 Subject: [PATCH 105/171] Formatter --- examples/dev/MLP.py | 36 +++-- examples/dev/conv2d_relu.py | 5 +- examples/dev/litGPT.py | 65 +++++---- examples/dev/nanogpt-block.py | 83 +++++++----- examples/dev/nanogpt.py | 95 ++++++++------ examples/dev/nvfuser_optimizations.py | 23 ++-- examples/dev/nvmath.py | 182 ++++++++++++++++++++++++++ examples/dev/sdpa.py | 37 +++--- examples/dev/sdpa_linear.py | 74 +++++++---- examples/dev/simple.py | 20 ++- examples/dev/te.py | 25 ++-- examples/dev/test_del.py | 7 +- thunder/backend_optimizer/utils.py | 99 +++++++++----- thunder/benchmarks/utils.py | 83 +++++++----- 14 files changed, 582 insertions(+), 252 deletions(-) create mode 100644 examples/dev/nvmath.py diff --git a/examples/dev/MLP.py b/examples/dev/MLP.py index 0f3317a328..0f249ee200 100644 --- a/examples/dev/MLP.py +++ b/examples/dev/MLP.py @@ -1,7 +1,13 @@ import torch import torch.nn as nn import thunder -from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_fw_bw_benchmark_nvsight, torch_total_benchmark +from thunder.benchmarks.utils import ( + thunder_fw_bw_benchmark, + torch_fw_bw_benchmark, + torch_fw_bw_benchmark_nvsight, + torch_total_benchmark, +) + class ModelConfig: def __init__(self, n_embd=256, n_head=8, dropout=0.1, block_size=64, bias=True): @@ -11,12 +17,13 @@ def __init__(self, n_embd=256, n_head=8, dropout=0.1, block_size=64, bias=True): self.bias = bias self.block_size = block_size + class MLP(nn.Module): def __init__(self, config): super().__init__() - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) - self.gelu = nn.GELU() - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + self.gelu = nn.GELU() + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) self.dropout = nn.Dropout(config.dropout) def forward(self, x): @@ -26,7 +33,8 @@ def forward(self, x): x = self.dropout(x) return x -with torch.device('cuda'): + +with torch.device("cuda"): embeddings = 3072 config = ModelConfig(n_embd=embeddings, dropout=0.0, bias=False) dtype = torch.float32 @@ -36,19 +44,24 @@ def forward(self, x): jmodel_def = thunder.jit(model) # This model fails under some circumstances after passed the placed traced under the rematelizer - jmodel_auto = thunder.jit(model, autotune_type='runtime', executors = ['nvfuser', 'torchcompile', 'sdpa', 'torch', 'python'], use_cudagraphs=False) + jmodel_auto = thunder.jit( + model, + autotune_type="runtime", + executors=["nvfuser", "torchcompile", "sdpa", "torch", "python"], + use_cudagraphs=False, + ) - print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) - print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) + print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item()) + print("deviation auto:", (jmodel_auto(x) - model(x)).abs().max().item()) iters = 100 callables = [jmodel_auto, jmodel_def] - labels = ['auto', 'def'] + labels = ["auto", "def"] inputs = [x, x] - print('Results with torch total benchmark:') + print("Results with torch total benchmark:") torch_total_benchmark(callables, labels, inputs, iters) - print('Results with thunder benchmark:') + print("Results with thunder benchmark:") fw_traces = [ thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], @@ -64,4 +77,3 @@ def forward(self, x): # for t in traces: # print(t) # print('##########################') - diff --git a/examples/dev/conv2d_relu.py b/examples/dev/conv2d_relu.py index 2089184082..c7b9386a26 100644 --- a/examples/dev/conv2d_relu.py +++ b/examples/dev/conv2d_relu.py @@ -1,6 +1,7 @@ import torch import thunder + class Module(torch.nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1) -> None: super().__init__() @@ -14,7 +15,8 @@ def forward(self, x: torch.Tensor): d = self.relu(b * a) return c + d -with torch.device('cuda'): + +with torch.device("cuda"): model = Module(16, 33, 3, stride=2) x = torch.randn(20, 16, 50, 100) @@ -27,4 +29,3 @@ def forward(self, x: torch.Tensor): # print('##############################################') # print('---------------------------------------------- ans') # print(ans) - diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 363f548ecb..839b02b118 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -1,15 +1,29 @@ from litgpt import GPT -from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_fw_bw_benchmark_nvsight, torch_total_benchmark +from thunder.benchmarks.utils import ( + thunder_fw_bw_benchmark, + torch_fw_bw_benchmark, + torch_fw_bw_benchmark_nvsight, + torch_total_benchmark, +) from thunder.tests.litgpt_model import Config import thunder import torch from thunder.executors.nvmathex import nvmath_ex -torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul -torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn +torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul +torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + class Test: - def __init__(self, layers: int, autotune_type: str, batch_size: int, seq_len: int = -1, model_name: str = 'Llama-3-8B', executors = None) -> None: + def __init__( + self, + layers: int, + autotune_type: str, + batch_size: int, + seq_len: int = -1, + model_name: str = "Llama-3-8B", + executors=None, + ) -> None: self.layers = layers self.autotune_type = autotune_type self.batch_size = batch_size @@ -17,6 +31,7 @@ def __init__(self, layers: int, autotune_type: str, batch_size: int, seq_len: in self.model_name = model_name self.executors = executors + layers = [ Test( 1, @@ -34,7 +49,12 @@ def __init__(self, layers: int, autotune_type: str, batch_size: int, seq_len: in 1, "runtime", 1, - executors=["cudnn", "sdpa", "nvfuser", "torchcompile",], + executors=[ + "cudnn", + "sdpa", + "nvfuser", + "torchcompile", + ], ), Test( 1, @@ -52,10 +72,10 @@ def __init__(self, layers: int, autotune_type: str, batch_size: int, seq_len: in cfg.block_size = test.seq_len torch.set_default_dtype(torch.bfloat16) print(cfg) - with torch.device('cuda'): + with torch.device("cuda"): model = GPT(cfg) x = torch.randint(1, model.config.vocab_size, (test.batch_size, cfg.block_size)) - print(f'Input size: {x.size()}') + print(f"Input size: {x.size()}") jmodel_def = thunder.jit(model) jmodel_auto = thunder.jit( @@ -65,8 +85,8 @@ def __init__(self, layers: int, autotune_type: str, batch_size: int, seq_len: in use_cudagraphs=False, ) - print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) - print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) + print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item()) + print("deviation auto:", (jmodel_auto(x) - model(x)).abs().max().item()) iters = 100 fw_traces = [ @@ -79,31 +99,32 @@ def __init__(self, layers: int, autotune_type: str, batch_size: int, seq_len: in ] fw_labels = ["fw_def", "fw_auto"] bw_labels = ["bw_def", "bw_auto"] - print(f'Results thunder benchmark ({iters} iters):') + print(f"Results thunder benchmark ({iters} iters):") thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters, nvsight=True) thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, 10, nvsight=True) print(test.model_name) - print(f'\n\nResults torch fw bw benchmark ({iters} iters):') + print(f"\n\nResults torch fw bw benchmark ({iters} iters):") callables = [jmodel_def, jmodel_auto] - labels = ['def', 'auto'] + labels = ["def", "auto"] inputs = [x, x] torch_fw_bw_benchmark(callables, labels, inputs, iters) - print(f'\n\nResults torch total benchmark ({iters} iters):') + print(f"\n\nResults torch total benchmark ({iters} iters):") torch_total_benchmark(callables, labels, inputs, iters) torch_fw_bw_benchmark_nvsight(callables, labels, inputs, iters) - print('\n\n\n\n\n\n') - print(f'{thunder.last_traces(jmodel_def)[-1]}') - print('###############################################################################') - print(f'{thunder.last_traces(jmodel_auto)[-1]}') + print("\n\n\n\n\n\n") + print(f"{thunder.last_traces(jmodel_def)[-1]}") + print("###############################################################################") + print(f"{thunder.last_traces(jmodel_auto)[-1]}") - print('\n\n') - print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') - print('###############################################################################') - print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') + print("\n\n") + print(f"{thunder.last_backward_traces(jmodel_def)[-1]}") + print("###############################################################################") + print(f"{thunder.last_backward_traces(jmodel_auto)[-1]}") except Exception as e: - print(f'Test failed:\n{e}') + print(f"Test failed:\n{e}") import traceback + traceback.print_exc() diff --git a/examples/dev/nanogpt-block.py b/examples/dev/nanogpt-block.py index 2992525296..f1f1693e22 100644 --- a/examples/dev/nanogpt-block.py +++ b/examples/dev/nanogpt-block.py @@ -7,8 +7,9 @@ # torch.set_default_dtype(torch.bfloat16) + class LayerNorm(nn.Module): - """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ + """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" def __init__(self, ndim, bias): super().__init__() @@ -18,8 +19,8 @@ def __init__(self, ndim, bias): def forward(self, input): return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) -class CausalSelfAttention(nn.Module): +class CausalSelfAttention(nn.Module): def __init__(self, config): super().__init__() assert config.n_embd % config.n_head == 0 @@ -34,46 +35,52 @@ def __init__(self, config): self.n_embd = config.n_embd self.dropout = config.dropout # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 - self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') + self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") if not self.flash: print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") # causal mask to ensure that attention is only applied to the left in the input sequence - self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) - .view(1, 1, config.block_size, config.block_size)) + self.register_buffer( + "bias", + torch.tril(torch.ones(config.block_size, config.block_size)).view( + 1, 1, config.block_size, config.block_size + ), + ) def forward(self, x): - B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim - q, k, v = self.c_attn(x).split(self.n_embd, dim=2) - k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) if self.flash: # efficient attention using Flash Attention CUDA kernels - y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) + y = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True + ) else: # manual implementation of attention att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side # output projection y = self.resid_dropout(self.c_proj(y)) return y -class MLP(nn.Module): +class MLP(nn.Module): def __init__(self, config): super().__init__() - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) - self.gelu = nn.GELU() - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) + self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) + self.gelu = nn.GELU() + self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) self.dropout = nn.Dropout(config.dropout) def forward(self, x): @@ -83,17 +90,18 @@ def forward(self, x): x = self.dropout(x) return x + class GPTConfig: block_size: int = 1024 - vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency n_layer: int = 12 n_head: int = 12 n_embd: int = 3072 dropout: float = 0.0 - bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster -class Block(nn.Module): +class Block(nn.Module): def __init__(self, config): super().__init__() self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) @@ -106,19 +114,22 @@ def forward(self, x): x = x + self.mlp(self.ln_2(x)) return x -with torch.device('cuda'): + +with torch.device("cuda"): config = GPTConfig() model = Block(config) x = torch.randn((16, 1024, 3072), dtype=torch.float32) jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type='runtime', executors = ['nvfuser', 'torchcompile', 'sdpa', 'cudnn', 'torch', 'python']) + jmodel_auto = thunder.jit( + model, autotune_type="runtime", executors=["nvfuser", "torchcompile", "sdpa", "cudnn", "torch", "python"] + ) - print('deviation def:', (jmodel_def(x) - model(x)).abs().max().item()) - print('deviation auto:', (jmodel_auto(x) - model(x)).abs().max().item()) + print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item()) + print("deviation auto:", (jmodel_auto(x) - model(x)).abs().max().item()) iters = 100 - print('Results thunder benchmark:') + print("Results thunder benchmark:") fw_traces = [ thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], @@ -131,18 +142,18 @@ def forward(self, x): bw_labels = ["bw_def", "bw_auto"] thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) - print('\n\nResults torch fw bw benchmark:') + print("\n\nResults torch fw bw benchmark:") callables = [jmodel_def, jmodel_auto] - labels = ['def', 'auto'] + labels = ["def", "auto"] inputs = [x, x] torch_fw_bw_benchmark(callables, labels, inputs, 100) - print('\n\n\n\n\n\n') - print(f'{thunder.last_traces(jmodel_def)[-1]}') - print('###############################################################################') - print(f'{thunder.last_traces(jmodel_auto)[-1]}') + print("\n\n\n\n\n\n") + print(f"{thunder.last_traces(jmodel_def)[-1]}") + print("###############################################################################") + print(f"{thunder.last_traces(jmodel_auto)[-1]}") - print('\n\n') - print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') - print('###############################################################################') - print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') + print("\n\n") + print(f"{thunder.last_backward_traces(jmodel_def)[-1]}") + print("###############################################################################") + print(f"{thunder.last_backward_traces(jmodel_auto)[-1]}") diff --git a/examples/dev/nanogpt.py b/examples/dev/nanogpt.py index bcf7e91781..d70a452f85 100644 --- a/examples/dev/nanogpt.py +++ b/examples/dev/nanogpt.py @@ -6,33 +6,36 @@ warm_up_iters = 50 -def run(target: str = 'runtime'): - if target != 'runtime' and target != 'memory': - raise AssertionError(f'Target {target} not supported. Only runtime and memory available') + +def run(target: str = "runtime"): + if target != "runtime" and target != "memory": + raise AssertionError(f"Target {target} not supported. Only runtime and memory available") # ----------------------------------------------------------------------------- batch_size = 12 block_size = 1024 bias = False real_data = False seed = 1337 - device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. - dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' - compile = False # use PyTorch 2.0 to compile the model to be faster - profile = False # use pytorch profiler, or just simple benchmarking? + device = "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. + dtype = ( + "bfloat16" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "float16" + ) # 'float32' or 'bfloat16' or 'float16' + compile = False # use PyTorch 2.0 to compile the model to be faster + profile = False # use pytorch profiler, or just simple benchmarking? # exec(open('configurator.py').read()) # overrides from command line or config file # ----------------------------------------------------------------------------- torch.manual_seed(seed) torch.cuda.manual_seed(seed) - torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul - torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn - device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast - ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] - ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) + torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul + torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast + ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype] + ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=ptdtype) # data loading init if real_data: - raise RuntimeError('Not supported') + raise RuntimeError("Not supported") else: # alternatively, if fixed data is desired to not care about data loading x = torch.randint(50304, (batch_size, block_size), device=device) @@ -41,20 +44,27 @@ def run(target: str = 'runtime'): # model init gptconf = GPTConfig( - block_size = block_size, # how far back does the model look? i.e. context size - n_layer = 4, n_head = 12, n_embd = 768, # size of the model - dropout = 0, # for determinism - bias = bias, + block_size=block_size, # how far back does the model look? i.e. context size + n_layer=4, + n_head=12, + n_embd=768, # size of the model + dropout=0, # for determinism + bias=bias, ) model = GPT(gptconf) model.to(device) jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type=target, executors = ['torchcompile', 'nvfuser', 'cudnn', 'sdpa', 'transformer_engine'], use_cudagraphs=False) + jmodel_auto = thunder.jit( + model, + autotune_type=target, + executors=["torchcompile", "nvfuser", "cudnn", "sdpa", "transformer_engine"], + use_cudagraphs=False, + ) if compile: print("Compiling model...") - model = torch.compile(model) # pytorch 2.0 + model = torch.compile(model) # pytorch 2.0 if profile: # useful docs on pytorch profiler: @@ -65,28 +75,27 @@ def run(target: str = 'runtime'): with torch.profiler.profile( activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1), - on_trace_ready=torch.profiler.tensorboard_trace_handler('./bench_log'), + on_trace_ready=torch.profiler.tensorboard_trace_handler("./bench_log"), record_shapes=False, profile_memory=False, - with_stack=False, # incurs an additional overhead, disable if not needed + with_stack=False, # incurs an additional overhead, disable if not needed with_flops=True, - with_modules=False, # only for torchscript models atm + with_modules=False, # only for torchscript models atm ) as prof: - models = [jmodel_def, jmodel_auto] for mod in models: - print('Profiling new model') - X, Y = get_batch('train') + print("Profiling new model") + X, Y = get_batch("train") for k in range(num_steps): with ctx: _, loss = mod(X, Y) - X, Y = get_batch('train') + X, Y = get_batch("train") loss.backward() lossf = loss.item() print(f"{k}/{num_steps} loss: {lossf:.4f}") - prof.step() # notify the profiler at end of each step + prof.step() # notify the profiler at end of each step else: # simple benchmarking @@ -95,7 +104,7 @@ def measure(m, label): torch.cuda.synchronize() for i in range(warm_up_iters): - X, Y = get_batch('train') + X, Y = get_batch("train") with ctx: _, loss = m(X, Y) loss.backward() @@ -107,7 +116,7 @@ def measure(m, label): for i in range(iters): torch.cuda.empty_cache() torch.cuda._sleep(1_000_000) - X, Y = get_batch('train') + X, Y = get_batch("train") start_events[i].record(stream) with ctx: _, loss = m(X, Y) @@ -117,13 +126,13 @@ def measure(m, label): torch.cuda.synchronize() tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] tot_time = sum(tot) / iters - print('\n\nResults torch benchmark:') - print(f'{label} tot time: {tot_time} ms') + print("\n\nResults torch benchmark:") + print(f"{label} tot time: {tot_time} ms") def measure_nvsight(m, label): # Warm up for _ in range(warm_up_iters): - X, Y = get_batch('train') + X, Y = get_batch("train") with ctx: _, loss = m(X, Y) loss.backward() @@ -134,7 +143,7 @@ def measure_nvsight(m, label): # Perform less iterations for _ in range(20): torch.cuda.empty_cache() - X, Y = get_batch('train') + X, Y = get_batch("train") torch.cuda.nvtx.range_push(f"{label}: fw-bw") with ctx: _, loss = m(X, Y) @@ -142,10 +151,10 @@ def measure_nvsight(m, label): torch.cuda.nvtx.range_pop() torch.cuda.cudart().cudaProfilerStop() - measure(jmodel_auto, 'auto') - measure(jmodel_def, 'def') + measure(jmodel_auto, "auto") + measure(jmodel_def, "def") - print('\n\nResults thunder benchmark:') + print("\n\nResults thunder benchmark:") fw_traces = [ thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], @@ -158,11 +167,17 @@ def measure_nvsight(m, label): bw_labels = ["bw_def", "bw_auto"] thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, 100) - measure_nvsight(jmodel_def, 'def') - measure_nvsight(jmodel_auto, 'auto') + measure_nvsight(jmodel_def, "def") + measure_nvsight(jmodel_auto, "auto") - traces = [thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], thunder.last_backward_traces(jmodel_def)[-1], thunder.last_backward_traces(jmodel_auto)[-1]] + traces = [ + thunder.last_traces(jmodel_def)[-1], + thunder.last_traces(jmodel_auto)[-1], + thunder.last_backward_traces(jmodel_def)[-1], + thunder.last_backward_traces(jmodel_auto)[-1], + ] for t in traces: - print(f'{t}\n############################################') + print(f"{t}\n############################################") + run() diff --git a/examples/dev/nvfuser_optimizations.py b/examples/dev/nvfuser_optimizations.py index abd3fe09bc..720eaa7a0a 100644 --- a/examples/dev/nvfuser_optimizations.py +++ b/examples/dev/nvfuser_optimizations.py @@ -1,6 +1,12 @@ import torch import thunder -from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_fw_bw_benchmark_nvsight, torch_total_benchmark +from thunder.benchmarks.utils import ( + thunder_fw_bw_benchmark, + torch_fw_bw_benchmark, + torch_fw_bw_benchmark_nvsight, + torch_total_benchmark, +) + class Module(torch.nn.Module): def __init__(self, in_features, out_features) -> None: @@ -9,7 +15,7 @@ def __init__(self, in_features, out_features) -> None: torch.nn.Linear(in_features, out_features), torch.nn.Linear(out_features, in_features), torch.nn.Linear(in_features, out_features), - torch.nn.Linear(out_features, in_features) + torch.nn.Linear(out_features, in_features), ) self.silu = torch.nn.SiLU() @@ -20,7 +26,8 @@ def forward(self, x: torch.Tensor): c = c @ torch.transpose(c, 0, 1) return self.silu(c) -with torch.device('cuda'): + +with torch.device("cuda"): in_features = 1 << 8 out_features = 1 << 10 model = Module(in_features, out_features) @@ -33,7 +40,7 @@ def forward(self, x: torch.Tensor): y = jmodel_auto(x) iters = 100 - print('Results thunder benchmark:') + print("Results thunder benchmark:") fw_traces = [ thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], @@ -48,14 +55,14 @@ def forward(self, x: torch.Tensor): # thunder_fw_bw_benchmark(traces, labels, iters, nvsight=True) callables = [jmodel_def, jmodel_auto] - labels = ['def', 'auto'] + labels = ["def", "auto"] inputs = [x, x] - print('Results torch benchmark:') + print("Results torch benchmark:") torch_fw_bw_benchmark(callables, labels, inputs, iters) torch_total_benchmark(callables, labels, inputs, iters) torch_fw_bw_benchmark_nvsight(callables, labels, inputs, iters) for t in fw_traces: - print(f'{t}\n#########################################') + print(f"{t}\n#########################################") for t in bw_traces: - print(f'{t}\n#########################################') + print(f"{t}\n#########################################") diff --git a/examples/dev/nvmath.py b/examples/dev/nvmath.py new file mode 100644 index 0000000000..db6d565038 --- /dev/null +++ b/examples/dev/nvmath.py @@ -0,0 +1,182 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. +# +# SPDX-License-Identifier: Apache-2.0 + +""" +This example demonstrates basic matrix multiplication of torch tensors. + +nvmath-python supports multiple frameworks. The result of each operation is a tensor of the same +framework that was used to pass the inputs. It is also located on the same device as the inputs. +""" + +import torch +import nvmath +import thunder + +torch.set_default_dtype(torch.bfloat16) +torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul +torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + +a = torch.randn(128256, 8192, device="cuda") +b = torch.randn(8192, 4096, device="cuda") + +iters = 20 +start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] +end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] +stream = torch.cuda.current_stream() +for i in range(iters): + result_a = torch.matmul(a, b) +torch.cuda.default_stream().synchronize() +for i in range(iters): + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + start_events[i].record(stream) + result_a = torch.matmul(a, b) + end_events[i].record(stream) +torch.cuda.default_stream().synchronize() +tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] +tot_time = sum(tot) / iters +print(f"torch tot time: {tot_time} ms") + +start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] +end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] +stream = torch.cuda.current_stream() +torch.cuda.default_stream().synchronize() +for i in range(iters): + result_b = nvmath.linalg.advanced.matmul(a, b) +torch.cuda.default_stream().synchronize() +for i in range(iters): + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + start_events[i].record(stream) + result_b = nvmath.linalg.advanced.matmul(a, b) + end_events[i].record(stream) +torch.cuda.default_stream().synchronize() +tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] +tot_time = sum(tot) / iters +print(f"nvmath tot time: {tot_time} ms") + +with nvmath.linalg.advanced.Matmul(a, b) as mm: + # Plan. + mm.plan() + + # Inspect the algorithms proposed. + print( + f"Planning returned {len(mm.algorithms)} algorithms. The capabilities of the best one are:", + ) + best = mm.algorithms[0] + print(best.capabilities) + + # Modify the tiling configuration of the algorithm. Note that the valid tile configuration depends on + # the hardware, and not all combinations of the configuration are supported, so we leave it as an exercise. + print(f"Tiling {best.tile}") + # Execute the multiplication. + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + stream = torch.cuda.current_stream() + torch.cuda.default_stream().synchronize() + for i in range(iters): + result_c = mm.execute() + torch.cuda.default_stream().synchronize() + for i in range(iters): + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + start_events[i].record(stream) + result_c = mm.execute() + end_events[i].record(stream) + torch.cuda.default_stream().synchronize() + tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + tot_time = sum(tot) / iters + print(f"nvmath tuned tot time: {tot_time} ms") + +with nvmath.linalg.advanced.Matmul(a, b) as mm: + # Plan. + mm.plan() + + # Autotune + mm.autotune(iterations=5) + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + stream = torch.cuda.current_stream() + torch.cuda.default_stream().synchronize() + for i in range(iters): + result_d = mm.execute() + torch.cuda.default_stream().synchronize() + for i in range(iters): + torch.cuda.empty_cache() + torch.cuda._sleep(1_000_000) + start_events[i].record(stream) + result_d = mm.execute() + end_events[i].record(stream) + torch.cuda.default_stream().synchronize() + tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + tot_time = sum(tot) / iters + print(f"nvmath tuned tot time: {tot_time} ms") + + +print("deviation:", (result_a - result_b).abs().max().item()) +print("deviation:", (result_a - result_c).abs().max().item()) +print("deviation:", (result_a - result_d).abs().max().item()) + +# from thunder.benchmarks.utils import thunder_fw_bw_benchmark +# from thunder.executors.nvmathex import nvmath_ex + +# class Module(torch.nn.Module): +# def __init__(self, *args, **kwargs) -> None: +# super().__init__(*args, **kwargs) + +# def forward(self, a, b): +# return a @ b + +# with torch.device('cuda'): + +# model = Module() +# jmodel_def = thunder.jit(model) +# jmodel_auto = thunder.jit(model, autotune_type="runtime", executors = [nvmath_ex], use_cudagraphs=False) +# a = torch.randn(128256, 128, requires_grad=True) +# b = torch.randn(128, 4096, requires_grad=True) + +# print('deviation def:', (jmodel_def(a, b) - model(a, b)).abs().max().item()) +# print('deviation auto:', (jmodel_auto(a, b) - model(a, b)).abs().max().item()) + +# from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_fw_bw_benchmark_nvsight, torch_total_benchmark + +# print('Results with thunder benchmark:') +# fw_traces = [ +# thunder.last_traces(jmodel_def)[-1], +# thunder.last_traces(jmodel_auto)[-1], +# ] +# bw_traces = [ +# thunder.last_backward_traces(jmodel_def)[-1], +# thunder.last_backward_traces(jmodel_auto)[-1], +# ] +# fw_labels = ["fw_def", "fw_auto"] +# bw_labels = ["bw_def", "bw_auto"] +# thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, 10) + +# for t in fw_traces: +# print(f'{t}\n################################') +# for t in bw_traces: +# print(f'{t}\n################################') + +# for _ in range(iters): +# result_a = model(a.clone().detach(), b.clone().detach()) +# torch.cuda.default_stream().synchronize() +# s = time.time_ns() +# for i in range(iters): +# result_a = model(a.clone().detach(), b.clone().detach()) +# torch.cuda.default_stream().synchronize() +# e = time.time_ns() +# print('time torch', (e-s)/1000000, 'ms') + +# for _ in range(iters): +# result_b = jmodel(a.clone().detach(), b.clone().detach()) +# torch.cuda.default_stream().synchronize() +# s = time.time_ns() +# for i in range(iters): +# result_b = model(a.clone().detach(), b.clone().detach()) +# torch.cuda.default_stream().synchronize() +# e = time.time_ns() +# print('time nvmath', (e-s)/1000000, 'ms') diff --git a/examples/dev/sdpa.py b/examples/dev/sdpa.py index abe57d8407..e8f666c806 100644 --- a/examples/dev/sdpa.py +++ b/examples/dev/sdpa.py @@ -4,7 +4,8 @@ dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 torch.set_default_dtype(dtype) -print(f'Script data type: {dtype}\n') +print(f"Script data type: {dtype}\n") + class Model(torch.nn.Module): def __init__(self) -> None: @@ -13,18 +14,21 @@ def __init__(self) -> None: def forward(self, query, key, value): a = torch.nn.functional.scaled_dot_product_attention(query, key, value) # Make different inputs as happens in a real model - b = torch.nn.functional.scaled_dot_product_attention(query+query, key+key, value+value) + b = torch.nn.functional.scaled_dot_product_attention(query + query, key + key, value + value) return a + b -with torch.device('cuda'): + +with torch.device("cuda"): model = Model() jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type='runtime', executors = ['nvfuser', 'cudnn', 'sdpa'], use_cudagraphs=False) + jmodel_auto = thunder.jit( + model, autotune_type="runtime", executors=["nvfuser", "cudnn", "sdpa"], use_cudagraphs=False + ) - q = torch.rand(32, 8, 128, 64*1, requires_grad=True) - k = torch.rand(32, 8, 128, 64*1, requires_grad=True) - v = torch.rand(32, 8, 128, 64*1, requires_grad=True) + q = torch.rand(32, 8, 128, 64 * 1, requires_grad=True) + k = torch.rand(32, 8, 128, 64 * 1, requires_grad=True) + v = torch.rand(32, 8, 128, 64 * 1, requires_grad=True) jmodel_def(q, k, v) jmodel_auto(q, k, v) @@ -40,16 +44,15 @@ def forward(self, query, key, value): ] fw_labels = ["fw_def", "fw_auto"] bw_labels = ["bw_def", "bw_auto"] - print('Thunder benchmark:') + print("Thunder benchmark:") thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) - print('\n\n\n\n\n\n') - print(f'{thunder.last_traces(jmodel_def)[-1]}') - print('###############################################################################') - print(f'{thunder.last_traces(jmodel_auto)[-1]}') - - print('\n\n') - print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') - print('###############################################################################') - print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') + print("\n\n\n\n\n\n") + print(f"{thunder.last_traces(jmodel_def)[-1]}") + print("###############################################################################") + print(f"{thunder.last_traces(jmodel_auto)[-1]}") + print("\n\n") + print(f"{thunder.last_backward_traces(jmodel_def)[-1]}") + print("###############################################################################") + print(f"{thunder.last_backward_traces(jmodel_auto)[-1]}") diff --git a/examples/dev/sdpa_linear.py b/examples/dev/sdpa_linear.py index 7627160828..0daa9498f1 100644 --- a/examples/dev/sdpa_linear.py +++ b/examples/dev/sdpa_linear.py @@ -4,6 +4,7 @@ torch.set_default_dtype(torch.float32) + class Model(torch.nn.Module): def __init__(self, inf, outf) -> None: super().__init__() @@ -14,40 +15,63 @@ def forward(self, query, key, value): a = torch.nn.functional.scaled_dot_product_attention(query, key, value) return a -with torch.device('cuda'): + +with torch.device("cuda"): features = 128 model = Model(features, features) jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type='runtime', executors=['nvfuser', 'cudnn', 'sdpa', 'fa3', 'torchcompile']) + jmodel_auto = thunder.jit( + model, autotune_type="runtime", executors=["nvfuser", "cudnn", "sdpa", "fa3", "torchcompile"] + ) q = torch.rand(32, 8, 128, features, requires_grad=True) k = torch.rand(32, 8, 128, features, requires_grad=True) v = torch.rand(32, 8, 128, features, requires_grad=True) - print('deviation def:', (jmodel_def(q, k, v) - model(q, k, v)).abs().max().item()) - print('deviation auto:', (jmodel_auto(q, k, v) - model(q, k, v)).abs().max().item()) - - print('########################################') - c, m, o = benchmark_trace(thunder.last_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='sdpa_def_fw', iters=10) - print(f'Executing default fw trace:\n{c} ms, {m / (2**30)} GB') - c, m, o = benchmark_trace(thunder.last_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='sdpa_auto_fw', iters=10) - print(f'Executing auto fw trace:\n{c} ms, {m / (2**30)} GB') - c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_def)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='sdpa_def_bw', iters=10) - print(f'Executing default bw trace:\n{c} ms, {m / (2**30)} GB') - c, m, o = benchmark_trace(thunder.last_backward_traces(jmodel_auto)[-1], apply_del_last_used=False, snapshot=True, snapshot_name='sdpa_auto_bw', iters=10) - print(f'Executing auto bw trace:\n{c} ms, {m / (2**30)} GB') - - print('\n\n\n\n\n\n') - print(f'{thunder.last_traces(jmodel_def)[-1]}') - print('###############################################################################') - print(f'{thunder.last_traces(jmodel_auto)[-1]}') - - print('\n\n') - print(f'{thunder.last_backward_traces(jmodel_def)[-1]}') - print('###############################################################################') - print(f'{thunder.last_backward_traces(jmodel_auto)[-1]}') - + print("deviation def:", (jmodel_def(q, k, v) - model(q, k, v)).abs().max().item()) + print("deviation auto:", (jmodel_auto(q, k, v) - model(q, k, v)).abs().max().item()) + print("########################################") + c, m, o = benchmark_trace( + thunder.last_traces(jmodel_def)[-1], + apply_del_last_used=False, + snapshot=True, + snapshot_name="sdpa_def_fw", + iters=10, + ) + print(f"Executing default fw trace:\n{c} ms, {m / (2**30)} GB") + c, m, o = benchmark_trace( + thunder.last_traces(jmodel_auto)[-1], + apply_del_last_used=False, + snapshot=True, + snapshot_name="sdpa_auto_fw", + iters=10, + ) + print(f"Executing auto fw trace:\n{c} ms, {m / (2**30)} GB") + c, m, o = benchmark_trace( + thunder.last_backward_traces(jmodel_def)[-1], + apply_del_last_used=False, + snapshot=True, + snapshot_name="sdpa_def_bw", + iters=10, + ) + print(f"Executing default bw trace:\n{c} ms, {m / (2**30)} GB") + c, m, o = benchmark_trace( + thunder.last_backward_traces(jmodel_auto)[-1], + apply_del_last_used=False, + snapshot=True, + snapshot_name="sdpa_auto_bw", + iters=10, + ) + print(f"Executing auto bw trace:\n{c} ms, {m / (2**30)} GB") + print("\n\n\n\n\n\n") + print(f"{thunder.last_traces(jmodel_def)[-1]}") + print("###############################################################################") + print(f"{thunder.last_traces(jmodel_auto)[-1]}") + print("\n\n") + print(f"{thunder.last_backward_traces(jmodel_def)[-1]}") + print("###############################################################################") + print(f"{thunder.last_backward_traces(jmodel_auto)[-1]}") diff --git a/examples/dev/simple.py b/examples/dev/simple.py index be19d6fd62..bc88314cce 100644 --- a/examples/dev/simple.py +++ b/examples/dev/simple.py @@ -2,6 +2,7 @@ import thunder from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark + class Module(torch.nn.Module): def __init__(self, in_features, out_features) -> None: super().__init__() @@ -14,20 +15,27 @@ def forward(self, x: torch.Tensor): c = b * b return self.silu(c) -with torch.device('cuda'): + +with torch.device("cuda"): in_features = 4096 out_features = 11008 model = Module(in_features, out_features) x = torch.randn(128, in_features, requires_grad=True) - jmodel_def = thunder.jit(model, ) - jmodel_auto = thunder.jit(model, autotune_type='runtime', executors=['nvfuser', 'torchcompile', 'cudnn', 'torch', 'python'], ) + jmodel_def = thunder.jit( + model, + ) + jmodel_auto = thunder.jit( + model, + autotune_type="runtime", + executors=["nvfuser", "torchcompile", "cudnn", "torch", "python"], + ) y = jmodel_def(x) y = jmodel_auto(x) iters = 100 - print('Results thunder benchmark:') + print("Results thunder benchmark:") fw_traces = [ thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], @@ -41,7 +49,7 @@ def forward(self, x: torch.Tensor): thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) callables = [jmodel_def, jmodel_auto] - labels = ['def', 'auto'] + labels = ["def", "auto"] inputs = [x, x] - print('Results torch benchmark:') + print("Results torch benchmark:") torch_fw_bw_benchmark(callables, labels, inputs, 50) diff --git a/examples/dev/te.py b/examples/dev/te.py index 1c1d541ab1..6341506454 100644 --- a/examples/dev/te.py +++ b/examples/dev/te.py @@ -1,6 +1,12 @@ import torch import thunder -from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_fw_bw_benchmark_nvsight, torch_total_benchmark +from thunder.benchmarks.utils import ( + thunder_fw_bw_benchmark, + torch_fw_bw_benchmark, + torch_fw_bw_benchmark_nvsight, + torch_total_benchmark, +) + class Module(torch.nn.Module): def __init__(self, in_features, out_features) -> None: @@ -14,18 +20,22 @@ def __init__(self, in_features, out_features) -> None: def forward(self, x: torch.Tensor): return self.linear(x) -with torch.device('cuda'): + +with torch.device("cuda"): m = 1 in_features = 4096 * m out_features = 4096 * m model = Module(in_features, out_features) x = torch.randn(768, in_features, requires_grad=True) - jmodel_def = thunder.jit(model, executors=['transformer_engine'], use_cudagraphs=False) + jmodel_def = thunder.jit(model, executors=["transformer_engine"], use_cudagraphs=False) jmodel_auto = thunder.jit( model, autotune_type="runtime", - executors=["nvfuser", "transformer_engine", ], + executors=[ + "nvfuser", + "transformer_engine", + ], use_cudagraphs=False, ) @@ -43,12 +53,11 @@ def forward(self, x: torch.Tensor): ] fw_labels = ["fw_def", "fw_auto"] bw_labels = ["bw_def", "bw_auto"] - print('Results thunder benchmark:') + print("Results thunder benchmark:") thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) callables = [jmodel_def, jmodel_auto] - labels = ['def', 'auto'] + labels = ["def", "auto"] inputs = [x, x] - print('\n\nResults torch benchmark:') + print("\n\nResults torch benchmark:") torch_total_benchmark(callables, labels, inputs, iters) - diff --git a/examples/dev/test_del.py b/examples/dev/test_del.py index 71aab23cb0..50654edbc3 100644 --- a/examples/dev/test_del.py +++ b/examples/dev/test_del.py @@ -3,8 +3,7 @@ iters = 1000 -with torch.device('cuda'): - +with torch.device("cuda"): tot_time = 0 for i in range(iters): s = time.time_ns() @@ -16,7 +15,7 @@ del b del c torch.cuda.synchronize() - tot_time += (time.time_ns() - s) + tot_time += time.time_ns() - s print(f"With del = {(tot_time / iters) / 1000000}") @@ -28,6 +27,6 @@ c = a + b + a + b c = c * c torch.cuda.synchronize() - tot_time += (time.time_ns() - s) + tot_time += time.time_ns() - s print(f"With no del = {(tot_time / iters) / 1000000}") diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index f58aa18dd9..8add55a3c9 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -12,6 +12,7 @@ from itertools import chain import torch + # Maybe we can use id(s) def sequence_hash(s: Sequence) -> str: def rec(s) -> str: @@ -20,18 +21,19 @@ def rec(s) -> str: if e is None: name += "None#" elif hasattr(e, "name"): - name += e.name + '#' + name += e.name + "#" elif isinstance(e, Sequence) and not isinstance(e, str): name += rec(e) elif isinstance(e, int): - name += 'int' + str(e) + '#' + name += "int" + str(e) + "#" else: raise AssertionError(f"Unsupported type = {type(e)}") - name += ']' + name += "]" return name return rec(s) + def can_executor_execute(ex: Executor, bsym: BoundSymbol) -> bool: try: return ex.can_execute(bsym) @@ -49,6 +51,7 @@ def get_first_available_operator_executor( return ex return Executor(name=empty_hash) + def flatten_sequence(sequence: Sequence) -> list: res = [] for e in sequence: @@ -59,6 +62,7 @@ def flatten_sequence(sequence: Sequence) -> list: res.append(e) return res + def get_not_used_intermediate_outsputs(trace_in: TraceCtx) -> list[Proxy]: def is_in_sequence(seq: Sequence[Any], t: Proxy): for e in seq: @@ -72,7 +76,7 @@ def unpack_output(out) -> Sequence[Proxy]: elif isinstance(out, Sequence): return flatten_sequence(out) else: - raise RuntimeError(f'Unpack operation not defined for {type(out)}') + raise RuntimeError(f"Unpack operation not defined for {type(out)}") ans: list[Proxy] = [] # Currently this is O(max(len(bsym.output)) * N^2) @@ -83,17 +87,14 @@ def unpack_output(out) -> Sequence[Proxy]: for e in unpacked_out: # None values are checked inside the unpack_output fn for b in trace_in.bound_symbols: - if ( - b.args is not None - and isinstance(b.args, Sequence) - and is_in_sequence(b.args, e) - ): + if b.args is not None and isinstance(b.args, Sequence) and is_in_sequence(b.args, e): f = True break if not f: ans.append(e) from thunder.backend_optimizer.optimizer import log, LogLevel - log(f'Returning not used proxies: {[p.name for p in ans]}', level=LogLevel.DEBUG) + + log(f"Returning not used proxies: {[p.name for p in ans]}", level=LogLevel.DEBUG) return ans @@ -269,6 +270,7 @@ def operation_in_trace(*, trace: TraceCtx, op: str) -> bool: return True return False + def is_te_used(trace: TraceCtx) -> bool: from thunder.executors.transformer_engineex import linear_bound_symbol_name_prefix from thunder.executors.transformer_engineex import te_functional_linear_backward_name @@ -281,9 +283,11 @@ def is_te_used(trace: TraceCtx) -> bool: return True return False + def is_backward_trace(trace: TraceCtx) -> bool: sig = trace.signature_with_no_ctx() - return sig.find('backward') >= 0 + return sig.find("backward") >= 0 + def benchmark_trace( trace: TraceCtx, @@ -294,7 +298,7 @@ def benchmark_trace( snapshot_name="", nvsight: bool = False, nvsight_fn_name: str = "", - **kwargs + **kwargs, ) -> tuple[float, float, Any]: from thunder.executors.passes import del_last_used import inspect @@ -325,6 +329,7 @@ def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, f return float("inf"), float("inf"), None except Exception as e: import inspect + trc = inspect.getsource(fn) print(f"#Trace execution failed for nvsight (error: {e}):\n{trc}") raise e @@ -397,17 +402,19 @@ def build_static_args(sequence: Sequence, **kwargs) -> list: return transform_proxy_to_torch(sequence, level=0, **kwargs) def backward_trace_args_preprocess() -> list: - if 'fw_trace' not in kwargs: - raise RuntimeError('Set the associated forward trace in order to benchmark backward pass with sdpa executor') - fw_trace = kwargs.get('fw_trace', None) + if "fw_trace" not in kwargs: + raise RuntimeError( + "Set the associated forward trace in order to benchmark backward pass with sdpa executor" + ) + fw_trace = kwargs.get("fw_trace", None) if not isinstance(fw_trace, TraceCtx): - raise AssertionError(f'forward trace is not a TraceCtx. Received: {type(fw_trace)}') + raise AssertionError(f"forward trace is not a TraceCtx. Received: {type(fw_trace)}") # Run the fw trace and get the outputs fw_output = benchmark_trace(fw_trace, apply_del_last_used=False)[2] # Check if the fw trace is a final trace or an intermediate one (used for single trace region benchmarks) sig = fw_trace.signature_with_no_ctx() - is_fw_final_trace = sig.startswith('def augmented') + is_fw_final_trace = sig.startswith("def augmented") # Filter the C0 tuple # These location might change if the implementation of the automatic @@ -423,7 +430,7 @@ def backward_trace_args_preprocess() -> list: # Now, we expected that if the fw trace is a final trace also the bw trace is a final one. And vice versa if is_fw_final_trace: # Swap saved_for_backward_traces - saved_for_bw = saved_for_bw_C0, fw_output[1][1] # Saved for backward tuple unpacks in (C0, _) + saved_for_bw = saved_for_bw_C0, fw_output[1][1] # Saved for backward tuple unpacks in (C0, _) # Subsitute the static inputs for saved_for_backward with the runtime ones input_args.pop(0) input_args.insert(0, saved_for_bw) @@ -453,7 +460,9 @@ def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causa See how the backward trace need t4 as argument recoveered from the static args """ updated_input_args = [t for t in saved_for_bw_C0] - updated_input_args.extend(input_args[len(updated_input_args):]) # Should be only one variable but leave this dyanamic + updated_input_args.extend( + input_args[len(updated_input_args) :] + ) # Should be only one variable but leave this dyanamic input_args = updated_input_args return input_args @@ -471,7 +480,7 @@ def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causa # that afterwards at least one TE symbol will be included # NOTE: compile data could be None if this benchmark util is used outside the compilation process. If this is the case we are benchmarking # a whole trace (in theory) and is_te_used API will return the needed result. - te_used = (cd.compile_options.get('te_used', False) if cd else False) or is_te_used(trace) + te_used = (cd.compile_options.get("te_used", False) if cd else False) or is_te_used(trace) if te_used: cached_te_fp8_autocast_value = trace._include_te_fp8_autocast trace._include_te_fp8_autocast = True @@ -502,6 +511,7 @@ def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causa t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) except Exception as e: import traceback + ex_str = traceback.format_exc() print(ex_str) # https://github.com/Lightning-AI/lightning-thunder/issues/664 @@ -574,35 +584,42 @@ def wrap_fn_with_exeuctor_compile_option(option, fn: Callable | None = None, *ar return out + def print_trace_args(trace: TraceCtx): print_nested_sequence(trace.args) + # Display nest sequence arguments def print_nested_sequence(args, show_dicts=False): - def is_tensor(t): return isinstance(t, torch.Tensor) or isinstance(t, TensorProxy) if not isinstance(args, Sequence): return - print('###################################### Sequence start') + print("###################################### Sequence start") + def _print(args, level): - tabs = '\t' * level - print(f'Level {level} start') + tabs = "\t" * level + print(f"Level {level} start") for arg in args: if isinstance(arg, Sequence): - _print(arg, level+1) + _print(arg, level + 1) else: tensor_shape = arg.shape if is_tensor(arg) else None dtype = arg.dtype if is_tensor(arg) else None name = arg.name if isinstance(arg, TensorProxy) else "" - print(f'{tabs}{name + ": " if name else ""}{type(arg)}{arg if isinstance(arg, dict) and show_dicts else ""} {tensor_shape if tensor_shape else ""} {dtype if dtype else ""}') - print(f'Level {level} end') + print( + f'{tabs}{name + ": " if name else ""}{type(arg)}{arg if isinstance(arg, dict) and show_dicts else ""} {tensor_shape if tensor_shape else ""} {dtype if dtype else ""}' + ) + print(f"Level {level} end") + _print(args, 0) - print('###################################### Debug args\n') + print("###################################### Debug args\n") + def update_compile_options_executor_list_after_fw_bw_split() -> None: from thunder.backend_optimizer.optimizer import get_fw_bw_split_backends_options + cd = get_compile_data() assert cd @@ -617,12 +634,13 @@ def update_compile_options_executor_list_after_fw_bw_split() -> None: executors_list.remove(ex) # Putting at the front even though order does not matter - for ex in cd.compile_options['executors_placed_by_fw_bw_split']: + for ex in cd.compile_options["executors_placed_by_fw_bw_split"]: executors_list.insert(0, ex) # Assign new compilation executors options cd.executors_list = executors_list + def transform_tensor(arg: TensorProxy, **kwargs) -> torch.Tensor: from thunder.core.dtypes import is_float_dtype, is_signedinteger_dtype, is_boolean_dtype @@ -634,16 +652,20 @@ def transform_tensor(arg: TensorProxy, **kwargs) -> torch.Tensor: requires_grad = arg.requires_grad torch_dtype = to_torch_dtype(dtype) if torch_dtype is None: - raise AssertionError(f'Unrecognized thunder dtype: {dtype}') + raise AssertionError(f"Unrecognized thunder dtype: {dtype}") if is_float_dtype(dtype): # Use TE Float8 if TE is enabled, it has float32 torch dtype - te_used = kwargs.get('te_used', False) + te_used = kwargs.get("te_used", False) if te_used: tensor: torch.Tensor = torch.randn( - shape, dtype=torch_dtype if dtype.bytes > 1 else torch.float32, device=device.device_str(), requires_grad=requires_grad + shape, + dtype=torch_dtype if dtype.bytes > 1 else torch.float32, + device=device.device_str(), + requires_grad=requires_grad, ) if dtype.bytes == 1: import transformer_engine.pytorch as te + tensor = te.float8_tensor.Float8Tensor.to_float8(tensor) # Support standard float tensors else: @@ -664,8 +686,10 @@ def transform_tensor(arg: TensorProxy, **kwargs) -> torch.Tensor: return tensor + def transform_proxy_to_torch(sequence: Sequence, level=0, **kwargs) -> tuple | list: from thunder.executors.transformer_engineex import Context as C + res = [] for e in sequence: if type(e) is tuple: @@ -693,16 +717,19 @@ def transform_proxy_to_torch(sequence: Sequence, level=0, **kwargs) -> tuple | l # If the static input generator will be capable to generate only the cotangents then branch will not be used anymore # # Currently an option to fill a custom maybe real context is left. - elif hasattr(e, 'name') and isinstance(e, AnyProxy) and e.name.startswith('ctx_te'): - required_context = kwargs.get('cached_fw_te_ctx_out', None) + elif hasattr(e, "name") and isinstance(e, AnyProxy) and e.name.startswith("ctx_te"): + required_context = kwargs.get("cached_fw_te_ctx_out", None) res.append(required_context if required_context is not None else C()) elif e is None: res.append(None) else: - raise AssertionError(f'Input arg type not recognized: {type(e)} with name: {e.name if hasattr(e, "name") else "unknown"} with value: {e}') + raise AssertionError( + f'Input arg type not recognized: {type(e)} with name: {e.name if hasattr(e, "name") else "unknown"} with value: {e}' + ) # Outer container must be a list return tuple(res) if level > 0 else res + def reorder_executors_list(executors: Sequence): from thunder.backend_optimizer.optimizer import get_fw_bw_split_backends_options from thunder.executors.torch_compile import torch_compile_ex @@ -733,7 +760,7 @@ def reorder_executors_list(executors: Sequence): for ex in reordered: if are_inputs_names and (ex == nvfuser_ex.name or ex == torch_compile_ex.name): found = True - elif (ex == nvfuser_ex or ex == torch_compile_ex): + elif ex == nvfuser_ex or ex == torch_compile_ex: found = True if not found: reordered.insert(0, nvfuser_ex.name if are_inputs_names else nvfuser_ex) diff --git a/thunder/benchmarks/utils.py b/thunder/benchmarks/utils.py index 04b66e3d2f..bcaa7007cc 100644 --- a/thunder/benchmarks/utils.py +++ b/thunder/benchmarks/utils.py @@ -5,27 +5,29 @@ warm_up_iters = 50 -class SplitFwBwBenchmarkUtils(): + +class SplitFwBwBenchmarkUtils: def __init__( - self, *, cost: float = float("inf"), fw_fn: Callable | None = None, bw_fn: Callable | None = None, executor = None + self, *, cost: float = float("inf"), fw_fn: Callable | None = None, bw_fn: Callable | None = None, executor=None ) -> None: self.cost: float = cost self.fw_fn: Callable | None = fw_fn self.bw_fn: Callable | None = bw_fn self.executor = executor -class AutotunerTorchAutogradBenchmarkUtils(): + +class AutotunerTorchAutogradBenchmarkUtils: def __init__( self, - cost: float = float('inf'), + cost: float = float("inf"), fw_trace: TraceCtx | None = None, bw_trace: TraceCtx | None = None, fw_traces: Sequence[TraceCtx] = [], bw_traces: Sequence[TraceCtx] = [], primal_trace: TraceCtx | None = None, - executor = None, - selected_executors: Sequence = [] - ) -> None: + executor=None, + selected_executors: Sequence = [], + ) -> None: self.cost: float = cost self.fw_trace = fw_trace self.bw_trace = bw_trace @@ -37,7 +39,6 @@ def __init__( def torch_fw_bw_benchmark_nvsight(models: list, labels: list, inputs: list, iters: int) -> None: - for m, input, label in zip(models, inputs, labels): # Warm up for _ in range(warm_up_iters): @@ -50,18 +51,18 @@ def torch_fw_bw_benchmark_nvsight(models: list, labels: list, inputs: list, iter for i in range(iters): torch.cuda.empty_cache() torch.cuda.nvtx.range_push(f"torch training {label}, iter {i}") - torch.cuda.nvtx.range_push('forward') + torch.cuda.nvtx.range_push("forward") y = m(input) torch.cuda.nvtx.range_pop() loss = y.sum() - torch.cuda.nvtx.range_push('backward') + torch.cuda.nvtx.range_push("backward") loss.backward() torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_pop() torch.cuda.cudart().cudaProfilerStop() -def torch_fw_bw_benchmark(models: list, labels: list, inputs: list, iters: int) -> None: +def torch_fw_bw_benchmark(models: list, labels: list, inputs: list, iters: int) -> None: for m, input, label in zip(models, inputs, labels): # Warm up for _ in range(warm_up_iters): @@ -82,16 +83,13 @@ def torch_fw_bw_benchmark(models: list, labels: list, inputs: list, iters: int) y = m(input) end_events[i].record(stream) - max_allocated_bytes = max( - max_allocated_bytes, torch.cuda.max_memory_allocated( - torch.cuda.current_device()) - ) + max_allocated_bytes = max(max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device())) torch.cuda.synchronize() tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] tot_time = sum(tot) / iters - print(f'{label} tot fw time: {tot_time} ms') - print(f'{label} max fw allocated memory: {max_allocated_bytes / (2**30)} GB') + print(f"{label} tot fw time: {tot_time} ms") + print(f"{label} max fw allocated memory: {max_allocated_bytes / (2**30)} GB") start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] @@ -109,19 +107,16 @@ def torch_fw_bw_benchmark(models: list, labels: list, inputs: list, iters: int) loss.backward() end_events[i].record(stream) - max_allocated_bytes = max( - max_allocated_bytes, torch.cuda.max_memory_allocated( - torch.cuda.current_device()) - ) + max_allocated_bytes = max(max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device())) torch.cuda.synchronize() tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] tot_time = sum(tot) / iters - print(f'{label} tot bw time: {tot_time} ms') - print(f'{label} max bw allocated memory: {max_allocated_bytes / (2**30)} GB') + print(f"{label} tot bw time: {tot_time} ms") + print(f"{label} max bw allocated memory: {max_allocated_bytes / (2**30)} GB") -def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int) -> None: +def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int) -> None: for m, input, label in zip(models, inputs, labels): # Warm up for _ in range(warm_up_iters): @@ -144,28 +139,44 @@ def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int) loss.backward() end_events[i].record(stream) - max_allocated_bytes = max( - max_allocated_bytes, torch.cuda.max_memory_allocated( - torch.cuda.current_device()) - ) + max_allocated_bytes = max(max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device())) torch.cuda.synchronize() tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] tot_time = sum(tot) / iters - print(f'{label} tot time: {tot_time} ms') - print(f'{label} max allocated memory: {max_allocated_bytes / (2**30)} GB') + print(f"{label} tot time: {tot_time} ms") + print(f"{label} max allocated memory: {max_allocated_bytes / (2**30)} GB") -def thunder_fw_bw_benchmark(fw_traces: list, bw_traces: list, fw_labels: list, bw_labels: list, iters: int, nvsight: bool = False) -> None: - assert(len(fw_traces) == len(bw_traces) == len(fw_labels) == len(bw_labels)) +def thunder_fw_bw_benchmark( + fw_traces: list, bw_traces: list, fw_labels: list, bw_labels: list, iters: int, nvsight: bool = False +) -> None: + assert len(fw_traces) == len(bw_traces) == len(fw_labels) == len(bw_labels) for trc, label in zip(fw_traces, fw_labels): - c, m, _ = benchmark_trace(trc, apply_del_last_used=False, snapshot=True, snapshot_name=label, iters=iters, nvsight=nvsight, nvsight_fn_name=label) + c, m, _ = benchmark_trace( + trc, + apply_del_last_used=False, + snapshot=True, + snapshot_name=label, + iters=iters, + nvsight=nvsight, + nvsight_fn_name=label, + ) if not nvsight: - print(f'Executing {label} trace:\n{c} ms, {m / (2**30)} GB') + print(f"Executing {label} trace:\n{c} ms, {m / (2**30)} GB") i = 0 for trc, label in zip(bw_traces, bw_labels): - c, m, _ = benchmark_trace(trc, apply_del_last_used=False, snapshot=True, snapshot_name=label, iters=iters, nvsight=nvsight, nvsight_fn_name=label, fw_trace=fw_traces[i]) + c, m, _ = benchmark_trace( + trc, + apply_del_last_used=False, + snapshot=True, + snapshot_name=label, + iters=iters, + nvsight=nvsight, + nvsight_fn_name=label, + fw_trace=fw_traces[i], + ) if not nvsight: - print(f'Executing {label} trace:\n{c} ms, {m / (2**30)} GB') + print(f"Executing {label} trace:\n{c} ms, {m / (2**30)} GB") i += 1 From 5ce0002a72368b697f9461b6da8551d59b71c436 Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 27 Aug 2024 12:10:51 +0300 Subject: [PATCH 106/171] Fixed tensor device --- thunder/tests/test_autotuner.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/thunder/tests/test_autotuner.py b/thunder/tests/test_autotuner.py index db4a025a9f..808d258cf8 100644 --- a/thunder/tests/test_autotuner.py +++ b/thunder/tests/test_autotuner.py @@ -221,25 +221,25 @@ def forward(self, q, k, v): [ ( Sdpa(), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), ["cudnn", "sdpa", "fa3"], 1, ), ( Sdpa(), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), ["cudnn", "sdpa"], 1, ), ( Sdpa(), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), [ "cudnn", ], @@ -247,9 +247,9 @@ def forward(self, q, k, v): ), ( Sdpa(), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda", requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), + torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), [], 0, ), From 388d7d0fb1c9706a9cd3cf42d456b4b10314b5fe Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 27 Aug 2024 12:55:17 +0300 Subject: [PATCH 107/171] Added cuda guard --- thunder/tests/test_autotuner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/thunder/tests/test_autotuner.py b/thunder/tests/test_autotuner.py index 808d258cf8..4295f7e222 100644 --- a/thunder/tests/test_autotuner.py +++ b/thunder/tests/test_autotuner.py @@ -387,5 +387,7 @@ def forward(self, x: torch.Tensor): (["transformer_engine", "nvfuser", "sdpa"], ["transformer_engine", "sdpa", "nvfuser"]), ], ) +# We might not have nvfuser in non cuda envs +@requiresCUDA def test_reorder_executors_list(executors, expected): assert aut_utils.reorder_executors_list(executors) == expected From 19953cd3e47e675a5027653427777d6e55b6191f Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 27 Aug 2024 13:07:55 +0300 Subject: [PATCH 108/171] Formatter --- thunder/__init__.py | 15 ++-- thunder/core/codeutils.py | 5 +- thunder/core/transform_common.py | 12 +++- thunder/core/transforms.py | 3 +- thunder/core/vjp_utils.py | 39 +++++++---- thunder/executors/nvmathex.py | 13 ++-- thunder/executors/passes.py | 14 ++-- thunder/executors/sdpaex.py | 1 + thunder/executors/torch_autograd.py | 11 +-- thunder/executors/transformer_engineex.py | 2 + thunder/tests/test_autotuner.py | 84 +++++++++++++++++++---- 11 files changed, 143 insertions(+), 56 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index ecf729ac2c..26ed89372e 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -326,14 +326,14 @@ def jit( if transforms is None: transforms = [] - required_autotune = compile_options.get('autotune_type', None) + required_autotune = compile_options.get("autotune_type", None) if required_autotune is not None: - if required_autotune not in ['runtime', 'memory']: - raise AssertionError(f'Not supported optimization: {required_autotune}') + if required_autotune not in ["runtime", "memory"]: + raise AssertionError(f"Not supported optimization: {required_autotune}") compile_options |= { - "autotune_type": OptimizerType.RUNTIME if required_autotune == 'runtime' else OptimizerType.MEMORY, - "executors_placed_by_fw_bw_split": set() + "autotune_type": OptimizerType.RUNTIME if required_autotune == "runtime" else OptimizerType.MEMORY, + "executors_placed_by_fw_bw_split": set(), } # Default the executors list to all_executors if no options are given @@ -341,9 +341,10 @@ def jit( if not executors: executors = get_all_executors() # Remove python and cudagraph - executors = [ex for ex in executors if ex.name != 'python' and ex.name != 'cudagraphex'] + executors = [ex for ex in executors if ex.name != "python" and ex.name != "cudagraphex"] from thunder.backend_optimizer.utils import reorder_executors_list + executors = reorder_executors_list(executors) # Resolve names of executors @@ -792,7 +793,6 @@ def fn_(*args, **kwargs) -> Any: cs.last_trace_host_execution_start = time.perf_counter_ns() if cache_entry.vanilla_tensor_args: - if alias_tensor_indices_str := _alias_tensor_of_args_kwargs(*inps): alias_tensor_indices = alias_tensor_indices_str alias_tensor_indices = {int(i) for i in alias_tensor_indices_str.split(",")} @@ -1144,4 +1144,3 @@ def _fn(*args, **kwargs): return original_result, original_trace return _fn - diff --git a/thunder/core/codeutils.py b/thunder/core/codeutils.py index d1c6d30406..a4ec6c2440 100644 --- a/thunder/core/codeutils.py +++ b/thunder/core/codeutils.py @@ -460,6 +460,7 @@ class NamedBindings: si.unwrapped_fn = unwrapped return si + def get_siginfo_name(trace) -> str: try: name = "" @@ -467,8 +468,8 @@ def get_siginfo_name(trace) -> str: siginfo: SigInfo = get_siginfo(trace.fn, trace.args, trace.kwargs) name = siginfo.name else: - name = 'unknown' + name = "unknown" return name except Exception as e: - raise AssertionError(f'Is input trace an instance of TraceCtx?\n{e}') + raise AssertionError(f"Is input trace an instance of TraceCtx?\n{e}") diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index 88da7d9907..ccad0c6280 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -100,13 +100,15 @@ def dce(trace: Trace, needed_proxies: None | set[Variable] = None) -> Trace: start_time_ns = time.perf_counter_ns() cd = get_compile_data() - disabled = not(not cd or (cd and not cd.compile_options.get('disable_dce', None))) + disabled = not (not cd or (cd and not cd.compile_options.get("disable_dce", None))) if not disabled: producer_map: ProxyDict = producers(trace) flat_trace_outputs, _ = tree_flatten(trace.output) if needed_proxies is None: - needed_proxies: set[Variable] = set(tuple(variableify(x) for x in flat_trace_outputs if isinstance(x, Proxy))) + needed_proxies: set[Variable] = set( + tuple(variableify(x) for x in flat_trace_outputs if isinstance(x, Proxy)) + ) else: needed_proxies.update(tuple(variableify(x) for x in flat_trace_outputs if isinstance(x, Proxy))) dced = [] @@ -162,7 +164,11 @@ def _helper(x): end_time_ns = time.perf_counter_ns() elapsed_time_ns = end_time_ns - start_time_ns elapsed_time_millis = elapsed_time_ns // 1000000 - dcetrace.set_provenance(TraceProvenance(f"Dead Code Elimination{' Skipped Per Compile Options' if disabled else ''} (took {elapsed_time_millis} milliseconds)")) + dcetrace.set_provenance( + TraceProvenance( + f"Dead Code Elimination{' Skipped Per Compile Options' if disabled else ''} (took {elapsed_time_millis} milliseconds)" + ) + ) return dcetrace diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index bb38cbf947..87f0a7d9e9 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -394,6 +394,7 @@ def visitor_transform_paired(trace_from: Trace, visit: Callable, zipped: zip, *, finally: reset_tracectx(tracectx_tok) + # Creates a new trace from "trace_from" by calling "visit" on its bound symbols ("bsyms"). # visit(bsym: BoundSymbolInterface) -> VISIT_TYPE should call operations # as if executing a program, and those operations will be recorded into the @@ -1488,7 +1489,6 @@ def grad( cfn, ) -> Callable: def grad(func): - @wraps(func) def grad_func(*args, **kwargs): _, grads = value_and_grad(func)(*args, **kwargs) @@ -3739,7 +3739,6 @@ def ones_like(x): else: return None - forward_trace = construct_trace()(augmented_forward_fn, *trace.args, **trace.kwargs) # We set forward trace to construct proxies because we need these proxies to # have different names than the ones in the forward trace. diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index 055dc5103d..a956a54b65 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -66,7 +66,9 @@ def _make_aug_forward_and_backward(*, return_traces=False, update_cache=True) -> if cached_result is not None and not getattr(joint_forward_backward, "_disable_caching", False): return cached_result - joint_trace = thunder.trace(inline_trace=False, use_dce=False)(joint_forward_backward, *bsym.args, **bsym.kwargs) + joint_trace = thunder.trace(inline_trace=False, use_dce=False)( + joint_forward_backward, *bsym.args, **bsym.kwargs + ) consumers = utils.consumers(joint_trace) def find_backward_input(forward_output): @@ -115,7 +117,9 @@ def find_backward_output(forward_input): forward_input_proxies = tree_flatten((joint_trace.args, joint_trace.kwargs))[0] forward_input_proxies = [arg for arg in forward_input_proxies if isinstance(arg, Proxy)] - forward_bsyms = utils.find_producer_symbols(joint_trace, tree_flatten(joint_trace.output)[0], forward_input_proxies) + forward_bsyms = utils.find_producer_symbols( + joint_trace, tree_flatten(joint_trace.output)[0], forward_input_proxies + ) backward_bsyms = [bsym for bsym in backward_bsyms if bsym not in forward_bsyms] # Find required info from forward trace for backward trace @@ -125,7 +129,9 @@ def find_backward_output(forward_input): for arg in backward_bsym.flat_args: if not isinstance(arg, Proxy): continue - if arg not in backward_producers and variableify(arg) not in map(variableify, tree_flatten(bw_inputs)[0]): + if arg not in backward_producers and variableify(arg) not in map( + variableify, tree_flatten(bw_inputs)[0] + ): saved_for_backward.append(arg) saved_for_backward = list({variableify(arg): arg for arg in saved_for_backward}.values()) @@ -202,7 +208,7 @@ def bw_fn(*args, **kwargs): cd = get_compile_data() # No autotuning - if not cd or not cd.compile_options.get('autotune_type', None): + if not cd or not cd.compile_options.get("autotune_type", None): return _make_aug_forward_and_backward() # This search will be performed on the requested executors list @@ -218,7 +224,7 @@ def bw_fn(*args, **kwargs): from thunder.backend_optimizer.utils import benchmark_trace # In order define this unique trace region we need an unique id - key = (bsym.sym, Executor(f'{id(bsym)}-autotuned'), subkey := _make_cache_key(bsym.args, bsym.kwargs)) + key = (bsym.sym, Executor(f"{id(bsym)}-autotuned"), subkey := _make_cache_key(bsym.args, bsym.kwargs)) # We do check the cache here as the key in the inner fn does not know about this special id cached_result = _cache.get(key, None) if subkey is not None else None # NOTE: cache is always enabled here @@ -237,15 +243,17 @@ def bw_fn(*args, **kwargs): requested_executors_list_for_bsym = [ex for ex in cached_executors_list if ex in backends] from thunder.benchmarks.utils import SplitFwBwBenchmarkUtils from thunder.backend_optimizer.optimizer import OptimizerType + best = SplitFwBwBenchmarkUtils() # Restrict the search space backends = list(requested_executors_list_for_bsym) from thunder.backend_optimizer.optimizer import log, LogLevel - log(f'Search space for {bsym.sym.name}: {backends}', level=LogLevel.INFO) + + log(f"Search space for {bsym.sym.name}: {backends}", level=LogLevel.INFO) for b in backends: - log(f'Benchmarking executor {b.name} for {bsym.sym.name}', level=LogLevel.INFO) + log(f"Benchmarking executor {b.name} for {bsym.sym.name}", level=LogLevel.INFO) # Let downstream fn to pick up this requested_executors_list_for_bsym.remove(b) requested_executors_list_for_bsym.insert(0, b) @@ -255,18 +263,25 @@ def bw_fn(*args, **kwargs): # TODO: make benchmark info taken from an autotuner config fw_time, fw_mem, _ = benchmark_trace(fw_trace, iters=100, apply_del_last_used=False) bw_time, bw_mem, _ = benchmark_trace(bw_trace, iters=100, apply_del_last_used=False, fw_trace=fw_trace) - cost = fw_time + bw_time if cd.compile_options['autotune_type'] == OptimizerType.RUNTIME else fw_mem + bw_mem + cost = ( + fw_time + bw_time if cd.compile_options["autotune_type"] == OptimizerType.RUNTIME else fw_mem + bw_mem + ) if cost < best.cost: - best = SplitFwBwBenchmarkUtils(cost = cost, fw_fn = fw_fn, bw_fn = bw_fn, executor = b) + best = SplitFwBwBenchmarkUtils(cost=cost, fw_fn=fw_fn, bw_fn=bw_fn, executor=b) - assert best.cost != float('inf') + assert best.cost != float("inf") from thunder.backend_optimizer.optimizer import log - log(f'Best executor for symbol [{bsym.output.name} = {bsym.sym.name}]: {best.executor.name}', level=LogLevel.INFO) + + log( + f"Best executor for symbol [{bsym.output.name} = {bsym.sym.name}]: {best.executor.name}", + level=LogLevel.INFO, + ) # Update the compile options cd.compile_options["executors_placed_by_fw_bw_split"].add(best.executor) from thunder.executors.transformer_engineex import transformer_engine_ex - cd.compile_options |= {'te_used': True if best.executor == transformer_engine_ex else False} + + cd.compile_options |= {"te_used": True if best.executor == transformer_engine_ex else False} # Restore executor list for downstream optimizations cd.executors_list = cached_executors_list # The executors used in this pass will be updated after the termination of the forward_and_backward_from_trace call diff --git a/thunder/executors/nvmathex.py b/thunder/executors/nvmathex.py index cea6c4e286..0661822a78 100644 --- a/thunder/executors/nvmathex.py +++ b/thunder/executors/nvmathex.py @@ -5,12 +5,14 @@ import thunder.torch as ltorch import torch -nvmath_ex = thunder.extend.OperatorExecutor('nvmath', version='0.1.0') +nvmath_ex = thunder.extend.OperatorExecutor("nvmath", version="0.1.0") thunder.extend.register_executor(nvmath_ex) + def _nvmath_linalg_advanced_matmul_impl(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: return nvmath.linalg.advanced.matmul(a, b) + def _nvmath_linalg_advanced_matmul_checker(a: TensorProxy, b: TensorProxy) -> bool: if len(a.shape) < 2 or len(b.shape) < 2: return False @@ -23,19 +25,16 @@ def _nvmath_linalg_advanced_matmul_checker(a: TensorProxy, b: TensorProxy) -> bo # Handle distribuited return True + nvmath_linalg_advanced_matmul = nvmath_ex.register_operator( "nvmath_linalg_advanced_matmul", like=ltorch.matmul, fn=_nvmath_linalg_advanced_matmul_impl, ) nvmath_ex.register_implementation( - ltorch.matmul, - nvmath_linalg_advanced_matmul, - checker=_nvmath_linalg_advanced_matmul_checker + ltorch.matmul, nvmath_linalg_advanced_matmul, checker=_nvmath_linalg_advanced_matmul_checker ) nvmath_ex.register_implementation( - PrimIDs.MATMUL, - nvmath_linalg_advanced_matmul, - checker=_nvmath_linalg_advanced_matmul_checker + PrimIDs.MATMUL, nvmath_linalg_advanced_matmul, checker=_nvmath_linalg_advanced_matmul_checker ) diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index dd70fada74..a6fb17526a 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -71,7 +71,6 @@ def visit_helper_(bsym: BoundSymbol) -> None | bool: ex: Executor for ex in executors_list: - # TODO Consider allowing operator executors to claim portions of operations # TODO Should FusionExecutors be allowed to claim bsym with bsym.sym.executor? @@ -139,6 +138,7 @@ def visit_(bsym: BoundSymbol) -> transforms.VISIT_TYPE: ) return extrace + # Autotuned transform_for_execution version def autotune_transform_for_execution( *, optimizer_context: BackendOptimizer, trace: TraceCtx, trace_type: TraceType @@ -158,7 +158,7 @@ def autotune_transform_for_execution( # Attach new trace and set the debug file name optimizer_context.attach_trace(trace=trace, trace_type=trace_type) - optimizer_context.log_file_name = f'autotune_transform_for_execution_{sig_name}.log' + optimizer_context.log_file_name = f"autotune_transform_for_execution_{sig_name}.log" # Forward traces are cached inside the context optimizer_context.optimize() match trace_type: @@ -188,6 +188,7 @@ def autotune_transform_for_execution( ) return fw_extrace, bw_extrace + def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor]) -> TraceCtx: import torch @@ -207,7 +208,6 @@ def transform_for_execution(trace: TraceCtx, executors_list: Sequence[Executor]) extrace = _transform_for_operator_executor_execution(trace, executors_list) extrace = dce(extrace) - # # Step 2 Fusion executors can transform the trace # @@ -356,7 +356,7 @@ def del_last_used(trace: TraceCtx, *, clear_mutable_collections=False) -> TraceC # If dce is disabled, we have to disable this pass also cd = get_compile_data() - disabled = not(not cd or (cd and not cd.compile_options.get('disable_dce', None))) + disabled = not (not cd or (cd and not cd.compile_options.get("disable_dce", None))) start_time_ns = time.perf_counter_ns() @@ -375,6 +375,10 @@ def del_last_used(trace: TraceCtx, *, clear_mutable_collections=False) -> TraceC elapsed_time_ns = end_time_ns - start_time_ns elapsed_time_millis = elapsed_time_ns // 1000000 - del_trace.set_provenance(TraceProvenance(f"Delete Last Used{' Skipped Per Compile Options' if disabled else ''} (took {elapsed_time_millis} milliseconds)")) + del_trace.set_provenance( + TraceProvenance( + f"Delete Last Used{' Skipped Per Compile Options' if disabled else ''} (took {elapsed_time_millis} milliseconds)" + ) + ) return del_trace diff --git a/thunder/executors/sdpaex.py b/thunder/executors/sdpaex.py index cb569f3e28..31347a5a60 100644 --- a/thunder/executors/sdpaex.py +++ b/thunder/executors/sdpaex.py @@ -389,6 +389,7 @@ def _scaled_dot_product_efficient_attention_backward_impl( scale=scale, ) + sdpaex_scaled_dot_product_efficient_attention_backward_name = "sdpaex_scaled_dot_product_efficient_attention_backward" sdpea_bwd = sdpa_ex.register_operator( "sdpaex_scaled_dot_product_efficient_attention_backward", diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index 15be9079a5..c18a847b0a 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -106,7 +106,8 @@ def backward(ctx, *args): del grads return (None, None, None, None, None, *([None] * n_grads)) -def update_bw_from_forward_optimization(*, fw: TraceCtx, bw:TraceCtx) -> TraceCtx: + +def update_bw_from_forward_optimization(*, fw: TraceCtx, bw: TraceCtx) -> TraceCtx: # Some of the optimization passes change proxies in the trace and # any change in the forward trace must be reflected in the backward # trace. @@ -117,8 +118,7 @@ def update_bw_from_forward_optimization(*, fw: TraceCtx, bw:TraceCtx) -> TraceCt for x, y in zip(original_bw_saved_tensors_for_backward, new_fw_saved_tensors_for_backward) if variableify(x) != variableify(y) } - new_bsyms = replace_redundant_inputs( - swap_map, bw.bound_symbols) + new_bsyms = replace_redundant_inputs(swap_map, bw.bound_symbols) # replace_redundant_inputs doesn't replace the output of # UNPACK_SEQUENCE so we do it manually. Here we have certain # assumptions about the structure of the backward trace. @@ -136,6 +136,7 @@ def update_bw_from_forward_optimization(*, fw: TraceCtx, bw:TraceCtx) -> TraceCt return bw + def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stats, /, *flat_args): from thunder.backend_optimizer.optimizer import TraceType, BackendOptimizer from thunder.backend_optimizer.utils import update_compile_options_executor_list_after_fw_bw_split @@ -154,7 +155,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat if not any(requires_grad_mask): raise RuntimeError("PyTorch's Autograd interface requires at least one tensor input with requires_grad=True") - autotune_type = compile_data.compile_options.get('autotune_type', None) + autotune_type = compile_data.compile_options.get("autotune_type", None) primal_trace = computation_trc primal_trace = sort_data_parallel_syncs(primal_trace) @@ -213,7 +214,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat apply_bucketing_bw_trace=do_apply_bucketing_bw_trace, produce_log=True, optimizer_type=autotune_type, - compile_data=compile_data + compile_data=compile_data, ) ) diff --git a/thunder/executors/transformer_engineex.py b/thunder/executors/transformer_engineex.py index e24df4be61..89b38768f7 100644 --- a/thunder/executors/transformer_engineex.py +++ b/thunder/executors/transformer_engineex.py @@ -410,6 +410,7 @@ def _te_functional_linear_backward_meta( TensorProxy(like=g, shape=b_shape) if b_shape else None, ) + te_functional_linear_backward_name: str = "te_functional_linear_backward" te_functional_linear_backward = transformer_engine_ex.register_operator( @@ -427,6 +428,7 @@ def _te_functional_linear_backward_meta( linear_bound_symbol_name_prefix: str = "te_linear" + # Creates a new stateful operator for each invocation of `linear`. def _create_fp8_linear_bound_symbol( a: TensorProxy, w: TensorProxy, b: TensorProxy, is_grad_enabled=False diff --git a/thunder/tests/test_autotuner.py b/thunder/tests/test_autotuner.py index 4295f7e222..51513f725a 100644 --- a/thunder/tests/test_autotuner.py +++ b/thunder/tests/test_autotuner.py @@ -221,25 +221,70 @@ def forward(self, q, k, v): [ ( Sdpa(), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), ["cudnn", "sdpa", "fa3"], 1, ), ( Sdpa(), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), ["cudnn", "sdpa"], 1, ), ( Sdpa(), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), [ "cudnn", ], @@ -247,9 +292,24 @@ def forward(self, q, k, v): ), ( Sdpa(), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), - torch.randn([10, 128, 4, 32], dtype=torch.float16, device="cuda" if torch.cuda.is_available() else "cpu", requires_grad=True), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), + torch.randn( + [10, 128, 4, 32], + dtype=torch.float16, + device="cuda" if torch.cuda.is_available() else "cpu", + requires_grad=True, + ), [], 0, ), From 04124beca5daa05b549ba1cb81275cac67c1d4e5 Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 27 Aug 2024 15:57:22 +0300 Subject: [PATCH 109/171] Changed file name --- examples/dev/nvmath.py | 182 --------------------------------- examples/dev/nvmath_example.py | 0 2 files changed, 182 deletions(-) delete mode 100644 examples/dev/nvmath.py create mode 100644 examples/dev/nvmath_example.py diff --git a/examples/dev/nvmath.py b/examples/dev/nvmath.py deleted file mode 100644 index db6d565038..0000000000 --- a/examples/dev/nvmath.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED. -# -# SPDX-License-Identifier: Apache-2.0 - -""" -This example demonstrates basic matrix multiplication of torch tensors. - -nvmath-python supports multiple frameworks. The result of each operation is a tensor of the same -framework that was used to pass the inputs. It is also located on the same device as the inputs. -""" - -import torch -import nvmath -import thunder - -torch.set_default_dtype(torch.bfloat16) -torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul -torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn - -a = torch.randn(128256, 8192, device="cuda") -b = torch.randn(8192, 4096, device="cuda") - -iters = 20 -start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] -end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] -stream = torch.cuda.current_stream() -for i in range(iters): - result_a = torch.matmul(a, b) -torch.cuda.default_stream().synchronize() -for i in range(iters): - torch.cuda.empty_cache() - torch.cuda._sleep(1_000_000) - start_events[i].record(stream) - result_a = torch.matmul(a, b) - end_events[i].record(stream) -torch.cuda.default_stream().synchronize() -tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] -tot_time = sum(tot) / iters -print(f"torch tot time: {tot_time} ms") - -start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] -end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] -stream = torch.cuda.current_stream() -torch.cuda.default_stream().synchronize() -for i in range(iters): - result_b = nvmath.linalg.advanced.matmul(a, b) -torch.cuda.default_stream().synchronize() -for i in range(iters): - torch.cuda.empty_cache() - torch.cuda._sleep(1_000_000) - start_events[i].record(stream) - result_b = nvmath.linalg.advanced.matmul(a, b) - end_events[i].record(stream) -torch.cuda.default_stream().synchronize() -tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] -tot_time = sum(tot) / iters -print(f"nvmath tot time: {tot_time} ms") - -with nvmath.linalg.advanced.Matmul(a, b) as mm: - # Plan. - mm.plan() - - # Inspect the algorithms proposed. - print( - f"Planning returned {len(mm.algorithms)} algorithms. The capabilities of the best one are:", - ) - best = mm.algorithms[0] - print(best.capabilities) - - # Modify the tiling configuration of the algorithm. Note that the valid tile configuration depends on - # the hardware, and not all combinations of the configuration are supported, so we leave it as an exercise. - print(f"Tiling {best.tile}") - # Execute the multiplication. - - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - stream = torch.cuda.current_stream() - torch.cuda.default_stream().synchronize() - for i in range(iters): - result_c = mm.execute() - torch.cuda.default_stream().synchronize() - for i in range(iters): - torch.cuda.empty_cache() - torch.cuda._sleep(1_000_000) - start_events[i].record(stream) - result_c = mm.execute() - end_events[i].record(stream) - torch.cuda.default_stream().synchronize() - tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] - tot_time = sum(tot) / iters - print(f"nvmath tuned tot time: {tot_time} ms") - -with nvmath.linalg.advanced.Matmul(a, b) as mm: - # Plan. - mm.plan() - - # Autotune - mm.autotune(iterations=5) - - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - stream = torch.cuda.current_stream() - torch.cuda.default_stream().synchronize() - for i in range(iters): - result_d = mm.execute() - torch.cuda.default_stream().synchronize() - for i in range(iters): - torch.cuda.empty_cache() - torch.cuda._sleep(1_000_000) - start_events[i].record(stream) - result_d = mm.execute() - end_events[i].record(stream) - torch.cuda.default_stream().synchronize() - tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] - tot_time = sum(tot) / iters - print(f"nvmath tuned tot time: {tot_time} ms") - - -print("deviation:", (result_a - result_b).abs().max().item()) -print("deviation:", (result_a - result_c).abs().max().item()) -print("deviation:", (result_a - result_d).abs().max().item()) - -# from thunder.benchmarks.utils import thunder_fw_bw_benchmark -# from thunder.executors.nvmathex import nvmath_ex - -# class Module(torch.nn.Module): -# def __init__(self, *args, **kwargs) -> None: -# super().__init__(*args, **kwargs) - -# def forward(self, a, b): -# return a @ b - -# with torch.device('cuda'): - -# model = Module() -# jmodel_def = thunder.jit(model) -# jmodel_auto = thunder.jit(model, autotune_type="runtime", executors = [nvmath_ex], use_cudagraphs=False) -# a = torch.randn(128256, 128, requires_grad=True) -# b = torch.randn(128, 4096, requires_grad=True) - -# print('deviation def:', (jmodel_def(a, b) - model(a, b)).abs().max().item()) -# print('deviation auto:', (jmodel_auto(a, b) - model(a, b)).abs().max().item()) - -# from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_fw_bw_benchmark_nvsight, torch_total_benchmark - -# print('Results with thunder benchmark:') -# fw_traces = [ -# thunder.last_traces(jmodel_def)[-1], -# thunder.last_traces(jmodel_auto)[-1], -# ] -# bw_traces = [ -# thunder.last_backward_traces(jmodel_def)[-1], -# thunder.last_backward_traces(jmodel_auto)[-1], -# ] -# fw_labels = ["fw_def", "fw_auto"] -# bw_labels = ["bw_def", "bw_auto"] -# thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, 10) - -# for t in fw_traces: -# print(f'{t}\n################################') -# for t in bw_traces: -# print(f'{t}\n################################') - -# for _ in range(iters): -# result_a = model(a.clone().detach(), b.clone().detach()) -# torch.cuda.default_stream().synchronize() -# s = time.time_ns() -# for i in range(iters): -# result_a = model(a.clone().detach(), b.clone().detach()) -# torch.cuda.default_stream().synchronize() -# e = time.time_ns() -# print('time torch', (e-s)/1000000, 'ms') - -# for _ in range(iters): -# result_b = jmodel(a.clone().detach(), b.clone().detach()) -# torch.cuda.default_stream().synchronize() -# s = time.time_ns() -# for i in range(iters): -# result_b = model(a.clone().detach(), b.clone().detach()) -# torch.cuda.default_stream().synchronize() -# e = time.time_ns() -# print('time nvmath', (e-s)/1000000, 'ms') diff --git a/examples/dev/nvmath_example.py b/examples/dev/nvmath_example.py new file mode 100644 index 0000000000..e69de29bb2 From 44a89392681219013263aef1b850f31cc0d8506d Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 27 Aug 2024 16:05:17 +0300 Subject: [PATCH 110/171] Restored old value --- thunder/core/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 87f0a7d9e9..fed2321458 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -1564,7 +1564,7 @@ def read(x: Variable): else: return x - def write(v: Variable, val: Any, allow_duplicates=True) -> None: + def write(v: Variable, val: Any, allow_duplicates=False) -> None: if not isinstance(v, Variable): return # Duplicates are allowed and overwritten From 2bc7546a4fd2915b476a3bbc28a54ada10797707 Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 27 Aug 2024 23:56:14 +0300 Subject: [PATCH 111/171] Restored flag --- examples/dev/litGPT.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 839b02b118..662b2ffdf9 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -100,8 +100,8 @@ def __init__( fw_labels = ["fw_def", "fw_auto"] bw_labels = ["bw_def", "bw_auto"] print(f"Results thunder benchmark ({iters} iters):") - thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters, nvsight=True) - thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, 10, nvsight=True) + thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters, nvsight=False) + # thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, 10, nvsight=True) print(test.model_name) print(f"\n\nResults torch fw bw benchmark ({iters} iters):") From 0fbcbb2bd599783824b2df949cac2ea1627536df Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 28 Aug 2024 09:36:03 +0300 Subject: [PATCH 112/171] Torch compiler reset --- thunder/backend_optimizer/optimizer.py | 4 ++++ thunder/backend_optimizer/utils.py | 32 ++++++++++++++------------ 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index fad3fe333c..68a945603c 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -714,6 +714,10 @@ def measure_and_update_result(): n_missing_bsyms = len(group) - start_idx # TODO (matteochen): consider to add the iteration with no fusion regions for i in range(0, n_missing_bsyms, n_missing_bsyms - 1 if self.trace_type == TraceType.BW else 1): + if ex.name == 'torchcompile': + import torch + torch.compiler.reset() + # for i in range(0, n_missing_bsyms): # From top to bottom (this will include the whole region) # -> First iteration is the one with fusion region with single element diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 8add55a3c9..8209b74043 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -303,6 +303,8 @@ def benchmark_trace( from thunder.executors.passes import del_last_used import inspect + torch.compiler.reset() + def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, float, Any]: try: warm_up_iters = 50 @@ -310,20 +312,20 @@ def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, f torch.cuda.synchronize() # Warm up cycles for _ in range(warm_up_iters): - cloned_args = clone_args(args) - fn(*cloned_args) - del cloned_args + # cloned_args = clone_args(args) + fn(*args) + # del cloned_args # Benchmark torch.cuda.empty_cache() torch.cuda.synchronize() torch.cuda.cudart().cudaProfilerStart() for i in range(iters): - cloned_args = clone_args(args) + # cloned_args = clone_args(args) torch.cuda.empty_cache() torch.cuda.nvtx.range_push(f"thunder benchmark fn:{nvsight_fn_name}, iter{i}") - fn(*cloned_args) + fn(*args) torch.cuda.nvtx.range_pop() - del cloned_args + # del cloned_args torch.cuda.cudart().cudaProfilerStop() return float("inf"), float("inf"), None @@ -356,19 +358,19 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl # Warm up cycles for _ in range(warm_up_iters): - cloned_args = clone_args(args) - out = fn(*cloned_args) - del cloned_args + # cloned_args = clone_args(args) + out = fn(*args) + # del cloned_args # Snapshot request if snapshot: - cloned_args = clone_args(args) + # cloned_args = clone_args(args) torch.cuda.empty_cache() torch.cuda.synchronize() torch.cuda.memory._record_memory_history() - fn(*cloned_args) + fn(*args) torch.cuda.memory._dump_snapshot(snapshot_name + "_benchmark.pickle") torch.cuda.memory._record_memory_history(enabled=None) - del cloned_args + # del cloned_args # Benchmark stream = torch.cuda.current_stream() max_allocated_bytes = 0 @@ -377,17 +379,17 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl torch.cuda.synchronize() for i in range(iters): current_iter = i - cloned_args = clone_args(args) + # cloned_args = clone_args(args) torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) torch.cuda.empty_cache() torch.cuda._sleep(1_000_000) start_events[i].record(stream) - fn(*cloned_args) + fn(*args) end_events[i].record(stream) max_allocated_bytes = max( max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device()) ) - del cloned_args + # del cloned_args torch.cuda.synchronize() times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] From fbf94e05018d310b8145c91f4dec6a0a2481bec3 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 28 Aug 2024 09:36:52 +0300 Subject: [PATCH 113/171] Updated litgpt runner --- examples/dev/litGPT.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 662b2ffdf9..d4faf1ef79 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -8,7 +8,6 @@ from thunder.tests.litgpt_model import Config import thunder import torch -from thunder.executors.nvmathex import nvmath_ex torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn @@ -40,14 +39,13 @@ def __init__( executors=[ "cudnn", "sdpa", - "fa3", "nvfuser", "torchcompile", ], ), Test( 1, - "runtime", + "memory", 1, executors=[ "cudnn", @@ -60,7 +58,15 @@ def __init__( 1, "runtime", 1, - executors=["cudnn", "sdpa", "fa3", "nvfuser", "torchcompile", nvmath_ex], + executors=["cudnn", "sdpa", "nvfuser", "torchcompile"], + model_name="stablecode-completion-alpha-3b", + ), + Test( + 1, + "memory", + 1, + executors=["cudnn", "sdpa", "nvfuser", "torchcompile"], + model_name="stablecode-completion-alpha-3b", ), ] @@ -77,6 +83,8 @@ def __init__( x = torch.randint(1, model.config.vocab_size, (test.batch_size, cfg.block_size)) print(f"Input size: {x.size()}") + eager = model + torch_compile = torch.compile(model) jmodel_def = thunder.jit(model) jmodel_auto = thunder.jit( model, @@ -88,7 +96,7 @@ def __init__( print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item()) print("deviation auto:", (jmodel_auto(x) - model(x)).abs().max().item()) - iters = 100 + iters = 40 fw_traces = [ thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], @@ -99,15 +107,15 @@ def __init__( ] fw_labels = ["fw_def", "fw_auto"] bw_labels = ["bw_def", "bw_auto"] + print('\n\n####################################################', test.model_name) print(f"Results thunder benchmark ({iters} iters):") thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters, nvsight=False) # thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, 10, nvsight=True) - print(test.model_name) print(f"\n\nResults torch fw bw benchmark ({iters} iters):") - callables = [jmodel_def, jmodel_auto] - labels = ["def", "auto"] - inputs = [x, x] + callables = [eager, torch_compile, jmodel_def, jmodel_auto] + labels = ['eager', 'torch.compile', 'Thunder', 'Thunder Autotuner'] + inputs = [x, x, x, x] torch_fw_bw_benchmark(callables, labels, inputs, iters) print(f"\n\nResults torch total benchmark ({iters} iters):") torch_total_benchmark(callables, labels, inputs, iters) From be3912a1eed03b95d746b0064f73a41d9a93b812 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 28 Aug 2024 09:56:20 +0300 Subject: [PATCH 114/171] Added guard for args cloning --- thunder/backend_optimizer/utils.py | 48 ++++++++++++++---------------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 8209b74043..0803a4fdbf 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -305,6 +305,24 @@ def benchmark_trace( torch.compiler.reset() + # TODO: If TE is used inside the trace we have to clone the input arguments as + # we are currently seeing benchmarking issues at the iteration i > 0 + def clone_args_if_needed(args): + te_used = is_te_used(trace) + if not te_used: + return args + res = [] + # Detatching the tensors as for standalone trace benchmarks we are not interested in the gradients + for arg in args: + if isinstance(arg, Sequence): + res.append(clone_args_if_needed(arg)) + else: + if isinstance(arg, torch.Tensor): + res.append(arg.clone().detach()) + else: + res.append(arg) + return tuple(res) + def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, float, Any]: try: warm_up_iters = 50 @@ -312,20 +330,18 @@ def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, f torch.cuda.synchronize() # Warm up cycles for _ in range(warm_up_iters): - # cloned_args = clone_args(args) + args = clone_args_if_needed(args) fn(*args) - # del cloned_args # Benchmark torch.cuda.empty_cache() torch.cuda.synchronize() torch.cuda.cudart().cudaProfilerStart() for i in range(iters): - # cloned_args = clone_args(args) + args = clone_args_if_needed(args) torch.cuda.empty_cache() torch.cuda.nvtx.range_push(f"thunder benchmark fn:{nvsight_fn_name}, iter{i}") fn(*args) torch.cuda.nvtx.range_pop() - # del cloned_args torch.cuda.cudart().cudaProfilerStop() return float("inf"), float("inf"), None @@ -336,41 +352,25 @@ def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, f print(f"#Trace execution failed for nvsight (error: {e}):\n{trc}") raise e - def clone_args(args): - res = [] - for arg in args: - if isinstance(arg, Sequence): - res.append(clone_args(arg)) - else: - if isinstance(arg, torch.Tensor): - res.append(arg.clone().detach()) - else: - res.append(arg) - return tuple(res) - def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[float, float, Any]: try: current_iter = 0 warm_up_iters = 50 out = None - # print_args(args) - # Warm up cycles for _ in range(warm_up_iters): - # cloned_args = clone_args(args) + args = clone_args_if_needed(args) out = fn(*args) - # del cloned_args # Snapshot request if snapshot: - # cloned_args = clone_args(args) + args = clone_args_if_needed(args) torch.cuda.empty_cache() torch.cuda.synchronize() torch.cuda.memory._record_memory_history() fn(*args) torch.cuda.memory._dump_snapshot(snapshot_name + "_benchmark.pickle") torch.cuda.memory._record_memory_history(enabled=None) - # del cloned_args # Benchmark stream = torch.cuda.current_stream() max_allocated_bytes = 0 @@ -379,7 +379,7 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl torch.cuda.synchronize() for i in range(iters): current_iter = i - # cloned_args = clone_args(args) + args = clone_args_if_needed(args) torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) torch.cuda.empty_cache() torch.cuda._sleep(1_000_000) @@ -389,11 +389,9 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl max_allocated_bytes = max( max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device()) ) - # del cloned_args torch.cuda.synchronize() times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] - # print(f"times: {times}") tot_time = sum(times) / iters return tot_time, max_allocated_bytes, out except Exception as e: From 7a1cf140284c3e1d09a752627de03b477bf2c988 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 28 Aug 2024 13:50:24 +0300 Subject: [PATCH 115/171] Added guard for name attribute --- thunder/backend_optimizer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 0803a4fdbf..cd4f3c0fc9 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -94,7 +94,7 @@ def unpack_output(out) -> Sequence[Proxy]: ans.append(e) from thunder.backend_optimizer.optimizer import log, LogLevel - log(f"Returning not used proxies: {[p.name for p in ans]}", level=LogLevel.DEBUG) + log(f"Returning not used proxies: {[p.name if hasattr(p, 'name') else p for p in ans ]}", level=LogLevel.DEBUG) return ans From e746c58da8855187032ff7248490055bf520f335 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 28 Aug 2024 13:50:59 +0300 Subject: [PATCH 116/171] Fixed var overwritten --- thunder/backend_optimizer/utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index cd4f3c0fc9..a1c55d198c 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -330,17 +330,17 @@ def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, f torch.cuda.synchronize() # Warm up cycles for _ in range(warm_up_iters): - args = clone_args_if_needed(args) - fn(*args) + new_args = clone_args_if_needed(args) + fn(*new_args) # Benchmark torch.cuda.empty_cache() torch.cuda.synchronize() torch.cuda.cudart().cudaProfilerStart() for i in range(iters): - args = clone_args_if_needed(args) + new_args = clone_args_if_needed(args) torch.cuda.empty_cache() torch.cuda.nvtx.range_push(f"thunder benchmark fn:{nvsight_fn_name}, iter{i}") - fn(*args) + fn(*new_args) torch.cuda.nvtx.range_pop() torch.cuda.cudart().cudaProfilerStop() @@ -360,15 +360,15 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl # Warm up cycles for _ in range(warm_up_iters): - args = clone_args_if_needed(args) - out = fn(*args) + new_args = clone_args_if_needed(args) + out = fn(*new_args) # Snapshot request if snapshot: - args = clone_args_if_needed(args) + new_args = clone_args_if_needed(args) torch.cuda.empty_cache() torch.cuda.synchronize() torch.cuda.memory._record_memory_history() - fn(*args) + fn(*new_args) torch.cuda.memory._dump_snapshot(snapshot_name + "_benchmark.pickle") torch.cuda.memory._record_memory_history(enabled=None) # Benchmark @@ -379,12 +379,12 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl torch.cuda.synchronize() for i in range(iters): current_iter = i - args = clone_args_if_needed(args) + new_args = clone_args_if_needed(args) torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) torch.cuda.empty_cache() torch.cuda._sleep(1_000_000) start_events[i].record(stream) - fn(*args) + fn(*new_args) end_events[i].record(stream) max_allocated_bytes = max( max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device()) From de67643b173253ac9be8226c45a24e715322cdc8 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 28 Aug 2024 14:08:17 +0300 Subject: [PATCH 117/171] Tests --- thunder/executors/torch_autograd.py | 2 +- thunder/tests/test_autotuner.py | 135 ++++++++++++++++++++++++++-- 2 files changed, 127 insertions(+), 10 deletions(-) diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index c18a847b0a..756df1477b 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -212,7 +212,7 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat else BackendOptimizer( priority_executors=compile_data.executors_list, apply_bucketing_bw_trace=do_apply_bucketing_bw_trace, - produce_log=True, + produce_log=False, optimizer_type=autotune_type, compile_data=compile_data, ) diff --git a/thunder/tests/test_autotuner.py b/thunder/tests/test_autotuner.py index 51513f725a..4a44c9a966 100644 --- a/thunder/tests/test_autotuner.py +++ b/thunder/tests/test_autotuner.py @@ -1,18 +1,24 @@ -from typing import Callable, Sequence -import thunder.backend_optimizer.utils as aut_utils -import pytest -import torch -import thunder +from thunder.backend_optimizer.optimizer import get_fw_bw_split_backends_options from thunder.core.dtypes import to_torch_dtype from thunder.core.prims import PrimIDs from thunder.core.proxies import FloatProxy, IntegerProxy, TensorProxy -from thunder.core.symbol import BoundSymbol +from thunder.core.symbol import BoundSymbol, Symbol from thunder.core.trace import TraceCtx -from thunder.extend import Executor, get_always_executors -from thunder.executors.torchex import ex as torchex -from thunder.executors.torch_compile import torch_compile_ex +from thunder.executors.cudnnex import cudnn_ex +from thunder.executors.fa3ex import fa3_ex from thunder.executors.nvfuserex import nvfuserex +from thunder.executors.pythonex import ex as pythonex +from thunder.executors.sdpaex import sdpa_ex +from thunder.executors.torch_compile import torch_compile_ex +from thunder.executors.torchex import ex as torchex +from thunder.executors.transformer_engineex import transformer_engine_ex +from thunder.extend import Executor, get_always_executors from thunder.tests.framework import requiresCUDA +from typing import Callable, Sequence +import pytest +import thunder +import thunder.backend_optimizer.utils as aut_utils +import torch class DummyProxy: @@ -451,3 +457,114 @@ def forward(self, x: torch.Tensor): @requiresCUDA def test_reorder_executors_list(executors, expected): assert aut_utils.reorder_executors_list(executors) == expected + +@pytest.mark.parametrize( + "name, expected", + [ + ('linear', [transformer_engine_ex]), + ('scaled_dot_product_attention', [sdpa_ex, cudnn_ex, fa3_ex]) + ], +) +def test_get_fw_bw_split_backends_options(name: str, expected): + symbol = Symbol(name=name) + bsym = BoundSymbol(symbol, (), {}, None) + options = get_fw_bw_split_backends_options(bsym) + assert all(map(lambda v: v in options, expected)) + +class Model_1(torch.nn.Module): + def __init__(self, in_f, out_f) -> None: + super().__init__() + self.linear = torch.nn.Linear(in_f, out_f) + + def forward(self, x): + t0 = self.linear(x) + return torch.nn.functional.silu(t0) + +class Model_2(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.n_head = 12 + self.n_embd = 3072 + self.c_attn = torch.nn.Linear(self.n_embd, 3 * self.n_embd, bias=False) + + def forward(self, x): + B, T, C = x.size() + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + return torch.nn.functional.scaled_dot_product_attention(q, k, v) + +@pytest.mark.parametrize( + "model, tensor_shape, dtype, autotune_type, executors, expected_executors, use_cudagraphs", + [ + (Model_1(32, 32), (32, 32), torch.float32, "runtime", [nvfuserex], [[nvfuserex, torchex, pythonex]], True), + ( + Model_1(32, 32), + (32, 32), + torch.float32, + "memory", + [torch_compile_ex], + [[torch_compile_ex, torchex, pythonex]], + True, + ), + ( + Model_1(4096, 4096), + (128, 4096), + torch.float32, + "runtime", + [transformer_engine_ex], + [[transformer_engine_ex, nvfuserex, torchex, pythonex]], + False, + ), + ( + Model_2(), + (16, 1024, 3072), + torch.float16, + "runtime", + [sdpa_ex, cudnn_ex], + [[sdpa_ex, nvfuserex, torchex, pythonex], [cudnn_ex, nvfuserex, torchex, pythonex]], + False, + ), + ( + Model_2(), + (16, 1024, 3072), + torch.float32, + "runtime", + [sdpa_ex, transformer_engine_ex], + [[sdpa_ex, transformer_engine_ex, nvfuserex, torchex, pythonex]], + False, + ), + ], +) +@requiresCUDA +def test_autotuner( + model: torch.nn.Module, + tensor_shape: tuple, + dtype: torch.dtype, + autotune_type: str, + executors: list, + expected_executors: list[list], + use_cudagraphs: bool, +): + def _run(): + model.to('cuda') + x = torch.randn(tensor_shape, dtype=dtype, device='cuda') + jitted_def = thunder.jit(model, executors=executors) + jitted_auto = thunder.jit(model, autotune_type=autotune_type, executors=executors, use_cudagraphs=use_cudagraphs) + y_def = jitted_def(x) + y_auto = jitted_auto(x) + + te_used = aut_utils.is_te_used(thunder.last_traces(jitted_auto)[-1]) + got = thunder.executors_applied(jitted_auto) + assert any([t == got for t in expected_executors]) + # With TE enabled deviation ((y_def - y_auto).abs().max().item()) is between tensors are ~0.2 + # For the else branch: https://pytorch.org/docs/stable/testing.html + torch.testing.assert_close(y_def, y_auto, atol=2 * 1e-1 if te_used else 1e-5, rtol=1e-1 if te_used else 1.3e-6) + + if dtype != torch.get_default_dtype(): + with torch.autocast(device_type="cuda"): + _run() + else: + _run() + From 56c8eaf5298e0771d663587936b7730c4b8e322a Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 28 Aug 2024 14:11:40 +0300 Subject: [PATCH 118/171] Updated litgpt --- examples/dev/litGPT.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index d4faf1ef79..01703c092a 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -39,6 +39,7 @@ def __init__( executors=[ "cudnn", "sdpa", + "fa3", "nvfuser", "torchcompile", ], @@ -50,6 +51,7 @@ def __init__( executors=[ "cudnn", "sdpa", + "fa3", "nvfuser", "torchcompile", ], @@ -58,14 +60,14 @@ def __init__( 1, "runtime", 1, - executors=["cudnn", "sdpa", "nvfuser", "torchcompile"], + executors=["cudnn", "sdpa", "fa3", "nvfuser", "torchcompile"], model_name="stablecode-completion-alpha-3b", ), Test( 1, "memory", 1, - executors=["cudnn", "sdpa", "nvfuser", "torchcompile"], + executors=["cudnn", "sdpa", "fa3", "nvfuser", "torchcompile"], model_name="stablecode-completion-alpha-3b", ), ] From ff0612545ffb6f23a6a6cc4e5a7b0480f605561b Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 28 Aug 2024 14:24:28 +0300 Subject: [PATCH 119/171] Removed not used --- thunder/backend_optimizer/optimizer.py | 78 +++----------------------- 1 file changed, 9 insertions(+), 69 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 68a945603c..1b5f87b2a6 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -162,7 +162,7 @@ def __init__( self, *, priority_executors: Sequence[Executor], - produce_log: bool = True, + produce_log: bool = False, apply_bucketing_bw_trace: bool, log_file_name: str, optimizer_type: OptimizerType = OptimizerType.RUNTIME, @@ -215,7 +215,7 @@ def __init__( self, *, priority_executors: Sequence[Executor], - produce_log: bool = True, + produce_log: bool = False, apply_bucketing_bw_trace: bool, log_file_name: str, optimizer_type: OptimizerType = OptimizerType.RUNTIME, @@ -288,22 +288,22 @@ def _best_runtime_and_memory_candidates(self, candidates): pair_cost_time = 0 pair_cost_mem = 0 t, m, _ = benchmark_trace(fw, iters=self.benchmark_iters) - # log(f"Pair fw time: {t} ms, mem: {m/(2**30)} GB", level=LogLevel.INFO) + log(f"Pair fw time: {t} ms, mem: {m/(2**30)} GB", level=LogLevel.DEBUG) pair_cost_time = pair_cost_time + t pair_cost_mem = pair_cost_mem + m t, m, _ = benchmark_trace(bw, iters=self.benchmark_iters, fw_trace=fw) - # log(f"Pair bw time: {t} ms, mem: {m/(2**30)} GB", level=LogLevel.INFO) + log(f"Pair bw time: {t} ms, mem: {m/(2**30)} GB", level=LogLevel.DEBUG) pair_cost_time = pair_cost_time + t pair_cost_mem = pair_cost_mem + m if pair_cost_time < min_value_time: best_pair_runtime = OutputCandidate(fw=fw, bw=bw, cost=pair_cost_time) - # log(f"New best runtime pair (no remat):\n{best_pair_runtime}", level=LogLevel.INFO) + log(f"New best runtime pair (no remat):\n{best_pair_runtime}", level=LogLevel.DEBUG) min_value_time = pair_cost_time if pair_cost_mem < min_value_mem: best_pair_memory = OutputCandidate(fw=fw, bw=bw, cost=pair_cost_mem) - # log(f"New best memory pair (no remat):\n{best_pair_memory}", level=LogLevel.INFO) + log(f"New best memory pair (no remat):\n{best_pair_memory}", level=LogLevel.DEBUG) min_value_mem = pair_cost_mem return best_pair_runtime, best_pair_memory @@ -325,27 +325,14 @@ def fw_benchmark(): label = list(pair_time.keys())[0] # TODO (matteochen): remove the benchmark here as will done later on the bw pass c, m, _ = benchmark_trace(trc_time, self.benchmark_iters) - # log( - # f'Benchmark fw end: Trace = [{label}] (time = {c} ms, mem = {m / (2**30)} GB)":\n{trc_time}', - # level=LogLevel.INFO, - # ) self.debug_msg += ( f"Trace name = [{label}] - Target: TIME - Time = {c} ms - Mem = {m / (2**30)} GB\n{trc_time}\n\n" ) c, m, _ = benchmark_trace(trc_mem, self.benchmark_iters) - # log( - # f'Benchmark fw end: Trace = [{label}] (time = {c} ms, mem = {m / (2**30)} GB)":\n{trc_mem}', - # level=LogLevel.INFO, - # ) self.debug_msg += ( f"Trace name = [{label}] - Target: MEM - Mem = {m / (2**30)} GB - Time = {c} ms\n{trc_mem}\n\n" ) # For forward trace we cache the best placement for both runtime and memory for the current Fusion executor (represented by label) - # if compile_opt_time is not None: - # print(f"Caching fw with compile options time: {compile_opt_time.fusion_tag}") - # if compile_opt_mem is not None: - # print(f"Caching fw with compile options mem: {compile_opt_mem.fusion_tag}") - for t, o in zip([trc_time, trc_mem], [compile_opt_time, compile_opt_mem]): log(f'Caching fw candidate [compile option: {o.fusion_tag if o else "None"}]') self.cached_fw_traces.append( @@ -367,10 +354,6 @@ def bw_benchmark(): trace, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0] ) self.debug_msg += f"Trace name = [{label}] - Target: TIME - Time = {trace_time} ms - Mem = {trace_mem / (2**30)} GB\n{trace}\n\n" - # log( - # f'Benchmark trace (target TIME) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}', - # level=LogLevel.INFO, - # ) if trace_time < time_result.runtime: time_result = BenchmarkResult(time=trace_time, memory=trace_mem, trace=trace, label=label, index=i) @@ -383,26 +366,12 @@ def bw_benchmark(): trace_time, trace_mem, res = benchmark_trace( trace, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0] ) - del res self.debug_msg += f"Trace name = [{label}] - Target: MEM - Mem = {trace_mem / (2**30)} GB - Time = {trace_time} ms\n{trace}\n\n" - # log( - # f'Benchmark trace (target MEM) "{label}" (time = {trace_time} ms, mem = {trace_mem / (2**30)} GB:\n{trace}', - # level=LogLevel.INFO, - # ) if trace_mem < memory_result.memory: memory_result = BenchmarkResult( time=trace_time, memory=trace_mem, trace=trace, label=label, index=i ) - # log( - # f'Benchmark end: Best trace time "{time_result.label} (time = {time_result.runtime} ms)":\n{time_result.trace}', - # level=LogLevel.INFO, - # ) - # log( - # f'Benchmark end: Best trace mem "{memory_result.label} (mem = {memory_result.memory / (2 ** 30)} GB)":\n{memory_result.trace}', - # level=LogLevel.INFO, - # ) - # Here we have to recover the traces without the pass through remat in order to be compliant # with thunder flow as we might have request for no remat # Unpack dict @@ -436,7 +405,6 @@ def bw_benchmark(): def _search_candidates(self, increment_factor: int = 1): from thunder.executors.data_dependent_partition import Node, fuse_bound_symbols - from thunder.core.rematerialization import rematerialize_forward_and_backward from thunder.backend_optimizer.utils import ( get_not_used_intermediate_outsputs, sequence_hash, @@ -449,7 +417,7 @@ def get_placed_trace(mapping: dict[str, Executor], bound_symbols_in: Sequence[Bo trc = from_trace(self.trace) trc.bound_symbols = list(bound_symbols_in) - # For this partial trace we have to return all not used tensors otherwise the dce will cut them out + # For this partial trace we have to return all not used tensors otherwise the dce remove them tensors = get_not_used_intermediate_outsputs(trc) forced_return_bsym = self.trace.bound_symbols[-1].from_bsym(args=tensors) @@ -460,8 +428,6 @@ def get_placed_trace(mapping: dict[str, Executor], bound_symbols_in: Sequence[Bo for bsym in trc.bound_symbols: if bsym.sym.name == "return": raise AssertionError("Return statement should not be here") - # executor_configuration.append(empty_executor) - # keys.append('return') elif isinstance(bsym.output, Sequence): seq_hash = sequence_hash(bsym.output) executor_configuration.append(mapping.get(seq_hash, empty_executor)) @@ -678,8 +644,6 @@ def measure_and_update_result(): nonlocal best_placement_mem nonlocal best_keys_mem trc, keys, placements = get_placed_trace(dict_time_strat, increasing_symbols) - # if self.trace_type == TraceType.BW and self.active_fw_trace_ctx[0] is not None: - # _, trc = rematerialize_forward_and_backward(self.active_fw_trace_ctx[0], trc) cost, mem, _ = benchmark_trace(trc, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0]) log(f"Placed trace (cost = {cost} ms, mem = {mem/(2**30)} GB)\n{trc}", level=LogLevel.DEBUG) if cost < best_res_time.runtime or (cost == best_res_time.runtime and mem < best_res_time.memory): @@ -738,26 +702,6 @@ def measure_and_update_result(): # Benchmark measure_and_update_result() - # TODO (matteochen): consider if this can increase placement - # From bottom to up (this will exclude the full region as being handled in the for cycle above) - # -> First iteration is the one with len(fusion_region) - 1 - # -> Last iteration gives no fusion regions - # for j in range(start_idx, start_idx + i + 1, increment_factor): - # match_bsym_output( - # group[j], - # [dict_time_strat, dict_mem_strat], - # get_first_available_operator_executor( - # bsym=group[j], - # executors=self.executors, - # empty_hash=self.empty_executor_hashable_placeholder, - # ), - # ) - # for k in range(start_idx + i + 1, len(group), increment_factor): - # match_bsym_output(group[k], [dict_time_strat, dict_mem_strat], ex) - - # # Benchmark this placement - # measure_and_update_result() - if best_placement_time is None or best_keys_time is None: raise AssertionError("Failed to get best time placement") if best_placement_mem is None or best_keys_mem is None: @@ -812,7 +756,7 @@ def measure_and_update_result(): raise AssertionError(f"Type not handled: {type(bsym.output)}") # For the forward trace we benchmark (memory) the mocked return statement as we don't know which - # tensor will be returned after the rematerialize_forward_and_backward() call in order to do not underestimate the memory consumption + # tensor will be returned after the rematerialize_forward_and_backward call in order to do not underestimate the memory consumption trace = self.trace if self.trace_type == TraceType.FW: trace = from_trace(self.trace) @@ -835,10 +779,6 @@ def measure_and_update_result(): always_executors=self.always_executors, empty_str=self.empty_executor_hashable_placeholder, ) - # print(f"Assigned trace:\n{trc}") - # if self.trace_type == TraceType.BW: - # # pass - # _, trc = rematerialize_forward_and_backward(self.active_fw_trace_ctx[0], trc) container.append({ex.name: trc}) # Save executors in order to generate real fw and bw trace with correct output with the placer @@ -1030,7 +970,7 @@ def __init__( self, *, priority_executors: Sequence[Executor], - produce_log=True, + produce_log=False, apply_bucketing_bw_trace: bool, log_file_name="autotune_debug.log", optimizer_type: OptimizerType = OptimizerType.RUNTIME, From 637d5cebbbc21e8b6adf406eb375aaad2ce743cc Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 28 Aug 2024 19:15:26 +0300 Subject: [PATCH 120/171] Wip on common transformer block replacement --- thunder/backend_optimizer/utils.py | 175 ++++++++++++++++++++++++++++- thunder/tests/test_autotuner.py | 37 ++++++ 2 files changed, 211 insertions(+), 1 deletion(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index a1c55d198c..925fd3320d 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -1,9 +1,10 @@ from collections.abc import Callable, Hashable, Sequence from typing import Any +from thunder.core import symbol from thunder.core.compile_data import get_compile_data from thunder.core.dtypes import to_torch_dtype from thunder.core.prims import PrimIDs -from thunder.core.proxies import AnyProxy, FloatProxy, IntegerProxy, Proxy, TensorProxy, Variable, variableify +from thunder.core.proxies import AnyProxy, FloatProxy, IntegerProxy, NumberProxy, Proxy, TensorProxy, Variable, variableify from thunder.core.symbol import BoundSymbol, Symbol from thunder.core.trace import TraceCtx, get_tracectx, reset_tracectx, set_tracectx from thunder.extend import Executor, FusionExecutor, OperatorExecutor @@ -766,3 +767,175 @@ def reorder_executors_list(executors: Sequence): reordered.insert(0, nvfuser_ex.name if are_inputs_names else nvfuser_ex) return reordered + +def symbol_hash(bsym: BoundSymbol): + def _tensor_hash(t: TensorProxy) -> str: + assert t.dtype + shapes = [str(s) for s in t.shape] + return '{' + '-'.join(shapes) + '/' + str(t.device) + t.dtype.full_name + str(t.requires_grad) + '}' + + def _number_hash(t: NumberProxy) -> str: + return '{' + str(t.value) + '}' + + def _sequence_hash(s: Sequence | None) -> str: + if s is None: + return "None" + + ret = '[' + for e in s: + if e is None: + ret += '{None}' + elif isinstance(e, TensorProxy): + ret += _tensor_hash(e) + ',' + elif isinstance(e, NumberProxy): + ret += _number_hash(e) + ',' + elif isinstance(e, Sequence): + ret += _sequence_hash(e) + ',' + elif isinstance(e, int) or isinstance(e, float) or isinstance(e, bool): + ret += f'{(type(e))}' + else: + raise RuntimeError(f'Not implemented {type(e)}') + return ret + ']' + + def _hash(bsym: BoundSymbol) -> str: + h = bsym.sym.name + # Handle tensor as output or sequences + if not isinstance(bsym.output, TensorProxy) and not isinstance(bsym.output, Sequence): + raise RuntimeError(f'type {type(bsym.output)} not implemented') + h += ( + "#out:" + + (_tensor_hash(bsym.output) if isinstance(bsym.output, TensorProxy) else _sequence_hash(bsym.output)) + + "#in:" + # Args is always a tuple + + _sequence_hash(bsym.args) + ) + return h + + return _hash(bsym) + +def repetead_transformer_blocks( + *, trace: TraceCtx, min_block_size=1, known_points: tuple[BoundSymbol, BoundSymbol] | None = None +) -> list[tuple]: + symbols = [ + s + for s in trace.bound_symbols + if not s.sym.name.startswith("python_del") and not s.sym.name.startswith("unpack") + ] + + def _tuple_name(tup: Sequence): + ret = '(' + for e in tup: + if e is None: + ret += 'None, ' + elif hasattr(e, 'name'): + ret += e.name + ', ' + elif isinstance(e, Sequence): + ret += _tuple_name(e) + ', ' + elif isinstance(e, int) or isinstance(e, float) or isinstance(e, bool): + ret += str(e) + ', ' + else: + raise RuntimeError(f'Not implemented {type(e)}') + return ret + ')' + + # Only bsym that have inputs and outputs + original_map_indexes = { + str(bsym.output.name) if isinstance(bsym.output, TensorProxy) else _tuple_name(bsym.output): i + for i, bsym in enumerate(trace.bound_symbols) + if not (bsym.output is None or not bsym.args) and bsym.sym.id != PrimIDs.RETURN + } + + def _lcs(start_indexes) -> int: + max_last_len = len(symbols)-1 + max_first_len = start_indexes[1] + print(max_first_len, max_last_len) + + lcs = 0 + while (start_indexes[0] < max_first_len and start_indexes[-1] < max_last_len): + # Get all the hashes + hashes = [symbol_hash(symbols[i]) for i in start_indexes] + # Advance if all the hashes coincides + uniques = set(hashes) + # print(f'unique: {len(uniques)}') + if len(uniques) == 1: + start_indexes = [i+1 for i in start_indexes] + lcs += 1 + else: + return lcs + return max(lcs, 1) + + def _skip(bsym: BoundSymbol) -> bool: + return bsym.output is None or not bsym.args + + bsym_indexes: dict[str, list[int]] = {} + for i, bsym in enumerate(symbols): + if i == len(symbols)-1: + break + # Skip None outputs (unpacks, returns, del) + if _skip(bsym): + continue + # print(f'hashing {bsym.sym.name} at index {i}') + h = symbol_hash(bsym) + if h in bsym_indexes: + bsym_indexes[h].append(i) + else: + bsym_indexes[h] = [i] + + def _range_seen(index: int, s: set): + for r in s: + if index >= r[0] and index <= r[1]: + return True + return False + + seen_hashes = set() + seen_ranges = set() + max_lcs = 0 + res = [] + for i, bsym in enumerate(symbols): + if i == len(symbols)-1: + break + # Skip None outputs (unpacks, returns, del) + if _skip(bsym): + continue + + h = symbol_hash(bsym) + # Could not generate hash for this bsym + # Normally, bsym are expected to output a TensorProxy + if not isinstance(bsym.output, Proxy) or h in seen_hashes or _range_seen(i, seen_ranges): + continue + # else: + # print(f'checking lcs for {bsym.sym.name}') + + indexes = bsym_indexes.get(h, []) + print([f'index {i}, out_name: {symbols[i].output.name}' for i in indexes]) + seen_hashes.add(h) + if len(indexes) < 2: + continue + + # Now we can find the longest common sequence between all the occurences + lcs = _lcs(indexes) + print('\n####################') + for index in indexes: + print(f'For index {index} lcs: {lcs}') + print(f'Starting bsym: {symbols[index]}') + print(f'Ending bsym: {symbols[index + lcs - 1]}') + print('\n####################') + if lcs > 1: + # Push every seen ranges to ignore all the subranges + for i in indexes: + seen_ranges.add((i, i+lcs-1)) + + # Set result + if lcs > max_lcs: + max_lcs = lcs + res = [(i, i+lcs-1) for i in indexes] + + print(f'\n\nMax lcs {max_lcs}') + print(res) + for r in res: + print(symbols[r[0]].output.name, symbols[r[1]].output.name) + return [ + (original_map_indexes[symbols[t[0]].output.name], original_map_indexes[symbols[t[1]].output.name]) for t in res + ] + + + diff --git a/thunder/tests/test_autotuner.py b/thunder/tests/test_autotuner.py index 4a44c9a966..0a661a1870 100644 --- a/thunder/tests/test_autotuner.py +++ b/thunder/tests/test_autotuner.py @@ -568,3 +568,40 @@ def _run(): else: _run() + +def test_repetead_transformer_blocks(): + device = 'cuda' + # def _fn(x: torch.Tensor, y: torch.Tensor): + # a = x + x + # b = y * y + # c = x @ y + # aa = a + a + # bb = b * b + # return c, aa, bb + + # a = torch.randn(2,2, device = device) + # b = torch.randn(2,2, device = device) + + # jitted = thunder.jit(_fn) + # jitted(a, b) + + # trace = thunder.last_traces(jitted)[-1] + # aut_utils.repetead_transformer_blocks(trace=trace) + + from thunder.tests.litgpt_model import Config + from litgpt import GPT + cfg = Config.from_name('Llama-3-8B') + cfg.n_layer = 2 + + model = GPT(cfg) + model.to(device) + x = torch.randint(1, model.config.vocab_size, (1, cfg.block_size), device=device, ) + jitted = thunder.jit(model, executors=['sdpa']) + jitted(x) + + # trace = thunder.last_traces(jitted)[-1] + # ret = aut_utils.repetead_transformer_blocks(trace=trace) + # print(ret) + + + From fec126a41358569a0b36b45040eb81b83b9af762 Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 30 Aug 2024 19:28:08 +0300 Subject: [PATCH 121/171] Transformer block optimization --- examples/dev/litGPT.py | 74 +++-- thunder/backend_optimizer/optimizer.py | 81 +++++- thunder/backend_optimizer/utils.py | 362 +++++++++++++++++++++---- thunder/tests/test_autotuner.py | 115 +++++--- 4 files changed, 512 insertions(+), 120 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 01703c092a..fea025192c 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -8,6 +8,7 @@ from thunder.tests.litgpt_model import Config import thunder import torch +import time torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn @@ -22,6 +23,8 @@ def __init__( seq_len: int = -1, model_name: str = "Llama-3-8B", executors=None, + optimize_transformer_blocks=True, + optimize_transformer_min_block_size=60, # for llama3 ) -> None: self.layers = layers self.autotune_type = autotune_type @@ -29,6 +32,8 @@ def __init__( self.seq_len = seq_len self.model_name = model_name self.executors = executors + self.optimize_transformer_blocks = (optimize_transformer_blocks,) + self.optimize_transformer_min_block_size = optimize_transformer_min_block_size layers = [ @@ -45,8 +50,8 @@ def __init__( ], ), Test( - 1, - "memory", + 4, + "runtime", 1, executors=[ "cudnn", @@ -56,20 +61,45 @@ def __init__( "torchcompile", ], ), - Test( - 1, - "runtime", - 1, - executors=["cudnn", "sdpa", "fa3", "nvfuser", "torchcompile"], - model_name="stablecode-completion-alpha-3b", - ), - Test( - 1, - "memory", - 1, - executors=["cudnn", "sdpa", "fa3", "nvfuser", "torchcompile"], - model_name="stablecode-completion-alpha-3b", - ), + # Test( + # 3, + # "runtime", + # 1, + # executors=[ + # "cudnn", + # "sdpa", + # # "fa3", + # "nvfuser", + # "torchcompile", + # ], + # seq_len=128 + # ), + # Test( + # 1, + # "memory", + # 1, + # executors=[ + # "cudnn", + # "sdpa", + # "fa3", + # "nvfuser", + # "torchcompile", + # ], + # ), + # Test( + # 1, + # "runtime", + # 1, + # executors=["cudnn", "sdpa", "fa3", "nvfuser", "torchcompile"], + # model_name="stablecode-completion-alpha-3b", + # ), + # Test( + # 1, + # "memory", + # 1, + # executors=["cudnn", "sdpa", "fa3", "nvfuser", "torchcompile"], + # model_name="stablecode-completion-alpha-3b", + # ), ] for test in layers: @@ -78,7 +108,7 @@ def __init__( cfg.n_layer = test.layers if test.seq_len != -1: cfg.block_size = test.seq_len - torch.set_default_dtype(torch.bfloat16) + torch.set_default_dtype(torch.float32) print(cfg) with torch.device("cuda"): model = GPT(cfg) @@ -93,10 +123,14 @@ def __init__( autotune_type=test.autotune_type, executors=test.executors, use_cudagraphs=False, + optimize_common_blocks=test.optimize_transformer_blocks, + optimize_common_blocks_min_size=test.optimize_transformer_min_block_size, ) - print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item()) + s = time.time_ns() print("deviation auto:", (jmodel_auto(x) - model(x)).abs().max().item()) + e = time.time_ns() + print("Compilation time:", {(e - s) / 1000000000}, "s") iters = 40 fw_traces = [ @@ -109,14 +143,14 @@ def __init__( ] fw_labels = ["fw_def", "fw_auto"] bw_labels = ["bw_def", "bw_auto"] - print('\n\n####################################################', test.model_name) + print("\n\n####################################################", test.model_name) print(f"Results thunder benchmark ({iters} iters):") thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters, nvsight=False) # thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, 10, nvsight=True) print(f"\n\nResults torch fw bw benchmark ({iters} iters):") callables = [eager, torch_compile, jmodel_def, jmodel_auto] - labels = ['eager', 'torch.compile', 'Thunder', 'Thunder Autotuner'] + labels = ["eager", "torch.compile", "Thunder", "Thunder Autotuner"] inputs = [x, x, x, x] torch_fw_bw_benchmark(callables, labels, inputs, iters) print(f"\n\nResults torch total benchmark ({iters} iters):") diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 1b5f87b2a6..c8cf4a3a9f 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -1,6 +1,11 @@ from collections.abc import Callable, Sequence from enum import Enum -from thunder.backend_optimizer.utils import operation_in_trace, wrap_fn_with_exeuctor_compile_option +from thunder.backend_optimizer.utils import ( + map_executors_from_reduced_trace_to_complete_trace, + operation_in_trace, + wrap_fn_with_exeuctor_compile_option, +) +from thunder.core.compile_data import get_compile_data from thunder.core.prims import PrimIDs from thunder.core.proxies import CollectionProxy, FloatProxy, IntegerProxy, TensorProxy from thunder.core.symbol import BoundSymbol @@ -239,12 +244,15 @@ def __init__( self.known_fusion_ex_compile_options: dict[str | Hashable, list[FusionCompileOptionsHelper]] = { "nvfuser": [ - FusionCompileOptionsHelper("nv_enable_linear", "linear", PrimIDs.LINEAR, linear, _linear_check), + # FusionCompileOptionsHelper("nv_enable_linear", "linear", PrimIDs.LINEAR, linear, _linear_check), # FusionCompileOptionsHelper("nv_enable_matmul", "matmul", PrimIDs.MATMUL, matmul, _matmul_check), # FusionCompileOptionsHelper("nv_enable_bookend", "bookend"), ] } + self.is_reduced: bool = False + self.cached_original_trace: TraceCtx | None = None + """ ################################################## Internal methods ################################################## """ @@ -350,7 +358,7 @@ def bw_benchmark(): # Unpack the dict label = list(pair.keys())[0] trace = list(pair.values())[0] - trace_time, trace_mem, res = benchmark_trace( + trace_time, trace_mem, _ = benchmark_trace( trace, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0] ) self.debug_msg += f"Trace name = [{label}] - Target: TIME - Time = {trace_time} ms - Mem = {trace_mem / (2**30)} GB\n{trace}\n\n" @@ -363,7 +371,7 @@ def bw_benchmark(): label = list(pair.keys())[0] trace = list(pair.values())[0] - trace_time, trace_mem, res = benchmark_trace( + trace_time, trace_mem, _ = benchmark_trace( trace, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0] ) self.debug_msg += f"Trace name = [{label}] - Target: MEM - Mem = {trace_mem / (2**30)} GB - Time = {trace_time} ms\n{trace}\n\n" @@ -519,12 +527,12 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): log(f"Number of Fusion groups = {len(bound_symbol_groups)}", level=LogLevel.DEBUG) # Print fusion groups if requested - for id, group in enumerate(bound_symbol_groups): - log(f"Group id: {id}", level=LogLevel.DEBUG) - for sub in group: - log(f"{sub.sym.name} -> out: {sub.output}", level=LogLevel.DEBUG) - if log_level == LogLevel.DEBUG and len(group) > 0: - print("\n") + # for id, group in enumerate(bound_symbol_groups): + # log(f"Group id: {id}", level=LogLevel.DEBUG) + # for sub in group: + # log(f"{sub.sym.name} -> out: {sub.output}", level=LogLevel.DEBUG) + # if log_level == LogLevel.DEBUG and len(group) > 0: + # print("\n") dict_time_strat: dict[str, Executor] = {} dict_mem_strat: dict[str, Executor] = {} @@ -678,8 +686,9 @@ def measure_and_update_result(): n_missing_bsyms = len(group) - start_idx # TODO (matteochen): consider to add the iteration with no fusion regions for i in range(0, n_missing_bsyms, n_missing_bsyms - 1 if self.trace_type == TraceType.BW else 1): - if ex.name == 'torchcompile': + if ex.name == "torchcompile": import torch + torch.compiler.reset() # for i in range(0, n_missing_bsyms): @@ -863,6 +872,7 @@ def optimize(self): from thunder.core.transform_common import dce from thunder.executors.torch_autograd import update_bw_from_forward_optimization from thunder.backend_optimizer.utils import assign_executors + from thunder.backend_optimizer.utils import repetead_trace_blocks, reduce_common_trace_blocks def _optimize(): # Reset fusion helpers @@ -870,8 +880,38 @@ def _optimize(): # Reset helpers data structures self.executor_placement_options = ExecutorPlacementOptions() + cd = get_compile_data() + # Check if common blocks optimization is requested + optimize_common_blocks = False if cd is None else cd.compile_options.get("optimize_common_blocks", False) + optimize_common_blocks_min_size = ( + -1 if cd is None else cd.compile_options.get("optimize_common_blocks_min_size", -1) + ) + + # Cut the compilation time if possible + common_trace_blocks = repetead_trace_blocks( + trace=self.trace, min_block_size=optimize_common_blocks_min_size if optimize_common_blocks else -1 + ) + # print(common_trace_blocks) + if len(common_trace_blocks) >= 2 and optimize_common_blocks: + log(f"Common blocks found {common_trace_blocks}", level=LogLevel.DEBUG) + reduced_trace = reduce_common_trace_blocks(trace=self.trace, common_blocks_in=common_trace_blocks) + log( + f"Operating on reduced trace (by cutting common transformer blocks):\n{reduced_trace}", + level=LogLevel.INFO, + ) + self.is_reduced = True + self.cached_original_trace = self.trace + self.trace = reduced_trace + else: + log( + "Optimizing the whole trace directly. No common transformer block optimization will be applied.", + level=LogLevel.INFO, + ) + + # This performs executor search self._search_candidates() + # From now on we have the optimized executors for each trace region. Apply them... if len(self.executor_placement_options.placement_options_time) != len( self.fusion_executors_saved_for_later ): @@ -882,6 +922,25 @@ def _optimize(): raise AssertionError( f"Unexpected mem placement options size: {len(self.executor_placement_options.placement_options_mem)}. Expected: {len(self.fusion_executors_saved_for_later)}" ) + + # If we optimized the reduced trace we now can share the placing with other blocks + if self.is_reduced and self.cached_original_trace is not None: + for placement_ctx in self.executor_placement_options.placement_options_time: + placement = map_executors_from_reduced_trace_to_complete_trace( + self.cached_original_trace, common_trace_blocks, placement_ctx.placement + ) + placement_ctx.placement = placement + + for placement_ctx in self.executor_placement_options.placement_options_mem: + placement = map_executors_from_reduced_trace_to_complete_trace( + self.cached_original_trace, common_trace_blocks, placement_ctx.placement + ) + placement_ctx.placement = placement + + # Reset original trace + self.trace = self.cached_original_trace + + # We will create the best compute time and peak memory consumption placement for each fusion executor for placement_ctx, ex in zip( self.executor_placement_options.placement_options_time, self.fusion_executors_saved_for_later ): diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 925fd3320d..e87493823a 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -1,17 +1,27 @@ from collections.abc import Callable, Hashable, Sequence from typing import Any -from thunder.core import symbol + from thunder.core.compile_data import get_compile_data from thunder.core.dtypes import to_torch_dtype from thunder.core.prims import PrimIDs -from thunder.core.proxies import AnyProxy, FloatProxy, IntegerProxy, NumberProxy, Proxy, TensorProxy, Variable, variableify +from thunder.core.proxies import ( + AnyProxy, + FloatProxy, + IntegerProxy, + NumberProxy, + Proxy, + TensorProxy, + Variable, + variableify, +) from thunder.core.symbol import BoundSymbol, Symbol -from thunder.core.trace import TraceCtx, get_tracectx, reset_tracectx, set_tracectx +from thunder.core.trace import TraceCtx, from_trace, get_tracectx, reset_tracectx, set_tracectx from thunder.extend import Executor, FusionExecutor, OperatorExecutor from thunder.core.utils import check, safe_map_flat import thunder.core.transforms as transforms from itertools import chain import torch +from thunder.core.dtypes import dtype # Maybe we can use id(s) @@ -768,40 +778,49 @@ def reorder_executors_list(executors: Sequence): return reordered + def symbol_hash(bsym: BoundSymbol): + # Maintainig essential metadata def _tensor_hash(t: TensorProxy) -> str: assert t.dtype shapes = [str(s) for s in t.shape] - return '{' + '-'.join(shapes) + '/' + str(t.device) + t.dtype.full_name + str(t.requires_grad) + '}' + return "{" + "-".join(shapes) + "/" + str(t.device) + t.dtype.full_name + str(t.requires_grad) + "}" def _number_hash(t: NumberProxy) -> str: - return '{' + str(t.value) + '}' + return "{" + str(t.value) + "}" + + def _any_proxy_hash(p: AnyProxy) -> str: + return "{" + p.__repr__() + "}" def _sequence_hash(s: Sequence | None) -> str: if s is None: return "None" - ret = '[' + ret = "[" for e in s: if e is None: - ret += '{None}' + ret += "{None}" elif isinstance(e, TensorProxy): - ret += _tensor_hash(e) + ',' + ret += _tensor_hash(e) + "," elif isinstance(e, NumberProxy): - ret += _number_hash(e) + ',' + ret += _number_hash(e) + "," elif isinstance(e, Sequence): - ret += _sequence_hash(e) + ',' + ret += _sequence_hash(e) + "," + elif isinstance(e, AnyProxy): + ret += _any_proxy_hash(e) elif isinstance(e, int) or isinstance(e, float) or isinstance(e, bool): - ret += f'{(type(e))}' + ret += f"{(type(e))}" + elif isinstance(e, dtype): + ret += f"{(type(e))}" else: - raise RuntimeError(f'Not implemented {type(e)}') - return ret + ']' + raise RuntimeError(f"Not implemented {type(e)}. Failed bsym: {bsym}") + return ret + "]" def _hash(bsym: BoundSymbol) -> str: h = bsym.sym.name # Handle tensor as output or sequences if not isinstance(bsym.output, TensorProxy) and not isinstance(bsym.output, Sequence): - raise RuntimeError(f'type {type(bsym.output)} not implemented') + raise RuntimeError(f"type {type(bsym.output)} not implemented") h += ( "#out:" + (_tensor_hash(bsym.output) if isinstance(bsym.output, TensorProxy) else _sequence_hash(bsym.output)) @@ -813,9 +832,15 @@ def _hash(bsym: BoundSymbol) -> str: return _hash(bsym) -def repetead_transformer_blocks( + +# Both lhs and rhs are included in the range +# TODO: known_points can be used to detect start and end of a block sequence +def repetead_trace_blocks( *, trace: TraceCtx, min_block_size=1, known_points: tuple[BoundSymbol, BoundSymbol] | None = None ) -> list[tuple]: + if min_block_size < 2: + return [] + symbols = [ s for s in trace.bound_symbols @@ -823,19 +848,19 @@ def repetead_transformer_blocks( ] def _tuple_name(tup: Sequence): - ret = '(' + ret = "(" for e in tup: if e is None: - ret += 'None, ' - elif hasattr(e, 'name'): - ret += e.name + ', ' + ret += "None, " + elif hasattr(e, "name"): + ret += e.name + ", " elif isinstance(e, Sequence): - ret += _tuple_name(e) + ', ' + ret += _tuple_name(e) + ", " elif isinstance(e, int) or isinstance(e, float) or isinstance(e, bool): - ret += str(e) + ', ' + ret += str(e) + ", " else: - raise RuntimeError(f'Not implemented {type(e)}') - return ret + ')' + raise RuntimeError(f"Not implemented {type(e)}") + return ret + ")" # Only bsym that have inputs and outputs original_map_indexes = { @@ -845,19 +870,17 @@ def _tuple_name(tup: Sequence): } def _lcs(start_indexes) -> int: - max_last_len = len(symbols)-1 + max_last_len = len(symbols) - 1 max_first_len = start_indexes[1] - print(max_first_len, max_last_len) lcs = 0 - while (start_indexes[0] < max_first_len and start_indexes[-1] < max_last_len): + while start_indexes[0] < max_first_len and start_indexes[-1] < max_last_len: # Get all the hashes hashes = [symbol_hash(symbols[i]) for i in start_indexes] # Advance if all the hashes coincides uniques = set(hashes) - # print(f'unique: {len(uniques)}') if len(uniques) == 1: - start_indexes = [i+1 for i in start_indexes] + start_indexes = [i + 1 for i in start_indexes] lcs += 1 else: return lcs @@ -868,12 +891,11 @@ def _skip(bsym: BoundSymbol) -> bool: bsym_indexes: dict[str, list[int]] = {} for i, bsym in enumerate(symbols): - if i == len(symbols)-1: + if i == len(symbols) - 1: break # Skip None outputs (unpacks, returns, del) if _skip(bsym): continue - # print(f'hashing {bsym.sym.name} at index {i}') h = symbol_hash(bsym) if h in bsym_indexes: bsym_indexes[h].append(i) @@ -891,7 +913,7 @@ def _range_seen(index: int, s: set): max_lcs = 0 res = [] for i, bsym in enumerate(symbols): - if i == len(symbols)-1: + if i == len(symbols) - 1: break # Skip None outputs (unpacks, returns, del) if _skip(bsym): @@ -902,40 +924,288 @@ def _range_seen(index: int, s: set): # Normally, bsym are expected to output a TensorProxy if not isinstance(bsym.output, Proxy) or h in seen_hashes or _range_seen(i, seen_ranges): continue - # else: - # print(f'checking lcs for {bsym.sym.name}') indexes = bsym_indexes.get(h, []) - print([f'index {i}, out_name: {symbols[i].output.name}' for i in indexes]) seen_hashes.add(h) if len(indexes) < 2: continue # Now we can find the longest common sequence between all the occurences lcs = _lcs(indexes) - print('\n####################') - for index in indexes: - print(f'For index {index} lcs: {lcs}') - print(f'Starting bsym: {symbols[index]}') - print(f'Ending bsym: {symbols[index + lcs - 1]}') - print('\n####################') + # print('\n####################') + # for index in indexes: + # print(f'For index {index} lcs: {lcs}') + # print(f'Starting bsym: {symbols[index]}') + # print(f'Ending bsym: {symbols[index + lcs - 1]}') + # print('\n####################') if lcs > 1: # Push every seen ranges to ignore all the subranges for i in indexes: - seen_ranges.add((i, i+lcs-1)) + seen_ranges.add((i, i + lcs - 1)) # Set result if lcs > max_lcs: max_lcs = lcs - res = [(i, i+lcs-1) for i in indexes] + res = [(i, i + lcs - 1) for i in indexes] - print(f'\n\nMax lcs {max_lcs}') - print(res) - for r in res: - print(symbols[r[0]].output.name, symbols[r[1]].output.name) + if max_lcs < min_block_size: + return [] + + # print(f'\n\nMax lcs {max_lcs}') + # print(res) + + # for r in res: + # print(symbols[r[0]].output.name, symbols[r[1]].output.name) return [ (original_map_indexes[symbols[t[0]].output.name], original_map_indexes[symbols[t[1]].output.name]) for t in res ] +# What is regions_between_blocks? +# They are trace regions between one transformer block and the next one in the backward pass and given that these regions are not present +# at the end of the last transformer block it means that they are needed in order to prepare shapes or strides +# for the block at i+1 from the output of block i. +# For example if common blocks looks like: [(32, 155), (157, 280)] +# the symbol at index 156 (the gap) looks like: +# In the forward trace we have not these gaps (so far). +def _regions_between_blocks(trace: TraceCtx, common_blocks: list[tuple]) -> int: + def _assert_args(seq_a: Sequence, seq_b: Sequence): + assert len(seq_a) == len(seq_b) + for a, b in zip(seq_a, seq_b): + assert type(a) == type(b) + if isinstance(a, TensorProxy): + assert a.shape == b.shape + assert a.dtype == b.dtype + elif isinstance(a, Sequence): + _assert_args(a, b) + + regions_between_blocks = common_blocks[1][0] - common_blocks[0][1] - 1 + trace_region_between_common_blocks = trace.bound_symbols[common_blocks[0][1] + 1 : common_blocks[1][0]] + for i in range(1, len(common_blocks)): + if not common_blocks[i][0] - common_blocks[i - 1][1] - 1 == regions_between_blocks: + raise AssertionError( + "Trace configuration not supported. All the trace regions between common blocks are expected to have the same number of instructions." + ) + + # Check that the trace regions are equal + test_trace_regions = trace.bound_symbols[common_blocks[i - 1][1] + 1 : common_blocks[i][0]] + assert len(test_trace_regions) == len(trace_region_between_common_blocks) + for a, b in zip(test_trace_regions, trace_region_between_common_blocks): + assert a.sym.name == b.sym.name + _assert_args(a.args, b.args) + + return regions_between_blocks + + +def _indices_to_exclude_between_common_blocks(common_blocks: list[tuple]) -> list: + if len(common_blocks) < 2: + return [] + + ret = [] + for i in range(1, len(common_blocks)): + start_gap_index = common_blocks[i - 1][1] + 1 + end_gap_index = common_blocks[i][0] - 1 + ret.extend([j for j in range(start_gap_index, end_gap_index + 1)]) + return ret + + +def reduce_common_trace_blocks( + *, trace: TraceCtx, common_blocks_in: list[tuple], skip_between_blocks: bool = True +) -> TraceCtx: + def _exclude(blocks: list[tuple[int, int]], index: int, black_list: set): + # Exclude if the index is in a repeated block + for block in blocks: + if index >= block[0] and index <= block[1]: + return True + + # Exclude if it marked as to remove + if index in black_list and skip_between_blocks: + return True + return False + + def _find_bsym_index(out_name: str, space: Sequence[BoundSymbol]) -> int: + for i, b in enumerate(space): + if b.output is not None and hasattr(b.output, "name") and b.output.name == out_name: + return i + raise RuntimeError(f"Can not found bsym with output {out_name} in the search space.") + + common_blocks = list(common_blocks_in) + if len(common_blocks) < 2: + trc = from_trace(trace) + trc.bound_symbols = list(trace.bound_symbols) + return trc + + # Create a mapping where we can easily find to which block a specific output blongs + output_to_block: dict[str, tuple[int, int]] = {} + for n_block, block in enumerate(common_blocks): + for i in range(block[0], block[1] + 1): + bsym = trace.bound_symbols[i] + if not hasattr(bsym.output, "name"): + continue + output_to_block[bsym.output.name] = (n_block, i - block[0]) + + # Check that we maintain the pattern + regions_between_blocks = _regions_between_blocks(trace, common_blocks) + + # We have to exlude these gaps indices from the reduce trace + index_gaps_to_exclude = [] + if regions_between_blocks: + index_gaps_to_exclude = _indices_to_exclude_between_common_blocks(common_blocks) + # Make it fast to search in + index_gaps_to_exclude = set(index_gaps_to_exclude) + + # Create reduce trace regions + bound_symbols: list[BoundSymbol] = [ + b for i, b in enumerate(trace.bound_symbols) if not _exclude(common_blocks[1:], i, index_gaps_to_exclude) + ] + + # Retrive first and last blocks + first_block = common_blocks[0] + # common_blocks = common_blocks[1:] + + # Now, we have to update the trace region inputs after the last block to accepts the outputs of the first block, if it's not the return statement + if trace.bound_symbols[common_blocks[-1][1] + 1].sym.id != PrimIDs.RETURN: + # first_block_outputs = trace.bound_symbols[first_block[1]].output + # last_block_outputs = trace.bound_symbols[common_blocks[-1][1]].output + + # if not isinstance(first_block_outputs, Sequence): + # first_block_outputs = [first_block_outputs] + # if not isinstance(last_block_outputs, Sequence): + # last_block_outputs = [last_block_outputs] + + symbol_to_correct_index = _find_bsym_index( + trace.bound_symbols[common_blocks[-1][1] + 1].output.name, bound_symbols + ) + symbol_to_correct = bound_symbols[symbol_to_correct_index] + + def _correct_args(target: BoundSymbol): + args = [] + for arg in target.args: + if arg is None: + args.append(None) + elif hasattr(arg, "name") and arg.name in output_to_block: + _, index_in_block = output_to_block[arg.name] + # Recover the argument from the first block + args.append(trace.bound_symbols[common_blocks[0][0] + index_in_block].output) + elif isinstance(arg, Sequence): + raise RuntimeError("Not implemented") + else: + args.append(arg) + return args + + def _correct_bsym(bsym: BoundSymbol) -> BoundSymbol: + bsym = bsym.from_bsym(args=_correct_args(bsym)) + return bsym + + new_subsymbols = [] + for sub in symbol_to_correct.subsymbols: + new_sub = _correct_bsym(sub) + new_subsymbols.append(new_sub) + + bound_symbols[symbol_to_correct_index] = symbol_to_correct.from_bsym( + args=_correct_args(symbol_to_correct), subsymbols=new_subsymbols + ) + + # print(bound_symbols[symbol_to_correct_index]) + + # We need to check also the return statements as we have fewer args now + flatten_bsyms = flatten_sequence([b.output for b in bound_symbols]) + args_remained = set([b.name for b in flatten_bsyms if b is not None and hasattr(b, "name")]) + # Fw trace + if isinstance(bound_symbols[-1].args[0], dict): + saved_for_backward = tuple( + [e for e in bound_symbols[-1].args[1][0] if hasattr(e, "name") and e.name in args_remained] + ) + if isinstance(bound_symbols[-1].args[0]["output"], Sequence): + output = tuple( + [o for o in bound_symbols[-1].args[0]["output"] if hasattr(o, "name") and o.name in args_remained] + ) + else: + output = bound_symbols[-1].args[0]["output"] + flat_output = tuple( + [o for o in bound_symbols[-1].args[0]["flat_output"] if hasattr(o, "name") and o.name in args_remained] + ) + new_dict = {"output": output, "flat_output": flat_output, "flat_args": bound_symbols[-1].args[0]["flat_args"]} + + # Create the new args and substitute return symbol + bsym = bound_symbols[-1].from_bsym(args=(new_dict, (saved_for_backward, bound_symbols[-1].args[1][1]))) + bound_symbols[-1] = bsym + # Bw trace + else: + + def _returned(seq: Sequence) -> tuple: + ret = [] + for e in seq: + if e is None: + ret.append(None) + elif isinstance(e, Sequence): + ret.append(_returned(e)) + elif isinstance(e, Proxy) and e.name in args_remained: + ret.append(e) + elif not isinstance(e, Proxy): + raise RuntimeError(f"type not recognized: {type(e)}") + + return tuple(ret) + + # Backward output is a tuple, and generally a tuple of tuple (()) + original_returned = bound_symbols[-1].args + returned = _returned(original_returned) + bound_symbols[-1] = bound_symbols[-1].from_bsym(args=returned) + + extrace: TraceCtx = from_trace(trace) + extrace.bound_symbols = bound_symbols + return extrace + + +# NOTE: This implementation currently relies on the fact that transformer blocks are contiguous in trace or they have a common gap region between them (in case for bw trace). +# TODO: generalize this +def map_executors_from_reduced_trace_to_complete_trace( + complete_trace: TraceCtx, common_blocks: list[tuple], ex_mappings: list[Executor] +) -> list[Executor]: + from thunder.executors.torchex import ex as torch_ex + + if len(common_blocks) < 2: + raise AssertionError("No common block found") + + # Check that we maintain the pattern + regions_between_blocks = _regions_between_blocks(complete_trace, common_blocks) + + # These are the trace region indices (referred to the complete trace) that we have excluded from the reduced trace optimization. + # We have also to integrate their executors. + # By default torchex will be used as currently no complex (optimizable) ops are present so far (they are usually reshape ops). + indices_excluded: list = _indices_to_exclude_between_common_blocks(common_blocks) + + # Correctness assertion + if regions_between_blocks: + assert len(indices_excluded) % regions_between_blocks == 0 + assert len(indices_excluded) // regions_between_blocks == len(common_blocks) - 1 + + # Solution starting point: copy up to the end of the first common block + complete_trace_executors: list[Executor] = ex_mappings[: common_blocks[0][1] + 1] + # Get the executors sequence to share from the first block to all the other equal blocks. + to_share: list[Executor] = [] + for i in range(len(common_blocks) - 1): + # First region bewteen block, adding here as this was not present in the reduce trace (not found in the ex_mappings structure) + if i == 0: + to_share.extend([torch_ex] * regions_between_blocks) + + to_share.extend(ex_mappings[common_blocks[0][0] : common_blocks[0][1] + 1]) + + # We have to add back the excluded regions (see comment 15 lines above). + if i < len(common_blocks) - 2: + to_share.extend([torch_ex] * regions_between_blocks) + + # Extend by sharing mappings of transformer blocks + complete_trace_executors.extend(to_share) + # Extend with the remained bsyms + complete_trace_executors.extend(ex_mappings[common_blocks[0][1] + 1 :]) + + # Check that we have all the executors needed + len_got = len(complete_trace_executors) + len_expected = len(complete_trace.bound_symbols) + if len_got != len_expected: + raise AssertionError( + f"Trace regions size is different from the obtained executors lenght: {len_expected} - {len_got}" + ) + return complete_trace_executors diff --git a/thunder/tests/test_autotuner.py b/thunder/tests/test_autotuner.py index 0a661a1870..7fdf2f9bd8 100644 --- a/thunder/tests/test_autotuner.py +++ b/thunder/tests/test_autotuner.py @@ -458,12 +458,10 @@ def forward(self, x: torch.Tensor): def test_reorder_executors_list(executors, expected): assert aut_utils.reorder_executors_list(executors) == expected + @pytest.mark.parametrize( "name, expected", - [ - ('linear', [transformer_engine_ex]), - ('scaled_dot_product_attention', [sdpa_ex, cudnn_ex, fa3_ex]) - ], + [("linear", [transformer_engine_ex]), ("scaled_dot_product_attention", [sdpa_ex, cudnn_ex, fa3_ex])], ) def test_get_fw_bw_split_backends_options(name: str, expected): symbol = Symbol(name=name) @@ -471,6 +469,7 @@ def test_get_fw_bw_split_backends_options(name: str, expected): options = get_fw_bw_split_backends_options(bsym) assert all(map(lambda v: v in options, expected)) + class Model_1(torch.nn.Module): def __init__(self, in_f, out_f) -> None: super().__init__() @@ -480,6 +479,7 @@ def forward(self, x): t0 = self.linear(x) return torch.nn.functional.silu(t0) + class Model_2(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -489,12 +489,13 @@ def __init__(self) -> None: def forward(self, x): B, T, C = x.size() - q, k, v = self.c_attn(x).split(self.n_embd, dim=2) - k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) return torch.nn.functional.scaled_dot_product_attention(q, k, v) + @pytest.mark.parametrize( "model, tensor_shape, dtype, autotune_type, executors, expected_executors, use_cudagraphs", [ @@ -532,7 +533,10 @@ def forward(self, x): torch.float32, "runtime", [sdpa_ex, transformer_engine_ex], - [[sdpa_ex, transformer_engine_ex, nvfuserex, torchex, pythonex]], + [ + [sdpa_ex, transformer_engine_ex, nvfuserex, torchex, pythonex], + [sdpa_ex, transformer_engine_ex, nvfuserex, torchex, pythonex], + ], False, ), ], @@ -548,15 +552,19 @@ def test_autotuner( use_cudagraphs: bool, ): def _run(): - model.to('cuda') - x = torch.randn(tensor_shape, dtype=dtype, device='cuda') + model.to("cuda") + x = torch.randn(tensor_shape, dtype=dtype, device="cuda") jitted_def = thunder.jit(model, executors=executors) - jitted_auto = thunder.jit(model, autotune_type=autotune_type, executors=executors, use_cudagraphs=use_cudagraphs) + jitted_auto = thunder.jit( + model, autotune_type=autotune_type, executors=executors, use_cudagraphs=use_cudagraphs + ) y_def = jitted_def(x) y_auto = jitted_auto(x) te_used = aut_utils.is_te_used(thunder.last_traces(jitted_auto)[-1]) got = thunder.executors_applied(jitted_auto) + print("got", got) + print("expected", expected_executors) assert any([t == got for t in expected_executors]) # With TE enabled deviation ((y_def - y_auto).abs().max().item()) is between tensors are ~0.2 # For the else branch: https://pytorch.org/docs/stable/testing.html @@ -569,39 +577,60 @@ def _run(): _run() +""" +The longest repeated block is: + t2 = x @ y + t3 = t0 + t0 + t4 = t1 * t1 +""" + + +def _test_repetead_transformer_blocks_fn(x: torch.Tensor, y: torch.Tensor): + t0 = x + x + t1 = y * y + t2 = x @ y + t3 = t0 + t0 + t4 = t1 * t1 + t5 = t2 @ t2 + t6 = t3 + t3 + t7 = t4 * t4 + t8 = t6 - t7 + return t8, t5 + + def test_repetead_transformer_blocks(): - device = 'cuda' - # def _fn(x: torch.Tensor, y: torch.Tensor): - # a = x + x - # b = y * y - # c = x @ y - # aa = a + a - # bb = b * b - # return c, aa, bb - - # a = torch.randn(2,2, device = device) - # b = torch.randn(2,2, device = device) - - # jitted = thunder.jit(_fn) - # jitted(a, b) - - # trace = thunder.last_traces(jitted)[-1] - # aut_utils.repetead_transformer_blocks(trace=trace) - - from thunder.tests.litgpt_model import Config - from litgpt import GPT - cfg = Config.from_name('Llama-3-8B') - cfg.n_layer = 2 - - model = GPT(cfg) - model.to(device) - x = torch.randint(1, model.config.vocab_size, (1, cfg.block_size), device=device, ) - jitted = thunder.jit(model, executors=['sdpa']) - jitted(x) + device = "cpu" + + a = torch.randn(2, 2, device=device) + b = torch.randn(2, 2, device=device) + + jitted = thunder.jit(_test_repetead_transformer_blocks_fn, disable_dce=True) + jitted(a, b) + + trace = thunder.last_traces(jitted)[-1] + print(trace) + blocks = aut_utils.repetead_trace_blocks(trace=trace) + assert len(blocks) == 2 + assert blocks[0][1] - blocks[0][0] + 1 == 3 - # trace = thunder.last_traces(jitted)[-1] - # ret = aut_utils.repetead_transformer_blocks(trace=trace) - # print(ret) +def test_reduce_common_trace_blocks(): + device = "cpu" + a = torch.randn(2, 2, device=device) + b = torch.randn(2, 2, device=device) + + jitted = thunder.jit(_test_repetead_transformer_blocks_fn, disable_dce=True) + jitted(a, b) + + trace = thunder.last_traces(jitted)[-1] + blocks = aut_utils.repetead_trace_blocks(trace=trace) + reduced_trace = aut_utils.reduce_common_trace_blocks( + trace=trace, common_blocks_in=blocks, skip_between_blocks=False + ) + # We expect that t5, t6, t7 have been removed + should_remove = set(["t5", "t6", "t7"]) + for b in reduced_trace.bound_symbols: + if hasattr(b.output, "name"): + assert b.output.name not in should_remove From cbc4bb647dbaf9fb0901f1471c2295c1c75e67d7 Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 30 Aug 2024 19:48:05 +0300 Subject: [PATCH 122/171] Fixed bad def value --- thunder/backend_optimizer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index e87493823a..21c20007b1 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -836,7 +836,7 @@ def _hash(bsym: BoundSymbol) -> str: # Both lhs and rhs are included in the range # TODO: known_points can be used to detect start and end of a block sequence def repetead_trace_blocks( - *, trace: TraceCtx, min_block_size=1, known_points: tuple[BoundSymbol, BoundSymbol] | None = None + *, trace: TraceCtx, min_block_size=2, known_points: tuple[BoundSymbol, BoundSymbol] | None = None ) -> list[tuple]: if min_block_size < 2: return [] From a14a155e4ca8df8d7fb50735b2207ecda2ca0db6 Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 30 Aug 2024 22:49:51 +0300 Subject: [PATCH 123/171] Fixed comment --- thunder/backend_optimizer/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 21c20007b1..69280b387c 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -951,11 +951,11 @@ def _range_seen(index: int, s: set): if max_lcs < min_block_size: return [] - # print(f'\n\nMax lcs {max_lcs}') + print(f'\n\nMax lcs {max_lcs}') # print(res) - # for r in res: - # print(symbols[r[0]].output.name, symbols[r[1]].output.name) + for r in res: + print(symbols[r[0]].output.name, symbols[r[1]].output.name) return [ (original_map_indexes[symbols[t[0]].output.name], original_map_indexes[symbols[t[1]].output.name]) for t in res ] @@ -1035,7 +1035,7 @@ def _find_bsym_index(out_name: str, space: Sequence[BoundSymbol]) -> int: trc.bound_symbols = list(trace.bound_symbols) return trc - # Create a mapping where we can easily find to which block a specific output blongs + # Create a mapping where we can easily find to which block a specific output belongs output_to_block: dict[str, tuple[int, int]] = {} for n_block, block in enumerate(common_blocks): for i in range(block[0], block[1] + 1): From 5e0c379f0490ab7b481cd1ef6e50c73cbfbe8795 Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 30 Aug 2024 22:50:42 +0300 Subject: [PATCH 124/171] Fixed comment --- thunder/backend_optimizer/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 69280b387c..e7f5c9b3bc 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -1054,7 +1054,7 @@ def _find_bsym_index(out_name: str, space: Sequence[BoundSymbol]) -> int: # Make it fast to search in index_gaps_to_exclude = set(index_gaps_to_exclude) - # Create reduce trace regions + # Create reduced trace regions bound_symbols: list[BoundSymbol] = [ b for i, b in enumerate(trace.bound_symbols) if not _exclude(common_blocks[1:], i, index_gaps_to_exclude) ] From dc9d181ae2eb345ca7703e7c4d2b8435e9fc5b3d Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 30 Aug 2024 23:00:13 +0300 Subject: [PATCH 125/171] Changed log level --- thunder/backend_optimizer/optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index c8cf4a3a9f..f52c0cb37b 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -893,7 +893,7 @@ def _optimize(): ) # print(common_trace_blocks) if len(common_trace_blocks) >= 2 and optimize_common_blocks: - log(f"Common blocks found {common_trace_blocks}", level=LogLevel.DEBUG) + log(f"Common blocks found {common_trace_blocks}", level=LogLevel.INFO) reduced_trace = reduce_common_trace_blocks(trace=self.trace, common_blocks_in=common_trace_blocks) log( f"Operating on reduced trace (by cutting common transformer blocks):\n{reduced_trace}", From b80ffbef83fd7547c4db2c498c56c2231036b4a7 Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 30 Aug 2024 23:00:42 +0300 Subject: [PATCH 126/171] Formatted comment --- thunder/backend_optimizer/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index e7f5c9b3bc..627ab4c08c 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -1157,7 +1157,8 @@ def _returned(seq: Sequence) -> tuple: return extrace -# NOTE: This implementation currently relies on the fact that transformer blocks are contiguous in trace or they have a common gap region between them (in case for bw trace). +# NOTE: This implementation currently relies on the fact that transformer blocks are contiguous in trace +# or they have a common gap region between them (in case for bw trace). # TODO: generalize this def map_executors_from_reduced_trace_to_complete_trace( complete_trace: TraceCtx, common_blocks: list[tuple], ex_mappings: list[Executor] From 62b75efaa71ed3cab1c3bf2e03e09bd87508d797 Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 30 Aug 2024 23:06:12 +0300 Subject: [PATCH 127/171] Updated runner --- examples/dev/litGPT.py | 49 +++++------------------------------------- 1 file changed, 5 insertions(+), 44 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index fea025192c..50bc2fb6c5 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -32,7 +32,7 @@ def __init__( self.seq_len = seq_len self.model_name = model_name self.executors = executors - self.optimize_transformer_blocks = (optimize_transformer_blocks,) + self.optimize_transformer_blocks = optimize_transformer_blocks self.optimize_transformer_min_block_size = optimize_transformer_min_block_size @@ -44,7 +44,7 @@ def __init__( executors=[ "cudnn", "sdpa", - "fa3", + # "fa3", "nvfuser", "torchcompile", ], @@ -56,50 +56,11 @@ def __init__( executors=[ "cudnn", "sdpa", - "fa3", + # "fa3", "nvfuser", "torchcompile", ], ), - # Test( - # 3, - # "runtime", - # 1, - # executors=[ - # "cudnn", - # "sdpa", - # # "fa3", - # "nvfuser", - # "torchcompile", - # ], - # seq_len=128 - # ), - # Test( - # 1, - # "memory", - # 1, - # executors=[ - # "cudnn", - # "sdpa", - # "fa3", - # "nvfuser", - # "torchcompile", - # ], - # ), - # Test( - # 1, - # "runtime", - # 1, - # executors=["cudnn", "sdpa", "fa3", "nvfuser", "torchcompile"], - # model_name="stablecode-completion-alpha-3b", - # ), - # Test( - # 1, - # "memory", - # 1, - # executors=["cudnn", "sdpa", "fa3", "nvfuser", "torchcompile"], - # model_name="stablecode-completion-alpha-3b", - # ), ] for test in layers: @@ -108,7 +69,7 @@ def __init__( cfg.n_layer = test.layers if test.seq_len != -1: cfg.block_size = test.seq_len - torch.set_default_dtype(torch.float32) + torch.set_default_dtype(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16) print(cfg) with torch.device("cuda"): model = GPT(cfg) @@ -132,7 +93,7 @@ def __init__( e = time.time_ns() print("Compilation time:", {(e - s) / 1000000000}, "s") - iters = 40 + iters = 100 fw_traces = [ thunder.last_traces(jmodel_def)[-1], thunder.last_traces(jmodel_auto)[-1], From a9a9a3f1607f569ea7eb0c3f5682489a54e325e7 Mon Sep 17 00:00:00 2001 From: matteochen Date: Sat, 31 Aug 2024 00:53:25 +0300 Subject: [PATCH 128/171] Enabled te and nvFuser compile options from thunder jit / updated tests --- examples/dev/LLaMAMLP.py | 7 +- examples/dev/MLP.py | 79 ------------ examples/dev/conv2d_relu.py | 31 ----- examples/dev/litGPT.py | 9 +- examples/dev/nanogpt-block.py | 159 ------------------------- examples/dev/nanogpt.py | 4 + examples/dev/nvfuser_optimizations.py | 9 +- examples/dev/nvmath_example.py | 0 examples/dev/sdpa.py | 11 +- examples/dev/sdpa_linear.py | 77 ------------ examples/dev/simple.py | 55 --------- examples/dev/te.py | 9 +- examples/dev/test_del.py | 32 ----- thunder/__init__.py | 23 +++- thunder/backend_optimizer/optimizer.py | 55 +++++---- thunder/backend_optimizer/utils.py | 8 +- thunder/core/vjp_utils.py | 6 +- thunder/tests/test_autotuner.py | 29 +++-- 18 files changed, 114 insertions(+), 489 deletions(-) delete mode 100644 examples/dev/MLP.py delete mode 100644 examples/dev/conv2d_relu.py delete mode 100644 examples/dev/nanogpt-block.py delete mode 100644 examples/dev/nvmath_example.py delete mode 100644 examples/dev/sdpa_linear.py delete mode 100644 examples/dev/simple.py delete mode 100644 examples/dev/test_del.py diff --git a/examples/dev/LLaMAMLP.py b/examples/dev/LLaMAMLP.py index 708b529ab5..f69a687e9b 100644 --- a/examples/dev/LLaMAMLP.py +++ b/examples/dev/LLaMAMLP.py @@ -1,3 +1,7 @@ +""" +This benchmark script is intended to demonstrate the optimizer on a generic model. +No executor are given leaving full responsibility to the engine. +""" import torch import thunder from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_total_benchmark @@ -29,8 +33,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: jmodel_auto = thunder.jit( model, autotune_type="runtime", - executors=["nvfuser", "torchcompile", "transformer_engine"], - use_cudagraphs=False, + autotune_enable_te=True ) print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item()) diff --git a/examples/dev/MLP.py b/examples/dev/MLP.py deleted file mode 100644 index 0f249ee200..0000000000 --- a/examples/dev/MLP.py +++ /dev/null @@ -1,79 +0,0 @@ -import torch -import torch.nn as nn -import thunder -from thunder.benchmarks.utils import ( - thunder_fw_bw_benchmark, - torch_fw_bw_benchmark, - torch_fw_bw_benchmark_nvsight, - torch_total_benchmark, -) - - -class ModelConfig: - def __init__(self, n_embd=256, n_head=8, dropout=0.1, block_size=64, bias=True): - self.n_embd = n_embd - self.n_head = n_head - self.dropout = dropout - self.bias = bias - self.block_size = block_size - - -class MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) - self.gelu = nn.GELU() - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) - self.dropout = nn.Dropout(config.dropout) - - def forward(self, x): - x = self.c_fc(x) - x = self.gelu(x) - x = self.c_proj(x) - x = self.dropout(x) - return x - - -with torch.device("cuda"): - embeddings = 3072 - config = ModelConfig(n_embd=embeddings, dropout=0.0, bias=False) - dtype = torch.float32 - x = torch.randn(16, 1024, embeddings, requires_grad=True) - - model = MLP(config) - - jmodel_def = thunder.jit(model) - # This model fails under some circumstances after passed the placed traced under the rematelizer - jmodel_auto = thunder.jit( - model, - autotune_type="runtime", - executors=["nvfuser", "torchcompile", "sdpa", "torch", "python"], - use_cudagraphs=False, - ) - - print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item()) - print("deviation auto:", (jmodel_auto(x) - model(x)).abs().max().item()) - - iters = 100 - callables = [jmodel_auto, jmodel_def] - labels = ["auto", "def"] - inputs = [x, x] - print("Results with torch total benchmark:") - torch_total_benchmark(callables, labels, inputs, iters) - - print("Results with thunder benchmark:") - fw_traces = [ - thunder.last_traces(jmodel_def)[-1], - thunder.last_traces(jmodel_auto)[-1], - ] - bw_traces = [ - thunder.last_backward_traces(jmodel_def)[-1], - thunder.last_backward_traces(jmodel_auto)[-1], - ] - fw_labels = ["fw_def", "fw_auto"] - bw_labels = ["bw_def", "bw_auto"] - thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) - - # for t in traces: - # print(t) - # print('##########################') diff --git a/examples/dev/conv2d_relu.py b/examples/dev/conv2d_relu.py deleted file mode 100644 index c7b9386a26..0000000000 --- a/examples/dev/conv2d_relu.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -import thunder - - -class Module(torch.nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride=1) -> None: - super().__init__() - self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride) - self.relu = torch.nn.ReLU() - - def forward(self, x: torch.Tensor): - a = self.conv2d(x) - b = self.conv2d(x) - c = self.conv2d(x + x) - d = self.relu(b * a) - return c + d - - -with torch.device("cuda"): - model = Module(16, 33, 3, stride=2) - x = torch.randn(20, 16, 50, 100) - - jmodel = thunder.jit(model) - - ans = jmodel(x) - # print('---------------------------------------------- all traces') - # for t in thunder.last_traces(jmodel): - # print(t) - # print('##############################################') - # print('---------------------------------------------- ans') - # print(ans) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 50bc2fb6c5..fc089bd98d 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -1,3 +1,7 @@ +""" +This script benchmarks litGPT models in a easier wrt to benchmark_litgpt.py way with a fake training loop with no optimizers in order to focus more on +forward and backward computation time and not others kernel during the loop. +""" from litgpt import GPT from thunder.benchmarks.utils import ( thunder_fw_bw_benchmark, @@ -83,9 +87,8 @@ def __init__( model, autotune_type=test.autotune_type, executors=test.executors, - use_cudagraphs=False, - optimize_common_blocks=test.optimize_transformer_blocks, - optimize_common_blocks_min_size=test.optimize_transformer_min_block_size, + autotune_optimize_common_blocks=test.optimize_transformer_blocks, + autotune_optimize_common_blocks_min_size=test.optimize_transformer_min_block_size, ) print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item()) s = time.time_ns() diff --git a/examples/dev/nanogpt-block.py b/examples/dev/nanogpt-block.py deleted file mode 100644 index f1f1693e22..0000000000 --- a/examples/dev/nanogpt-block.py +++ /dev/null @@ -1,159 +0,0 @@ -import math -import torch -import torch.nn as nn -from torch.nn import functional as F -import thunder -from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark - -# torch.set_default_dtype(torch.bfloat16) - - -class LayerNorm(nn.Module): - """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" - - def __init__(self, ndim, bias): - super().__init__() - self.weight = nn.Parameter(torch.ones(ndim)) - self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None - - def forward(self, input): - return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) - - -class CausalSelfAttention(nn.Module): - def __init__(self, config): - super().__init__() - assert config.n_embd % config.n_head == 0 - # key, query, value projections for all heads, but in a batch - self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) - # output projection - self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) - # regularization - self.attn_dropout = nn.Dropout(config.dropout) - self.resid_dropout = nn.Dropout(config.dropout) - self.n_head = config.n_head - self.n_embd = config.n_embd - self.dropout = config.dropout - # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 - self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") - if not self.flash: - print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") - # causal mask to ensure that attention is only applied to the left in the input sequence - self.register_buffer( - "bias", - torch.tril(torch.ones(config.block_size, config.block_size)).view( - 1, 1, config.block_size, config.block_size - ), - ) - - def forward(self, x): - B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) - - # calculate query, key, values for all heads in batch and move head forward to be the batch dim - q, k, v = self.c_attn(x).split(self.n_embd, dim=2) - k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - - # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) - if self.flash: - # efficient attention using Flash Attention CUDA kernels - y = torch.nn.functional.scaled_dot_product_attention( - q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True - ) - else: - # manual implementation of attention - att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) - att = F.softmax(att, dim=-1) - att = self.attn_dropout(att) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side - - # output projection - y = self.resid_dropout(self.c_proj(y)) - return y - - -class MLP(nn.Module): - def __init__(self, config): - super().__init__() - self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) - self.gelu = nn.GELU() - self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) - self.dropout = nn.Dropout(config.dropout) - - def forward(self, x): - x = self.c_fc(x) - x = self.gelu(x) - x = self.c_proj(x) - x = self.dropout(x) - return x - - -class GPTConfig: - block_size: int = 1024 - vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency - n_layer: int = 12 - n_head: int = 12 - n_embd: int = 3072 - dropout: float = 0.0 - bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster - - -class Block(nn.Module): - def __init__(self, config): - super().__init__() - self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) - self.attn = CausalSelfAttention(config) - self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) - self.mlp = MLP(config) - - def forward(self, x): - x = x + self.attn(self.ln_1(x)) - x = x + self.mlp(self.ln_2(x)) - return x - - -with torch.device("cuda"): - config = GPTConfig() - model = Block(config) - x = torch.randn((16, 1024, 3072), dtype=torch.float32) - - jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit( - model, autotune_type="runtime", executors=["nvfuser", "torchcompile", "sdpa", "cudnn", "torch", "python"] - ) - - print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item()) - print("deviation auto:", (jmodel_auto(x) - model(x)).abs().max().item()) - - iters = 100 - print("Results thunder benchmark:") - fw_traces = [ - thunder.last_traces(jmodel_def)[-1], - thunder.last_traces(jmodel_auto)[-1], - ] - bw_traces = [ - thunder.last_backward_traces(jmodel_def)[-1], - thunder.last_backward_traces(jmodel_auto)[-1], - ] - fw_labels = ["fw_def", "fw_auto"] - bw_labels = ["bw_def", "bw_auto"] - thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) - - print("\n\nResults torch fw bw benchmark:") - callables = [jmodel_def, jmodel_auto] - labels = ["def", "auto"] - inputs = [x, x] - torch_fw_bw_benchmark(callables, labels, inputs, 100) - - print("\n\n\n\n\n\n") - print(f"{thunder.last_traces(jmodel_def)[-1]}") - print("###############################################################################") - print(f"{thunder.last_traces(jmodel_auto)[-1]}") - - print("\n\n") - print(f"{thunder.last_backward_traces(jmodel_def)[-1]}") - print("###############################################################################") - print(f"{thunder.last_backward_traces(jmodel_auto)[-1]}") diff --git a/examples/dev/nanogpt.py b/examples/dev/nanogpt.py index d70a452f85..f0b1331e49 100644 --- a/examples/dev/nanogpt.py +++ b/examples/dev/nanogpt.py @@ -1,3 +1,7 @@ +""" +This benchmark script is intended to demonstrate the optimizer on nanoGPT model. +The script runner is taken from: https://github.com/karpathy/nanoGPT/blob/master/bench.py +""" import torch import thunder from thunder.benchmarks.utils import thunder_fw_bw_benchmark diff --git a/examples/dev/nvfuser_optimizations.py b/examples/dev/nvfuser_optimizations.py index 720eaa7a0a..7055a04df3 100644 --- a/examples/dev/nvfuser_optimizations.py +++ b/examples/dev/nvfuser_optimizations.py @@ -1,3 +1,8 @@ +""" +This benchmark script is intended to demonstrate the optimizer optimizing the nvFuser executor with its compile options. + +nvFuser compile options can be autotune with the argument `autotune_enable_nvfuser_all=True`. +""" import torch import thunder from thunder.benchmarks.utils import ( @@ -34,7 +39,9 @@ def forward(self, x: torch.Tensor): x = torch.randn(1 << 9, in_features, requires_grad=True) jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type="runtime", executors=["nvfuser", "cudnn", "torch", "python"]) + jmodel_auto = thunder.jit( + model, autotune_type="runtime", executors=["nvfuser", "cudnn"], autotune_enable_nvfuser_all=True + ) y = jmodel_def(x) y = jmodel_auto(x) diff --git a/examples/dev/nvmath_example.py b/examples/dev/nvmath_example.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/examples/dev/sdpa.py b/examples/dev/sdpa.py index e8f666c806..6aeaed4d13 100644 --- a/examples/dev/sdpa.py +++ b/examples/dev/sdpa.py @@ -1,11 +1,15 @@ +""" +This benchmark script is intended to demonstrate the optimizer working on +the single trace region bext executor (when the forward trace symbol will influence the backward trace). + +Set the log level at least to INF0 in `thunder/backend_optimizer/optimizer.py`. +""" import torch import thunder from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_total_benchmark dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 torch.set_default_dtype(dtype) -print(f"Script data type: {dtype}\n") - class Model(torch.nn.Module): def __init__(self) -> None: @@ -17,13 +21,12 @@ def forward(self, query, key, value): b = torch.nn.functional.scaled_dot_product_attention(query + query, key + key, value + value) return a + b - with torch.device("cuda"): model = Model() jmodel_def = thunder.jit(model) jmodel_auto = thunder.jit( - model, autotune_type="runtime", executors=["nvfuser", "cudnn", "sdpa"], use_cudagraphs=False + model, autotune_type="runtime", executors=["nvfuser", "cudnn", "sdpa"] ) q = torch.rand(32, 8, 128, 64 * 1, requires_grad=True) diff --git a/examples/dev/sdpa_linear.py b/examples/dev/sdpa_linear.py deleted file mode 100644 index 0daa9498f1..0000000000 --- a/examples/dev/sdpa_linear.py +++ /dev/null @@ -1,77 +0,0 @@ -import torch -import thunder -from thunder.backend_optimizer.optimizer import benchmark_trace - -torch.set_default_dtype(torch.float32) - - -class Model(torch.nn.Module): - def __init__(self, inf, outf) -> None: - super().__init__() - self.linear = torch.nn.Linear(inf, outf, bias=False) - - def forward(self, query, key, value): - query = self.linear(query) - a = torch.nn.functional.scaled_dot_product_attention(query, key, value) - return a - - -with torch.device("cuda"): - features = 128 - model = Model(features, features) - - jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit( - model, autotune_type="runtime", executors=["nvfuser", "cudnn", "sdpa", "fa3", "torchcompile"] - ) - - q = torch.rand(32, 8, 128, features, requires_grad=True) - k = torch.rand(32, 8, 128, features, requires_grad=True) - v = torch.rand(32, 8, 128, features, requires_grad=True) - - print("deviation def:", (jmodel_def(q, k, v) - model(q, k, v)).abs().max().item()) - print("deviation auto:", (jmodel_auto(q, k, v) - model(q, k, v)).abs().max().item()) - - print("########################################") - c, m, o = benchmark_trace( - thunder.last_traces(jmodel_def)[-1], - apply_del_last_used=False, - snapshot=True, - snapshot_name="sdpa_def_fw", - iters=10, - ) - print(f"Executing default fw trace:\n{c} ms, {m / (2**30)} GB") - c, m, o = benchmark_trace( - thunder.last_traces(jmodel_auto)[-1], - apply_del_last_used=False, - snapshot=True, - snapshot_name="sdpa_auto_fw", - iters=10, - ) - print(f"Executing auto fw trace:\n{c} ms, {m / (2**30)} GB") - c, m, o = benchmark_trace( - thunder.last_backward_traces(jmodel_def)[-1], - apply_del_last_used=False, - snapshot=True, - snapshot_name="sdpa_def_bw", - iters=10, - ) - print(f"Executing default bw trace:\n{c} ms, {m / (2**30)} GB") - c, m, o = benchmark_trace( - thunder.last_backward_traces(jmodel_auto)[-1], - apply_del_last_used=False, - snapshot=True, - snapshot_name="sdpa_auto_bw", - iters=10, - ) - print(f"Executing auto bw trace:\n{c} ms, {m / (2**30)} GB") - - print("\n\n\n\n\n\n") - print(f"{thunder.last_traces(jmodel_def)[-1]}") - print("###############################################################################") - print(f"{thunder.last_traces(jmodel_auto)[-1]}") - - print("\n\n") - print(f"{thunder.last_backward_traces(jmodel_def)[-1]}") - print("###############################################################################") - print(f"{thunder.last_backward_traces(jmodel_auto)[-1]}") diff --git a/examples/dev/simple.py b/examples/dev/simple.py deleted file mode 100644 index bc88314cce..0000000000 --- a/examples/dev/simple.py +++ /dev/null @@ -1,55 +0,0 @@ -import torch -import thunder -from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark - - -class Module(torch.nn.Module): - def __init__(self, in_features, out_features) -> None: - super().__init__() - self.linear = torch.nn.Linear(in_features, out_features) - self.silu = torch.nn.SiLU() - - def forward(self, x: torch.Tensor): - a = x + x - b: torch.Tensor = self.linear(a) - c = b * b - return self.silu(c) - - -with torch.device("cuda"): - in_features = 4096 - out_features = 11008 - model = Module(in_features, out_features) - x = torch.randn(128, in_features, requires_grad=True) - - jmodel_def = thunder.jit( - model, - ) - jmodel_auto = thunder.jit( - model, - autotune_type="runtime", - executors=["nvfuser", "torchcompile", "cudnn", "torch", "python"], - ) - - y = jmodel_def(x) - y = jmodel_auto(x) - - iters = 100 - print("Results thunder benchmark:") - fw_traces = [ - thunder.last_traces(jmodel_def)[-1], - thunder.last_traces(jmodel_auto)[-1], - ] - bw_traces = [ - thunder.last_backward_traces(jmodel_def)[-1], - thunder.last_backward_traces(jmodel_auto)[-1], - ] - fw_labels = ["fw_def", "fw_auto"] - bw_labels = ["bw_def", "bw_auto"] - thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) - - callables = [jmodel_def, jmodel_auto] - labels = ["def", "auto"] - inputs = [x, x] - print("Results torch benchmark:") - torch_fw_bw_benchmark(callables, labels, inputs, 50) diff --git a/examples/dev/te.py b/examples/dev/te.py index 6341506454..6f8d08d691 100644 --- a/examples/dev/te.py +++ b/examples/dev/te.py @@ -1,3 +1,8 @@ +""" +This benchmark script is intended to demonstrate the optimizer supporting the transformer engine executor. + +This option can be enabled inside the autotuner by using the flag `autotune_enable_te=True`. +""" import torch import thunder from thunder.benchmarks.utils import ( @@ -28,7 +33,7 @@ def forward(self, x: torch.Tensor): model = Module(in_features, out_features) x = torch.randn(768, in_features, requires_grad=True) - jmodel_def = thunder.jit(model, executors=["transformer_engine"], use_cudagraphs=False) + jmodel_def = thunder.jit(model) jmodel_auto = thunder.jit( model, autotune_type="runtime", @@ -36,7 +41,7 @@ def forward(self, x: torch.Tensor): "nvfuser", "transformer_engine", ], - use_cudagraphs=False, + autotune_enable_te=True ) y = jmodel_def(x) diff --git a/examples/dev/test_del.py b/examples/dev/test_del.py deleted file mode 100644 index 50654edbc3..0000000000 --- a/examples/dev/test_del.py +++ /dev/null @@ -1,32 +0,0 @@ -import torch -import time - -iters = 1000 - -with torch.device("cuda"): - tot_time = 0 - for i in range(iters): - s = time.time_ns() - a = torch.randn(2, 2048, 4096 // 1, requires_grad=True) - b = torch.randn(2, 2048, 4096 // 1, requires_grad=True) - c = a + b + a + b - c = c * c - del a - del b - del c - torch.cuda.synchronize() - tot_time += time.time_ns() - s - - print(f"With del = {(tot_time / iters) / 1000000}") - - tot_time = 0 - for i in range(iters): - s = time.time_ns() - a = torch.randn(2, 2048, 4096 // 1, requires_grad=True) - b = torch.randn(2, 2048, 4096 // 1, requires_grad=True) - c = a + b + a + b - c = c * c - torch.cuda.synchronize() - tot_time += time.time_ns() - s - - print(f"With no del = {(tot_time / iters) / 1000000}") diff --git a/thunder/__init__.py b/thunder/__init__.py index 26ed89372e..037a01d334 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -333,19 +333,34 @@ def jit( compile_options |= { "autotune_type": OptimizerType.RUNTIME if required_autotune == "runtime" else OptimizerType.MEMORY, - "executors_placed_by_fw_bw_split": set(), + "autotune_executors_placed_by_fw_bw_split": set(), } # Default the executors list to all_executors if no options are given # Otherwise the user restricted choice will be used + from thunder.executors.transformer_engineex import transformer_engine_ex + from thunder.executors.cudagraphex import cudagraphex + from thunder.executors.pythonex import ex as python_ex if not executors: executors = get_all_executors() # Remove python and cudagraph - executors = [ex for ex in executors if ex.name != "python" and ex.name != "cudagraphex"] + executors = [ex for ex in executors if ex != python_ex and ex != cudagraphex] + # Remove transformer_engine if not requested + executors = [ + ex + for ex in executors + if ex != transformer_engine_ex + or (ex == transformer_engine_ex and compile_options.get("autotune_enable_te", False)) + ] + else: + # If TE is in executors list we have to enable the compilation option + if transformer_engine_ex in executors: + compile_options['autotune_enable_te'] = True from thunder.backend_optimizer.utils import reorder_executors_list - - executors = reorder_executors_list(executors) + executors = reorder_executors_list( + executors, autotune_enable_te=compile_options.get("autotune_enable_te", False) + ) # Resolve names of executors executors = resolve_executors(executors) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index f52c0cb37b..fa882d826a 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -16,18 +16,22 @@ from thunder.backend_optimizer.utils import benchmark_trace -def get_fw_bw_split_backends_options(bsym: BoundSymbol | None = None) -> list | dict: +# This fn is used before compile data being set, rely on kwargs +def get_fw_bw_split_backends_options(bsym: BoundSymbol | None = None, **kwargs) -> list | dict: from thunder.executors.sdpaex import sdpa_ex from thunder.executors.cudnnex import cudnn_ex from thunder.executors.fa3ex import fa3_ex from thunder.executors.transformer_engineex import transformer_engine_ex - # Current configuration - options: dict[str, list] = { - # TODO: filter out TE only if requested - "linear": [transformer_engine_ex], - "scaled_dot_product_attention": [sdpa_ex, cudnn_ex, fa3_ex], - } + if kwargs is None or not kwargs.get("autotune_enable_te", False): + options: dict[str, list] = { + "scaled_dot_product_attention": [sdpa_ex, cudnn_ex, fa3_ex], + } + else: + options: dict[str, list] = { + "linear": [transformer_engine_ex], + "scaled_dot_product_attention": [sdpa_ex, cudnn_ex, fa3_ex], + } return options.get(bsym.sym.name, []) if bsym else options @@ -239,17 +243,25 @@ def __init__( self.fusion_strat_helper: FusionStratHelper = FusionStratHelper() self.executor_placement_options: ExecutorPlacementOptions = ExecutorPlacementOptions() - from thunder.executors.nvfuserex_impl import linear, _linear_check - from thunder.executors.nvfuserex_impl import matmul, _matmul_check - - self.known_fusion_ex_compile_options: dict[str | Hashable, list[FusionCompileOptionsHelper]] = { - "nvfuser": [ - # FusionCompileOptionsHelper("nv_enable_linear", "linear", PrimIDs.LINEAR, linear, _linear_check), - # FusionCompileOptionsHelper("nv_enable_matmul", "matmul", PrimIDs.MATMUL, matmul, _matmul_check), - # FusionCompileOptionsHelper("nv_enable_bookend", "bookend"), - ] - } - + # nvFuser compile options + if compile_data.compile_options.get('autotune_enable_nvfuser_all', False): + from thunder.executors.nvfuserex_impl import linear, _linear_check + from thunder.executors.nvfuserex_impl import matmul, _matmul_check + + self.known_fusion_ex_compile_options: dict[str | Hashable, list[FusionCompileOptionsHelper]] = { + "nvfuser": [ + FusionCompileOptionsHelper("nv_enable_linear", "linear", PrimIDs.LINEAR, linear, _linear_check), + FusionCompileOptionsHelper("nv_enable_matmul", "matmul", PrimIDs.MATMUL, matmul, _matmul_check), + ] + } + else: + self.known_fusion_ex_compile_options: dict[str | Hashable, list[FusionCompileOptionsHelper]] = { + "nvfuser": [ + ] + } + + # Transformer based models optimization + # TODO: explain self.is_reduced: bool = False self.cached_original_trace: TraceCtx | None = None @@ -278,8 +290,7 @@ def _best_runtime_and_memory_candidates(self, candidates): (pair.fw, pair.bw), (remat_fw, remat_bw), ] - # We want to verify that it is not set to false - if self.compile_data.use_cudagraphs is None or self.compile_data.use_cudagraphs == True: + if self.compile_data.use_cudagraphs is not None and self.compile_data.use_cudagraphs: from thunder.executors.cudagraphex import cudagraphex pair_options.extend( @@ -882,9 +893,9 @@ def _optimize(): cd = get_compile_data() # Check if common blocks optimization is requested - optimize_common_blocks = False if cd is None else cd.compile_options.get("optimize_common_blocks", False) + optimize_common_blocks = False if cd is None else cd.compile_options.get("autotune_optimize_common_blocks", False) optimize_common_blocks_min_size = ( - -1 if cd is None else cd.compile_options.get("optimize_common_blocks_min_size", -1) + -1 if cd is None else cd.compile_options.get("autotune_optimize_common_blocks_min_size", -1) ) # Cut the compilation time if possible diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 627ab4c08c..53dc937db3 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -635,7 +635,7 @@ def update_compile_options_executor_list_after_fw_bw_split() -> None: assert cd # Get all the possible options that the vjp_optimization pass will use - options: dict = get_fw_bw_split_backends_options() + options: dict = get_fw_bw_split_backends_options(autotune_enable_te=cd.compile_options.get('autotune_enable_te', False)) executors_list = list(cd.executors_list) # Remove all the initial options @@ -645,7 +645,7 @@ def update_compile_options_executor_list_after_fw_bw_split() -> None: executors_list.remove(ex) # Putting at the front even though order does not matter - for ex in cd.compile_options["executors_placed_by_fw_bw_split"]: + for ex in cd.compile_options["autotune_executors_placed_by_fw_bw_split"]: executors_list.insert(0, ex) # Assign new compilation executors options @@ -741,13 +741,13 @@ def transform_proxy_to_torch(sequence: Sequence, level=0, **kwargs) -> tuple | l return tuple(res) if level > 0 else res -def reorder_executors_list(executors: Sequence): +def reorder_executors_list(executors: Sequence, **kwargs): from thunder.backend_optimizer.optimizer import get_fw_bw_split_backends_options from thunder.executors.torch_compile import torch_compile_ex from thunder.executors.nvfuserex_impl import ex as nvfuser_ex reordered = [] - options = get_fw_bw_split_backends_options() + options = get_fw_bw_split_backends_options(**kwargs) are_inputs_names = isinstance(executors[0], str) diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index a956a54b65..d9f4bb5154 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -232,7 +232,9 @@ def bw_fn(*args, **kwargs): return cached_result # Get the possible backends for the current bsym - backends = get_fw_bw_split_backends_options(bsym) + backends = get_fw_bw_split_backends_options( + bsym, autotune_enable_te=cd.compile_options.get("autotune_enable_te", False) + ) if not backends: raise AssertionError( f"No enabled backends found for {bsym.sym.name} but an executor for that symbol it is present in the executors list. Either remove that from the executors list or enable at least one backend for {bsym.sym.name} inside 'get_fw_bw_split_backends_options'." @@ -278,7 +280,7 @@ def bw_fn(*args, **kwargs): ) # Update the compile options - cd.compile_options["executors_placed_by_fw_bw_split"].add(best.executor) + cd.compile_options["autotune_executors_placed_by_fw_bw_split"].add(best.executor) from thunder.executors.transformer_engineex import transformer_engine_ex cd.compile_options |= {"te_used": True if best.executor == transformer_engine_ex else False} diff --git a/thunder/tests/test_autotuner.py b/thunder/tests/test_autotuner.py index 7fdf2f9bd8..041da26920 100644 --- a/thunder/tests/test_autotuner.py +++ b/thunder/tests/test_autotuner.py @@ -445,18 +445,18 @@ def forward(self, x: torch.Tensor): @pytest.mark.parametrize( - "executors, expected", + "executors, expected, use_te", [ - (["python"], ["nvfuser", "python"]), - (["nvfuser", "cudnn"], ["cudnn", "nvfuser"]), - (["torch", "nvfuser", "sdpa"], ["sdpa", "torch", "nvfuser"]), - (["transformer_engine", "nvfuser", "sdpa"], ["transformer_engine", "sdpa", "nvfuser"]), + (["python"], ["nvfuser", "python"], False), + (["nvfuser", "cudnn"], ["cudnn", "nvfuser"], False), + (["torch", "nvfuser", "sdpa"], ["sdpa", "torch", "nvfuser"], False), + (["transformer_engine", "nvfuser", "sdpa"], ["transformer_engine", "sdpa", "nvfuser"], True), ], ) # We might not have nvfuser in non cuda envs @requiresCUDA -def test_reorder_executors_list(executors, expected): - assert aut_utils.reorder_executors_list(executors) == expected +def test_reorder_executors_list(executors, expected, use_te): + assert aut_utils.reorder_executors_list(executors, autotune_enable_te=use_te) == expected @pytest.mark.parametrize( @@ -466,7 +466,7 @@ def test_reorder_executors_list(executors, expected): def test_get_fw_bw_split_backends_options(name: str, expected): symbol = Symbol(name=name) bsym = BoundSymbol(symbol, (), {}, None) - options = get_fw_bw_split_backends_options(bsym) + options = get_fw_bw_split_backends_options(bsym, autotune_enable_te=True) assert all(map(lambda v: v in options, expected)) @@ -497,9 +497,9 @@ def forward(self, x): @pytest.mark.parametrize( - "model, tensor_shape, dtype, autotune_type, executors, expected_executors, use_cudagraphs", + "model, tensor_shape, dtype, autotune_type, executors, expected_executors, use_cudagraphs, use_te", [ - (Model_1(32, 32), (32, 32), torch.float32, "runtime", [nvfuserex], [[nvfuserex, torchex, pythonex]], True), + (Model_1(32, 32), (32, 32), torch.float32, "runtime", [nvfuserex], [[nvfuserex, torchex, pythonex]], True, False), ( Model_1(32, 32), (32, 32), @@ -508,6 +508,7 @@ def forward(self, x): [torch_compile_ex], [[torch_compile_ex, torchex, pythonex]], True, + False ), ( Model_1(4096, 4096), @@ -517,6 +518,7 @@ def forward(self, x): [transformer_engine_ex], [[transformer_engine_ex, nvfuserex, torchex, pythonex]], False, + True ), ( Model_2(), @@ -526,6 +528,7 @@ def forward(self, x): [sdpa_ex, cudnn_ex], [[sdpa_ex, nvfuserex, torchex, pythonex], [cudnn_ex, nvfuserex, torchex, pythonex]], False, + False, ), ( Model_2(), @@ -535,9 +538,10 @@ def forward(self, x): [sdpa_ex, transformer_engine_ex], [ [sdpa_ex, transformer_engine_ex, nvfuserex, torchex, pythonex], - [sdpa_ex, transformer_engine_ex, nvfuserex, torchex, pythonex], + [transformer_engine_ex, sdpa_ex, nvfuserex, torchex, pythonex], ], False, + True ), ], ) @@ -550,13 +554,14 @@ def test_autotuner( executors: list, expected_executors: list[list], use_cudagraphs: bool, + use_te: bool ): def _run(): model.to("cuda") x = torch.randn(tensor_shape, dtype=dtype, device="cuda") jitted_def = thunder.jit(model, executors=executors) jitted_auto = thunder.jit( - model, autotune_type=autotune_type, executors=executors, use_cudagraphs=use_cudagraphs + model, autotune_type=autotune_type, executors=executors, use_cudagraphs=use_cudagraphs, autotune_enable_te=use_te ) y_def = jitted_def(x) y_auto = jitted_auto(x) From 448fd8ac5b65ccd27c3adeb85cd234a7cd82f29f Mon Sep 17 00:00:00 2001 From: matteochen Date: Sat, 31 Aug 2024 14:02:48 +0300 Subject: [PATCH 129/171] Disabled cudagraphs --- examples/dev/litGPT.py | 1 + thunder/backend_optimizer/optimizer.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index fc089bd98d..a20209e136 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -87,6 +87,7 @@ def __init__( model, autotune_type=test.autotune_type, executors=test.executors, + use_cudagraphs=False, autotune_optimize_common_blocks=test.optimize_transformer_blocks, autotune_optimize_common_blocks_min_size=test.optimize_transformer_min_block_size, ) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index fa882d826a..6a3f66cb4e 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -290,15 +290,15 @@ def _best_runtime_and_memory_candidates(self, candidates): (pair.fw, pair.bw), (remat_fw, remat_bw), ] - if self.compile_data.use_cudagraphs is not None and self.compile_data.use_cudagraphs: - from thunder.executors.cudagraphex import cudagraphex - - pair_options.extend( - [ - (cudagraphex.fusion_pass(pair.fw), cudagraphex.fusion_pass(pair.bw)), - (cudagraphex.fusion_pass(remat_fw), cudagraphex.fusion_pass(remat_bw)), - ] - ) + # if self.compile_data.use_cudagraphs is not None and self.compile_data.use_cudagraphs: + # from thunder.executors.cudagraphex import cudagraphex + + # pair_options.extend( + # [ + # (cudagraphex.fusion_pass(pair.fw), cudagraphex.fusion_pass(pair.bw)), + # (cudagraphex.fusion_pass(remat_fw), cudagraphex.fusion_pass(remat_bw)), + # ] + # ) # Select the best options for pair_option in pair_options: fw = pair_option[0] From b08d9e1a7897bed69da60c3abcb2479bc7c5835a Mon Sep 17 00:00:00 2001 From: matteochen Date: Sat, 31 Aug 2024 14:04:50 +0300 Subject: [PATCH 130/171] Restricting the same executor in vjp pass if common trace block opt is True --- thunder/backend_optimizer/utils.py | 2 - thunder/core/vjp_utils.py | 61 ++++++++++++++++++------------ 2 files changed, 37 insertions(+), 26 deletions(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 53dc937db3..6f761b51cc 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -893,7 +893,6 @@ def _skip(bsym: BoundSymbol) -> bool: for i, bsym in enumerate(symbols): if i == len(symbols) - 1: break - # Skip None outputs (unpacks, returns, del) if _skip(bsym): continue h = symbol_hash(bsym) @@ -915,7 +914,6 @@ def _range_seen(index: int, s: set): for i, bsym in enumerate(symbols): if i == len(symbols) - 1: break - # Skip None outputs (unpacks, returns, del) if _skip(bsym): continue diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index d9f4bb5154..ee90b90250 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -3,8 +3,8 @@ from functools import wraps from inspect import Parameter, Signature from itertools import chain -from os import execl +from thunder.backend_optimizer.utils import symbol_hash from thunder.core import prims, utils from thunder.core.compile_data import get_compile_data from thunder.core.prims import PrimIDs @@ -17,6 +17,7 @@ _cache = {} +_autototune_common_bsym_in_blocks_cache = {} def disable_caching_split_forward_and_backward(fn): @@ -222,6 +223,7 @@ def bw_fn(*args, **kwargs): else: from thunder.backend_optimizer.optimizer import get_fw_bw_split_backends_options from thunder.backend_optimizer.utils import benchmark_trace + from thunder.backend_optimizer.optimizer import log, LogLevel # In order define this unique trace region we need an unique id key = (bsym.sym, Executor(f"{id(bsym)}-autotuned"), subkey := _make_cache_key(bsym.args, bsym.kwargs)) @@ -248,37 +250,48 @@ def bw_fn(*args, **kwargs): best = SplitFwBwBenchmarkUtils() - # Restrict the search space - backends = list(requested_executors_list_for_bsym) - - from thunder.backend_optimizer.optimizer import log, LogLevel - - log(f"Search space for {bsym.sym.name}: {backends}", level=LogLevel.INFO) - for b in backends: - log(f"Benchmarking executor {b.name} for {bsym.sym.name}", level=LogLevel.INFO) - # Let downstream fn to pick up this - requested_executors_list_for_bsym.remove(b) - requested_executors_list_for_bsym.insert(0, b) - cd.executors_list = requested_executors_list_for_bsym - fw_fn, bw_fn, fw_trace, bw_trace = _make_aug_forward_and_backward(return_traces=True, update_cache=False) - # What should be the optimal iter? - # TODO: make benchmark info taken from an autotuner config - fw_time, fw_mem, _ = benchmark_trace(fw_trace, iters=100, apply_del_last_used=False) - bw_time, bw_mem, _ = benchmark_trace(bw_trace, iters=100, apply_del_last_used=False, fw_trace=fw_trace) - cost = ( - fw_time + bw_time if cd.compile_options["autotune_type"] == OptimizerType.RUNTIME else fw_mem + bw_mem - ) - if cost < best.cost: - best = SplitFwBwBenchmarkUtils(cost=cost, fw_fn=fw_fn, bw_fn=bw_fn, executor=b) + # Do we have a common transformer block optimization enabled? + # If yes we have to restrict the same executor on every bsym + # in the transformer block (e.g. every scaled_dot_product in every transformer block will have the same executor + # as they are expected to work on same input size and shapes). + optmimizer_common_transformer_block = cd.compile_options.get('autotune_optimize_common_blocks', False) + # The generated hash will rely on the operation, input args metadata and output metadata + h = symbol_hash(bsym) + if h in _autototune_common_bsym_in_blocks_cache and optmimizer_common_transformer_block: + best = _autototune_common_bsym_in_blocks_cache[h] + else: + # Restrict the search space + backends = list(requested_executors_list_for_bsym) + + log(f"Search space for {bsym.sym.name}: {backends}", level=LogLevel.INFO) + for b in backends: + log(f"Benchmarking executor {b.name} for {bsym.sym.name}", level=LogLevel.INFO) + # Let downstream fn to pick up this + requested_executors_list_for_bsym.remove(b) + requested_executors_list_for_bsym.insert(0, b) + cd.executors_list = requested_executors_list_for_bsym + fw_fn, bw_fn, fw_trace, bw_trace = _make_aug_forward_and_backward(return_traces=True, update_cache=False) + # What should be the optimal iter? + # TODO: make benchmark info taken from an autotuner config + fw_time, fw_mem, _ = benchmark_trace(fw_trace, iters=100, apply_del_last_used=False) + bw_time, bw_mem, _ = benchmark_trace(bw_trace, iters=100, apply_del_last_used=False, fw_trace=fw_trace) + cost = ( + fw_time + bw_time if cd.compile_options["autotune_type"] == OptimizerType.RUNTIME else fw_mem + bw_mem + ) + if cost < best.cost: + best = SplitFwBwBenchmarkUtils(cost=cost, fw_fn=fw_fn, bw_fn=bw_fn, executor=b) assert best.cost != float("inf") - from thunder.backend_optimizer.optimizer import log log( f"Best executor for symbol [{bsym.output.name} = {bsym.sym.name}]: {best.executor.name}", level=LogLevel.INFO, ) + # Cache the bsym result for common trace's common block reductions + if bsym.sym.name in ['linear', 'scaled_dot_product_attention'] and optmimizer_common_transformer_block: + _autototune_common_bsym_in_blocks_cache[h] = best + # Update the compile options cd.compile_options["autotune_executors_placed_by_fw_bw_split"].add(best.executor) from thunder.executors.transformer_engineex import transformer_engine_ex From 4b05a39f455a4aa58f2da7497264310ae70c6f3e Mon Sep 17 00:00:00 2001 From: matteochen Date: Sat, 31 Aug 2024 14:16:59 +0300 Subject: [PATCH 131/171] Docs --- thunder/benchmarks/utils.py | 78 +++++++++++++++++++++++++------------ 1 file changed, 54 insertions(+), 24 deletions(-) diff --git a/thunder/benchmarks/utils.py b/thunder/benchmarks/utils.py index bcaa7007cc..2a541e985f 100644 --- a/thunder/benchmarks/utils.py +++ b/thunder/benchmarks/utils.py @@ -1,12 +1,20 @@ -from collections.abc import Callable, Sequence +from collections.abc import Callable import torch from thunder.backend_optimizer.utils import benchmark_trace -from thunder.core.trace import TraceCtx warm_up_iters = 50 - class SplitFwBwBenchmarkUtils: + """ + Represents a benchmark result container. + It should be used when a single trace region is benchmarked as it can store an optimal executor (referred to the bsym under investigation). + + Attributes: + cost: The benchmark result. Can be compute time or peak memory usage. + fw_fn: Storage for a forward trace. + bw_fn: Storage for a backward trace. + executor: An OperatorExecutor. + """ def __init__( self, *, cost: float = float("inf"), fw_fn: Callable | None = None, bw_fn: Callable | None = None, executor=None ) -> None: @@ -16,29 +24,18 @@ def __init__( self.executor = executor -class AutotunerTorchAutogradBenchmarkUtils: - def __init__( - self, - cost: float = float("inf"), - fw_trace: TraceCtx | None = None, - bw_trace: TraceCtx | None = None, - fw_traces: Sequence[TraceCtx] = [], - bw_traces: Sequence[TraceCtx] = [], - primal_trace: TraceCtx | None = None, - executor=None, - selected_executors: Sequence = [], - ) -> None: - self.cost: float = cost - self.fw_trace = fw_trace - self.bw_trace = bw_trace - self.fw_traces = fw_traces - self.bw_traces = bw_traces - self.primal_trace = primal_trace - self.executor = executor - self.selected_executors = selected_executors +def torch_fw_bw_benchmark_nvsight(models: list, labels: list, inputs: list, iters: int) -> None: + """ + Benchmark a mock trainig loop of the given models. The loss function is defined as a naive torch.sum(). + This util will generate nvsight system profiles. + Args: + models: a list of Callable models to benchmark. + labels: a list of labels (names) referring to the models. + inputs: a list of inputs to give to models' forward pass. + iters: benchmark iterations. + """ -def torch_fw_bw_benchmark_nvsight(models: list, labels: list, inputs: list, iters: int) -> None: for m, input, label in zip(models, inputs, labels): # Warm up for _ in range(warm_up_iters): @@ -63,6 +60,16 @@ def torch_fw_bw_benchmark_nvsight(models: list, labels: list, inputs: list, iter def torch_fw_bw_benchmark(models: list, labels: list, inputs: list, iters: int) -> None: + """ + Benchmark a mock trainig loop of the given models. The loss function is defined as a naive torch.sum(). + Forward and backward pass will be both recorded. + + Args: + models: a list of Callable models to benchmark. + labels: a list of labels (names) referring to the models. + inputs: a list of inputs to give to models' forward pass. + iters: benchmark iterations. + """ for m, input, label in zip(models, inputs, labels): # Warm up for _ in range(warm_up_iters): @@ -117,6 +124,16 @@ def torch_fw_bw_benchmark(models: list, labels: list, inputs: list, iters: int) def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int) -> None: + """ + Benchmark a mock trainig loop of the given models. The loss function is defined as a naive torch.sum(). + The complete time will be recorded with no split between forward pass and backward pass. + + Args: + models: a list of Callable models to benchmark. + labels: a list of labels (names) referring to the models. + inputs: a list of inputs to give to models' forward pass. + iters: benchmark iterations. + """ for m, input, label in zip(models, inputs, labels): # Warm up for _ in range(warm_up_iters): @@ -151,6 +168,19 @@ def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int) def thunder_fw_bw_benchmark( fw_traces: list, bw_traces: list, fw_labels: list, bw_labels: list, iters: int, nvsight: bool = False ) -> None: + """ + Benchmark a foward and backward trace pair. + The requested inputs are TraceCtx objects. + A nvsight profile can be generate if requested. + + Args: + fw_traces: a list of TraceCtx. + bw_traces: a list of TraceCtx. + fw_labels: a list of labels (names) referring to the forward traces. + bw_labels: a list of labels (names) referring to the backward traces. + iters: benchmark iterations. + nvsight: flag to control nvsight profile generation. + """ assert len(fw_traces) == len(bw_traces) == len(fw_labels) == len(bw_labels) for trc, label in zip(fw_traces, fw_labels): c, m, _ = benchmark_trace( From 701b36c39d71c3cff141e7b24182a37109021b02 Mon Sep 17 00:00:00 2001 From: matteochen Date: Sat, 31 Aug 2024 16:02:01 +0300 Subject: [PATCH 132/171] Docs and cleaning --- thunder/backend_optimizer/utils.py | 333 ++++++++++++++++++++++------- thunder/tests/test_autotuner.py | 22 +- 2 files changed, 262 insertions(+), 93 deletions(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 6f761b51cc..6da45c363c 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -26,6 +26,13 @@ # Maybe we can use id(s) def sequence_hash(s: Sequence) -> str: + """ + Create a fake hash for a sequence of elements. + A fake hash is created because it relies on the elements metadata and not on a specific hash function. + + Args: + s: A sequence to hash. + """ def rec(s) -> str: name = "[" for e in s: @@ -46,6 +53,13 @@ def rec(s) -> str: def can_executor_execute(ex: Executor, bsym: BoundSymbol) -> bool: + """ + Wrap the `can_execute` call of the `Executor`. + + Args: + ex: The executor to test. + bsym: The bound symbol to test. + """ try: return ex.can_execute(bsym) except Exception: @@ -55,6 +69,14 @@ def can_executor_execute(ex: Executor, bsym: BoundSymbol) -> bool: def get_first_available_operator_executor( *, bsym: BoundSymbol, executors: Sequence[Executor], empty_hash: str = "empty" ): + """ + Returns the first available executor which can execute the given bound symbol. + + Args: + bsym: The bound symbol to execute. + executors: A list of possible executors. + empty_hash: A label representing an empty executor if none will be found. + """ for ex in executors: if isinstance(ex, FusionExecutor): continue @@ -64,6 +86,13 @@ def get_first_available_operator_executor( def flatten_sequence(sequence: Sequence) -> list: + """ + Flat a sequence containing sub sequences with a dfs search. + By default None elements will be skipped. + + Args: + sequence: The sequence to flatten. + """ res = [] for e in sequence: if isinstance(e, Sequence): @@ -75,6 +104,13 @@ def flatten_sequence(sequence: Sequence) -> list: def get_not_used_intermediate_outsputs(trace_in: TraceCtx) -> list[Proxy]: + """ + Returns all the intermediate outputs that are not used or returned in the input trace. + This can be usefull if we want to force a specific TensorProxy to be returned in a modfied trace to avoid the dce. + + Args: + in_trace: A generic trace. + """ def is_in_sequence(seq: Sequence[Any], t: Proxy): for e in seq: if hasattr(e, "name") and hasattr(t, "name") and e.name == t.name: @@ -90,8 +126,6 @@ def unpack_output(out) -> Sequence[Proxy]: raise RuntimeError(f"Unpack operation not defined for {type(out)}") ans: list[Proxy] = [] - # Currently this is O(max(len(bsym.output)) * N^2) - # Can we check only bsym after the one in the outer loop in the inner loop (over trace.bound_symbols) ? for a in trace_in.bound_symbols: f = False unpacked_out = unpack_output(a.output) @@ -119,6 +153,18 @@ def assign_executors( compile_data=None, fusion_executor_compile_options_to_activate: Any | None = None, ) -> TraceCtx: + """ + Given a not optimized trace (original computation trace) generate a transformed trace with the requested executors. + + Args: + in_trace: The computation trace. + executors_list: A list of executors, one for each trace region. The size of this list is expected to be equal to the number of bound symbols inside the trace. + always_executors: A list of always executors to pick up symbols not picked up by any specific executor. + empty_str: A label representing an empty executor in the executors_list. + compile_data: A reference to the current compilation data. + fusion_executor_compile_options_to_activate: Any fusion exeuctor compilation options that can be enabled during the trace generation (for example nvFuser). + """ + from thunder.executors.passes import _transform_for_operator_executor_execution def _assign_executors(): @@ -234,7 +280,6 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: if t_name not in executor_mapping: # Symbol added by the visitor continue - # raise AssertionError('Failed to retrive key in mapping') saved_ex = executor_mapping[t_name] if isinstance(saved_ex, OperatorExecutor): cached_subsymbols[t_name] = list(bsym.subsymbols) @@ -270,32 +315,57 @@ def visit(bsym: BoundSymbol, ex: Executor) -> transforms.VISIT_TYPE: return _assign_executors() -def operation_in_trace(*, trace: TraceCtx, op: str) -> bool: - # Some optimizations are not available as symbols +def operation_in_trace(*, trace: TraceCtx, op: str, prefix: bool = False) -> bool: + """ + Test if an operation is being used inside a trace. + + Args: + trace: A computation trace. + op: The operation name to be tested. + prefix: Test only the prefix label. + """ + + # This is to query nv_enable_bookend (https://github.com/Lightning-AI/lightning-thunder/blob/339a782e3d75061a065a3d2e47b5206f23aea7c3/thunder/executors/nvfuserex_impl.py#L807) + # as there won't be any references about this in a trace. always_true = set(["bookend"]) if op in always_true: return True for b in trace.bound_symbols: - if b.sym.name == op: - return True + if prefix: + if b.sym.name.startswith(op): + return True + else: + if b.sym.name == op: + return True return False def is_te_used(trace: TraceCtx) -> bool: + """ + Test if transformer engine is being used inside a trace. + + Args: + trace: A computation trace. + """ from thunder.executors.transformer_engineex import linear_bound_symbol_name_prefix from thunder.executors.transformer_engineex import te_functional_linear_backward_name - for bsym in trace.bound_symbols: - if ( - bsym.sym.name.startswith(linear_bound_symbol_name_prefix) - or bsym.sym.name == te_functional_linear_backward_name - ): - return True + if operation_in_trace(trace=trace, op=te_functional_linear_backward_name) or operation_in_trace( + trace=trace, op=linear_bound_symbol_name_prefix, prefix=True + ): + return True + return False def is_backward_trace(trace: TraceCtx) -> bool: + """ + Test if a trace is a backward trace from its signature. + + Args: + trace: A computation trace. + """ sig = trace.signature_with_no_ctx() return sig.find("backward") >= 0 @@ -307,10 +377,27 @@ def benchmark_trace( apply_del_last_used=True, snapshot=False, snapshot_name="", - nvsight: bool = False, - nvsight_fn_name: str = "", + nsight: bool = False, + nsight_fn_name: str = "", **kwargs, ) -> tuple[float, float, Any]: + """ + Benchmark a generic computation trace compute time and peak memory usage. + nsight profiles can be generated if requested. + + If a backward trace is benchmarked, its paired forward trace is requested (with kwargs) as we don't generate inputs + for the backward call from the static args but with the dynamic arguments returned by the forward trace. + + Args: + trace: A computation trace. + iters: Benchmark iterations. + show_func: Print the executed trace if True. + apply_del_last_used: A flag to control if the trace should be executed after a deletion of not used vars call. + snapshot: A flag controlling if memory usage snapshots should be created (https://pytorch.org/docs/stable/torch_cuda_memory.html). + snapshot_name: A label for the generated snapshot. + nsight: A flag contolling if nvsigh profiles should be generated or not. + nsight_fn_name: A label for the nsight iteration name during benchmark loop. + """ from thunder.executors.passes import del_last_used import inspect @@ -334,7 +421,7 @@ def clone_args_if_needed(args): res.append(arg) return tuple(res) - def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, float, Any]: + def compute_time_cost_nsight(fn: Callable, iters: int, *args) -> tuple[float, float, Any]: try: warm_up_iters = 50 torch.cuda.empty_cache() @@ -350,7 +437,7 @@ def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, f for i in range(iters): new_args = clone_args_if_needed(args) torch.cuda.empty_cache() - torch.cuda.nvtx.range_push(f"thunder benchmark fn:{nvsight_fn_name}, iter{i}") + torch.cuda.nvtx.range_push(f"thunder benchmark fn:{nsight_fn_name}, iter{i}") fn(*new_args) torch.cuda.nvtx.range_pop() torch.cuda.cudart().cudaProfilerStop() @@ -360,7 +447,7 @@ def compute_time_cost_nvsight(fn: Callable, iters: int, *args) -> tuple[float, f import inspect trc = inspect.getsource(fn) - print(f"#Trace execution failed for nvsight (error: {e}):\n{trc}") + print(f"#Trace execution failed for nsight (error: {e}):\n{trc}") raise e def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[float, float, Any]: @@ -410,7 +497,7 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl raise e def build_static_args(sequence: Sequence, **kwargs) -> list: - return transform_proxy_to_torch(sequence, level=0, **kwargs) + return transform_proxies_to_real(sequence, level=0, **kwargs) def backward_trace_args_preprocess() -> list: if "fw_trace" not in kwargs: @@ -447,8 +534,8 @@ def backward_trace_args_preprocess() -> list: input_args.insert(0, saved_for_bw) else: # Currently single trace region backward trace receives as input the saved_for_bw tensors plus some others. - # They are indexed like [saved_for_bw, others...] - # NOTE: This may change in the future + # They are indexed like [saved_for_bw, others...]. + # NOTE: This may change in the future. """ Example: @torch.no_grad() @@ -468,7 +555,7 @@ def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causa (t5, t6, t7) = cudnn_sdpa_bwd(t4, query, key, value, None, dropout_p, is_causal, t0, t1, t2, t3, scale=None, cat_grad_qkv=False) return {'query': t5, 'key': t6, 'value': t7, 'attn_mask': None, 'dropout_p': None, 'is_causal': None, 'scale': None} - See how the backward trace need t4 as argument recoveered from the static args + See how the backward trace needs t4 as argument recoveered from the static args """ updated_input_args = [t for t in saved_for_bw_C0] updated_input_args.extend( @@ -516,29 +603,13 @@ def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causa m = float("inf") answer = None try: - if nvsight: - t, m, answer = compute_time_cost_nvsight(executable, iters, *input_args) + if nsight: + t, m, answer = compute_time_cost_nsight(executable, iters, *input_args) else: t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) - except Exception as e: + except Exception: import traceback - - ex_str = traceback.format_exc() - print(ex_str) - # https://github.com/Lightning-AI/lightning-thunder/issues/664 - # Seems that this patch never work ... - if ( - "call_method UserDefinedObjectVariable(set) __contains__ [UserDefinedObjectVariable()] {}" in str(e) - and not nvsight - ): - print( - "Executing with torch compile no full graph (this might still fail), see: https://github.com/Lightning-AI/lightning-thunder/issues/664" - ) - torch_compiled = torch.compile(executable, fullgraph=False) - try: - t, m, answer = compute_time_cost_ms(torch_compiled, executable_str, iters, *input_args) - except Exception as e: - print(f"Compiled trace execution still failed:\n{e}") + traceback.print_exc() finally: reset_tracectx(trace_tok) @@ -549,14 +620,14 @@ def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causa return t, m, answer -def register_impl_executor(ex: Executor, id: PrimIDs, fn: Callable, checker: Callable) -> None: +def _register_impl_executor(ex: Executor, id: PrimIDs, fn: Callable, checker: Callable) -> None: if ex.name == "nvfuser": from thunder.executors.nvfuserex_impl import register_supported register_supported(id, fn, checker) -def recover_ex_from_compile_option(option: str) -> Executor: +def _recover_ex_from_compile_option(option: str) -> Executor: if option.startswith("nv"): from thunder.executors.nvfuserex_impl import ex @@ -566,6 +637,16 @@ def recover_ex_from_compile_option(option: str) -> Executor: def wrap_fn_with_exeuctor_compile_option(option, fn: Callable | None = None, *args): + """ + Wraps a function call enabling a compile option for a specific executor. + The compile option will be restored after the function completes. + This can be usefull if we want to benchmark a specific compile option. + + Args: + option: The option to be enabled. + fn: A callable function. + args: Function arguments. + """ from thunder.core import compile_data cd = compile_data.get_compile_data() @@ -573,13 +654,12 @@ def wrap_fn_with_exeuctor_compile_option(option, fn: Callable | None = None, *ar # Update compile option context if cd is None: raise AssertionError("compile_data is None") - # TODO: use getattr old_opt: bool | None = cd.compile_options.get(option.fusion_tag, None) new_opt = True if old_opt is None or old_opt is False else False cd.compile_options[option.fusion_tag] = new_opt # Register the impl for the executor in order to be able to execute the id - register_impl_executor( - recover_ex_from_compile_option(option.fusion_tag), + _register_impl_executor( + _recover_ex_from_compile_option(option.fusion_tag), option.id, option.impl, option.checker, @@ -597,11 +677,27 @@ def wrap_fn_with_exeuctor_compile_option(option, fn: Callable | None = None, *ar def print_trace_args(trace: TraceCtx): + """ + Utility to display a trace arguments. + + Args: + trace: A computation trace. + """ print_nested_sequence(trace.args) -# Display nest sequence arguments def print_nested_sequence(args, show_dicts=False): + """ + Utility to display a sequence of elements with possible nested sequences. + Elements will be retrieved in a dfs manner. + + Args: + args: The input sequence. + show_dicts: Control if dict types should be printed. + """ + + import pprint + def is_tensor(t): return isinstance(t, torch.Tensor) or isinstance(t, TensorProxy) @@ -620,7 +716,7 @@ def _print(args, level): dtype = arg.dtype if is_tensor(arg) else None name = arg.name if isinstance(arg, TensorProxy) else "" print( - f'{tabs}{name + ": " if name else ""}{type(arg)}{arg if isinstance(arg, dict) and show_dicts else ""} {tensor_shape if tensor_shape else ""} {dtype if dtype else ""}' + f'{tabs}{name + ": " if name else ""}{type(arg)}{pprint.pformat(arg) if isinstance(arg, dict) and show_dicts else ""} {tensor_shape if tensor_shape else ""} {dtype if dtype else ""}' ) print(f"Level {level} end") @@ -629,6 +725,11 @@ def _print(args, level): def update_compile_options_executor_list_after_fw_bw_split() -> None: + """ + Updates the compile options with the executors that have been placed by the forward-backward split pass. + This utility can be used to save all the executors that have been effectively placed in a trace. + """ + from thunder.backend_optimizer.optimizer import get_fw_bw_split_backends_options cd = get_compile_data() @@ -653,10 +754,16 @@ def update_compile_options_executor_list_after_fw_bw_split() -> None: def transform_tensor(arg: TensorProxy, **kwargs) -> torch.Tensor: + """ + Retrive the associated torch.Tensor from a proxy tensor by reading its metadata. + This will allocate the real tensor in memory. + This utility can read transformer engine compilation requests and generate the associated FP8 tensor if needed. + + Args: + arg: The proxy tensor. + """ from thunder.core.dtypes import is_float_dtype, is_signedinteger_dtype, is_boolean_dtype - # TODO (matteochen): Missing parallel and fsdp handling... - # TODO (matteochen): Missing support for meta types ... dtype = arg.dtype shape = arg.shape device = arg.device @@ -698,13 +805,21 @@ def transform_tensor(arg: TensorProxy, **kwargs) -> torch.Tensor: return tensor -def transform_proxy_to_torch(sequence: Sequence, level=0, **kwargs) -> tuple | list: +def transform_proxies_to_real(sequence: Sequence, level=0, **kwargs) -> tuple | list: + """ + Retrieve a sequence of real arguments relative to a sequence of proxy arguments. + This supports also nested sequences in a recursive way. + + Args: + sequence: The input proxy sequence. + level: An utility integer representing the search dept. + """ from thunder.executors.transformer_engineex import Context as C res = [] for e in sequence: if type(e) is tuple: - res.append(transform_proxy_to_torch(e, level + 1, **kwargs)) + res.append(transform_proxies_to_real(e, level + 1, **kwargs)) else: if isinstance(e, TensorProxy): res.append(transform_tensor(e, **kwargs)) @@ -742,6 +857,16 @@ def transform_proxy_to_torch(sequence: Sequence, level=0, **kwargs) -> tuple | l def reorder_executors_list(executors: Sequence, **kwargs): + """ + Reorders a random executors list to be compatible with the autotuner compilation flow. + This will put in the front of the returned list all the executors with a grad fn. + All the other executors will be appended afterwards. + + If no fusion executors is present inside the input list, a default one will be added in order to trigger the autotuning process. + + Args: + executors: The executors to be reordered. + """ from thunder.backend_optimizer.optimizer import get_fw_bw_split_backends_options from thunder.executors.torch_compile import torch_compile_ex from thunder.executors.nvfuserex_impl import ex as nvfuser_ex @@ -780,7 +905,14 @@ def reorder_executors_list(executors: Sequence, **kwargs): def symbol_hash(bsym: BoundSymbol): - # Maintainig essential metadata + """ + Hash a bound symbol relying on its metadata (symbol name, bound symbol inputsa and outputs). + No hash functions will be applied in order to leave the output readable. + + Args: + bsym: A bound symbol. + """ + def _tensor_hash(t: TensorProxy) -> str: assert t.dtype shapes = [str(s) for s in t.shape] @@ -837,10 +969,26 @@ def _hash(bsym: BoundSymbol) -> str: # TODO: known_points can be used to detect start and end of a block sequence def repetead_trace_blocks( *, trace: TraceCtx, min_block_size=2, known_points: tuple[BoundSymbol, BoundSymbol] | None = None -) -> list[tuple]: +) -> list[tuple[int, int]]: + """ + Detects if are there repeated sections inside a given trace. + This utility can be employed on traces referring to transformer based models where the layers are repeated N times. + + The return list will contain a tuple of two elements pointing to the index (in the computation trace) of where a block starts and ends (both included). + + The variable min_block_size can be tuned in order to not allucinate this function by capturing unwanted sections (small sections) if no repeated transformer layers can be found. + + Args: + trace: A computation trace. + min_block_size: The minimum block lenght, by default 2. + known_points: If a practitioner already knows where a transformer layer starts and ends inside a given trace, these points can be supplied in order to speed up the search. Currently not implemented. + """ if min_block_size < 2: return [] + if known_points is not None: + raise RuntimeError('known_points research is not supported.') + symbols = [ s for s in trace.bound_symbols @@ -918,7 +1066,6 @@ def _range_seen(index: int, s: set): continue h = symbol_hash(bsym) - # Could not generate hash for this bsym # Normally, bsym are expected to output a TensorProxy if not isinstance(bsym.output, Proxy) or h in seen_hashes or _range_seen(i, seen_ranges): continue @@ -959,14 +1106,25 @@ def _range_seen(index: int, s: set): ] -# What is regions_between_blocks? -# They are trace regions between one transformer block and the next one in the backward pass and given that these regions are not present -# at the end of the last transformer block it means that they are needed in order to prepare shapes or strides -# for the block at i+1 from the output of block i. -# For example if common blocks looks like: [(32, 155), (157, 280)] -# the symbol at index 156 (the gap) looks like: -# In the forward trace we have not these gaps (so far). def _regions_between_blocks(trace: TraceCtx, common_blocks: list[tuple]) -> int: + """ + Retrieve the size of a gap region between common blocks. + + What is regions_between_blocks? + They are trace regions between one transformer block and the next one (usually found in the backward trace) and given that these regions are not present + at the end of the last transformer block it means that they are needed in order to prepare shapes or strides + for the block at i+1 from the output of block i. + For example if common blocks looks like: [(32, 155), (157, 280)] + the symbol at index 156 (the gap) could generally be: (for torch.float32 dtype, if another dtype is used the trace may contain other ops in this region leading to a larger gap). + In the forward trace we have not these gaps (so far). + + In the example above the returned value will be 1. + + Args: + trace: A computation trace. + common_blocks: A list containig the common blocks for the given trace. + + """ def _assert_args(seq_a: Sequence, seq_b: Sequence): assert len(seq_a) == len(seq_b) for a, b in zip(seq_a, seq_b): @@ -996,6 +1154,12 @@ def _assert_args(seq_a: Sequence, seq_b: Sequence): def _indices_to_exclude_between_common_blocks(common_blocks: list[tuple]) -> list: + """ + Retrive the indicies referring to the gaps between one common block and the next one. + + Args: + common_blocks: A computed common block list for a given trace. + """ if len(common_blocks) < 2: return [] @@ -1010,6 +1174,16 @@ def _indices_to_exclude_between_common_blocks(common_blocks: list[tuple]) -> lis def reduce_common_trace_blocks( *, trace: TraceCtx, common_blocks_in: list[tuple], skip_between_blocks: bool = True ) -> TraceCtx: + """ + Generate a reduced trace (shorter computation nodes) given a common block pattern. + + This can be useful to speed up the executor tuning for models with repeated layers. + + Args: + trace: A computation trace. + common_blocks_in: A previously computed common block pattern. + skip_between_blocks: A flag to control if gaps between common blocks should be included in the output trace or not. See _regions_between_blocks. + """ def _exclude(blocks: list[tuple[int, int]], index: int, black_list: set): # Exclude if the index is in a repeated block for block in blocks: @@ -1057,20 +1231,8 @@ def _find_bsym_index(out_name: str, space: Sequence[BoundSymbol]) -> int: b for i, b in enumerate(trace.bound_symbols) if not _exclude(common_blocks[1:], i, index_gaps_to_exclude) ] - # Retrive first and last blocks - first_block = common_blocks[0] - # common_blocks = common_blocks[1:] - - # Now, we have to update the trace region inputs after the last block to accepts the outputs of the first block, if it's not the return statement + # Now, we have to update the trace region inputs after the last block to accepts the outputs of the first block, if it's not the return statement. if trace.bound_symbols[common_blocks[-1][1] + 1].sym.id != PrimIDs.RETURN: - # first_block_outputs = trace.bound_symbols[first_block[1]].output - # last_block_outputs = trace.bound_symbols[common_blocks[-1][1]].output - - # if not isinstance(first_block_outputs, Sequence): - # first_block_outputs = [first_block_outputs] - # if not isinstance(last_block_outputs, Sequence): - # last_block_outputs = [last_block_outputs] - symbol_to_correct_index = _find_bsym_index( trace.bound_symbols[common_blocks[-1][1] + 1].output.name, bound_symbols ) @@ -1104,8 +1266,6 @@ def _correct_bsym(bsym: BoundSymbol) -> BoundSymbol: args=_correct_args(symbol_to_correct), subsymbols=new_subsymbols ) - # print(bound_symbols[symbol_to_correct_index]) - # We need to check also the return statements as we have fewer args now flatten_bsyms = flatten_sequence([b.output for b in bound_symbols]) args_remained = set([b.name for b in flatten_bsyms if b is not None and hasattr(b, "name")]) @@ -1130,7 +1290,6 @@ def _correct_bsym(bsym: BoundSymbol) -> BoundSymbol: bound_symbols[-1] = bsym # Bw trace else: - def _returned(seq: Sequence) -> tuple: ret = [] for e in seq: @@ -1155,12 +1314,22 @@ def _returned(seq: Sequence) -> tuple: return extrace -# NOTE: This implementation currently relies on the fact that transformer blocks are contiguous in trace -# or they have a common gap region between them (in case for bw trace). -# TODO: generalize this def map_executors_from_reduced_trace_to_complete_trace( complete_trace: TraceCtx, common_blocks: list[tuple], ex_mappings: list[Executor] ) -> list[Executor]: + """ + Generate executors mappings (trace region -> executor) for the complete trace once the optimization has been performed on a reduced trace. + + This implementation currently relies on the fact that transformer blocks are contiguous in trace + or they have a common gap region between them (in case for bw trace). + + The output executor list has size equal to the complete trace regions size. + + Args: + complete_trace: A computation trace. + common_blocks: A previously computed common block pattern. + ex_mappings: The executor mappings for the reduce trace. + """ from thunder.executors.torchex import ex as torch_ex if len(common_blocks) < 2: diff --git a/thunder/tests/test_autotuner.py b/thunder/tests/test_autotuner.py index 041da26920..baf24e7a02 100644 --- a/thunder/tests/test_autotuner.py +++ b/thunder/tests/test_autotuner.py @@ -336,12 +336,12 @@ def test_update_compile_options_executor_list_after_fw_bw_split(model, q, k, v, assert count == expected -def _test_transform_proxy_to_torch_fn_1(a: torch.Tensor, b: torch.Tensor, k: int): +def _test_transform_proxies_to_real_fn_1(a: torch.Tensor, b: torch.Tensor, k: int): t0 = a * b return t0 * k -def _test_transform_proxy_to_torch_fn_2( +def _test_transform_proxies_to_real_fn_2( a: torch.Tensor, b: torch.Tensor, c: tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]] ): t0 = c[0] + c[1][0] @@ -349,7 +349,7 @@ def _test_transform_proxy_to_torch_fn_2( return t1 - a + b -def _test_transform_proxy_to_torch_common( +def _test_transform_proxies_to_real_common( fn: Callable, torch_args: tuple, executors: list, has_backward: bool, **kwargs ): jitted = thunder.jit(fn, executors=executors) @@ -358,7 +358,7 @@ def _test_transform_proxy_to_torch_common( trace_static_args = thunder.last_traces(jitted)[-1].args assert trace_static_args - transformed_args = aut_utils.transform_proxy_to_torch(trace_static_args, **kwargs) + transformed_args = aut_utils.transform_proxies_to_real(trace_static_args, **kwargs) assert isinstance(transformed_args, list) @@ -389,7 +389,7 @@ def _comp(thunder_seq: Sequence, torch_seq: Sequence): trace_static_args = thunder.last_backward_traces(jitted)[-1].args assert trace_static_args - transformed_args = aut_utils.transform_proxy_to_torch(trace_static_args, **kwargs) + transformed_args = aut_utils.transform_proxies_to_real(trace_static_args, **kwargs) print(trace_static_args) # print(transformed_args) @@ -399,9 +399,9 @@ def _comp(thunder_seq: Sequence, torch_seq: Sequence): @pytest.mark.parametrize( "fn, torch_args, executors, has_backward", [ - (_test_transform_proxy_to_torch_fn_1, tuple([torch.randn(1, 1), torch.randn(1, 1), 10]), [], False), + (_test_transform_proxies_to_real_fn_1, tuple([torch.randn(1, 1), torch.randn(1, 1), 10]), [], False), ( - _test_transform_proxy_to_torch_fn_2, + _test_transform_proxies_to_real_fn_2, tuple([torch.randn(1, 1), torch.randn(1, 1), (torch.randn(1, 1), (torch.randn(1, 1), torch.rand(1, 1)))]), [], False, @@ -418,12 +418,12 @@ def _comp(thunder_seq: Sequence, torch_seq: Sequence): ), ], ) -def test_transform_proxy_to_torch(fn: Callable, torch_args: tuple, executors: list, has_backward: bool): - _test_transform_proxy_to_torch_common(fn, torch_args, executors, has_backward) +def test_transform_proxies_to_real(fn: Callable, torch_args: tuple, executors: list, has_backward: bool): + _test_transform_proxies_to_real_common(fn, torch_args, executors, has_backward) @requiresCUDA -def test_transform_proxy_to_torch_TE(): +def test_transform_proxies_to_real_TE(): class Model(torch.nn.Module): def __init__(self, in_features, out_features) -> None: super().__init__() @@ -435,7 +435,7 @@ def forward(self, x: torch.Tensor): model = Model(4096, 4096) model.to("cuda") - _test_transform_proxy_to_torch_common( + _test_transform_proxies_to_real_common( model, tuple([torch.randn(4096, 4096, requires_grad=True, device="cuda")]), ["transformer_engine"], From cce3ea3d7c454f338ac570b4e43192bfb390f88c Mon Sep 17 00:00:00 2001 From: matteochen Date: Sat, 31 Aug 2024 17:32:34 +0300 Subject: [PATCH 133/171] Docs and reorganization --- thunder/backend_optimizer/optimizer.py | 332 +++++++++++++++++-------- thunder/backend_optimizer/utils.py | 101 +++++++- thunder/core/vjp_utils.py | 4 +- thunder/tests/test_autotuner.py | 2 +- 4 files changed, 335 insertions(+), 104 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 6a3f66cb4e..98a11add02 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -13,61 +13,27 @@ from thunder.core.transforms import construct_trace from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_always_executors from typing import Hashable -from thunder.backend_optimizer.utils import benchmark_trace - - -# This fn is used before compile data being set, rely on kwargs -def get_fw_bw_split_backends_options(bsym: BoundSymbol | None = None, **kwargs) -> list | dict: - from thunder.executors.sdpaex import sdpa_ex - from thunder.executors.cudnnex import cudnn_ex - from thunder.executors.fa3ex import fa3_ex - from thunder.executors.transformer_engineex import transformer_engine_ex - - if kwargs is None or not kwargs.get("autotune_enable_te", False): - options: dict[str, list] = { - "scaled_dot_product_attention": [sdpa_ex, cudnn_ex, fa3_ex], - } - else: - options: dict[str, list] = { - "linear": [transformer_engine_ex], - "scaled_dot_product_attention": [sdpa_ex, cudnn_ex, fa3_ex], - } - - return options.get(bsym.sym.name, []) if bsym else options - - -class BenchmarkResult: - def __init__( - self, - *, - time: float = float("inf"), - memory: float = float("inf"), - trace: TraceCtx = TraceCtx(), - label: str | Hashable = "", - index: int = -1, - ) -> None: - self.runtime: float = time - self.memory: float = memory - self.trace: TraceCtx = trace - self.label: str | Hashable = label - self.index: int = index - - -class OptimizerType(Enum): - MEMORY = 0 - RUNTIME = 1 - - -class TraceType(Enum): - FW = 0 - BW = 1 +from thunder.backend_optimizer.utils import benchmark_trace, BenchmarkResult, OptimizerType, TraceType, LogLevel, log class OptimizationAlgorithm(Enum): + """ + Represents the optimization technique used by the autotuner. + """ BEST_FUSER = 0 class FusionCompileOptionsHelper: + """ + Represents compile options for a fusion executor. + + Attributes: + fusion_tag: A label representing the fusion ops regarding a compile option (e.g. nv_linear). + symbol_tag: The symbol name + id: The symbol id. + impl: A callable implementation. + checker: A callable checker. + """ def __init__(self, fusion_tag: str, symbol_tag: str, id: PrimIDs, impl: Callable, checker: Callable) -> None: self.fusion_tag = fusion_tag self.symbol_tag = symbol_tag @@ -77,6 +43,18 @@ def __init__(self, fusion_tag: str, symbol_tag: str, id: PrimIDs, impl: Callable class TraceCandidate: + """ + Represents an optimal trace candidate. + + Attributes: + trace: The candidate trace. + compile_opt: Any compile options used for the current candidate. + label: A generic label. + symbol_tag: The symbol name + id: The symbol id. + impl: A callable implementation. + checker: A callable checker. + """ def __init__(self, *, trace: TraceCtx, compile_opt: FusionCompileOptionsHelper | None = None, label: str) -> None: self.trace: TraceCtx = trace self.compile_opt: FusionCompileOptionsHelper | None = compile_opt @@ -84,6 +62,15 @@ def __init__(self, *, trace: TraceCtx, compile_opt: FusionCompileOptionsHelper | class TraceCandidates: + """ + Represents an optimal pair of trace candidates (compute time and memory consumption). + + Attributes: + best_time: The trace with the optimal runtime. + best_mem: The trace with the optimal peak memory consumption. + compile_opt_time: Any compile options used for a fusion executor regarding the first trace. + compile_opt_mem: Any compile options used for a fusion executor regarding the second trace. + """ def __init__( self, best_time: TraceCtx | None = None, @@ -97,25 +84,52 @@ def __init__( self.compile_opt_mem: FusionCompileOptionsHelper | None = compile_opt_mem def __repr__(self) -> str: + """ + Give a representation for the current object. + """ return f"\nBest runtime candidate:\n{self.best_time}\nBest memory candidate:\n{self.best_mem}" def is_set(self) -> bool: + """ + Check that the optimal trace pair has been set. + """ return False if self.best_time is None or self.best_mem is None else True def attach_best_time_candidate(self, trace: TraceCtx): + """ + Attach a new best time trace result. + """ self.best_time = trace def attach_best_mem_candidate(self, trace: TraceCtx): + """ + Attach a new best memory trace result. + """ self.best_mem = trace def iterable(self) -> tuple[TraceCtx | None, TraceCtx | None]: + """ + Returns an iterable object over the traces contained in the current object. + """ return self.best_time, self.best_mem def compile_opt_iterables(self) -> tuple[FusionCompileOptionsHelper | None, FusionCompileOptionsHelper | None]: + """ + Returns an iterable object over the compile options used in the traces contained in the current object. + """ return self.compile_opt_time, self.compile_opt_mem class OutputCandidate: + """ + Represents a final output candidate: forward and backward trace pair. + + Attributes: + fw: The forward trace. + bw: The backward trace. + compile_opt: Any compile options being used for a fusion executor. + tot_cost: The total cost to execute the pair (ms for a time strategy and GB for a memory strategy). + """ def __init__( self, *, fw: TraceCtx, bw: TraceCtx, compile_opt: FusionCompileOptionsHelper | None = None, cost: float = 0.0 ) -> None: @@ -125,14 +139,23 @@ def __init__( self.tot_cost: float = cost def __repr__(self) -> str: + """ + Give a representation of the current object. + """ return f"Final output candidate: forward trace:\n{self.fw.__repr__()}\nFinal output candidate: backward trace:\n{self.bw.__repr__()}" -# Benchmark only traces will contain traces after the rematerialization call with fw and bw calls, reproducing what will be the real traces after the autotune pass -# Non benchmark traces will contain traces after the placement (default) with no call to remat -# We have duplciated those in order to maintain thunder compilation flow as the output from the autotuner will be the traces with no pass through rematerialization -# TODO: torchcompile_cat currently is not supported as the autotuner search space in the FusionExecutor section is limited to 1 class FusionStratHelper: + """ + Represents a helper structure for the fusion strategy. + + Attributes: + supported_executors: A list of supported fusion executors. + optimized_traces_mem: a list of dictionaries containing informations regarding the optimized traces for peak memory consumption. + optimized_traces_mem_benchmark_only: a list of dictionaries containing informations regarding the optimized traces for peak memory consumption (used only for internal benchmarking). + optimized_traces_time: a list of dictionaries containing informations regarding the optimized traces for total compute time. + optimized_traces_time_benchmark_only: a list of dictionaries containing informations regarding the optimized traces for total compute time (used only for internal benchmarking). + """ def __init__(self) -> None: self.supported_executors: set = set(["nvfuser", "torchcompile"]) self.optimized_traces_mem: list[dict[str | Hashable, tuple[TraceCtx, FusionCompileOptionsHelper | None]]] = [] @@ -142,31 +165,54 @@ def __init__(self) -> None: class FusionExecutorsPlacementCtx: + """ + Represents a executor placement context. + + Attributes: + placement: A list of executors. + compile_options: Any compile options being used for the fusion executor contained in the placement. + """ def __init__(self, *, placement: list, compile_options: FusionCompileOptionsHelper | None = None) -> None: self.placement: list = placement self.compile_options: FusionCompileOptionsHelper | None = compile_options class ExecutorPlacementOptions: + """ + Represents an aggregate placement options for executors combining those that targets peak memory consumption and those for total compute time. + + Attributes: + placement_options_mem: A list of placement contexts. + placement_options_time: A list of placement contexts. + """ def __init__(self) -> None: self.placement_options_mem: list[FusionExecutorsPlacementCtx] = [] self.placement_options_time: list[FusionExecutorsPlacementCtx] = [] -class LogLevel(Enum): - DEBUG = 0 - INFO = 1 - - -log_level: LogLevel = LogLevel.INFO - - -def log(what: str, level: LogLevel = LogLevel.INFO): - if log_level == LogLevel.DEBUG or log_level == level: - print(f"================================================================================ Autotune: {what}") - - class PlacerBase: + """ + Represents a base (interface) class for a placement class. + + Attributes: + always_executors: A list of always present executors. + empty_executor_hashable_placeholder: A label representing en empty executor. + executors: A list of executors to use. + fusion_executors: A list of fusion executors to use. + fusion_executors_saved_for_later: A helper list containing maybe repeated fusion executors. + debug_msg: A dynamic filled log message. + log_file_name: The output log file name if generated. + produce_log: A tuning parameter to control log file generation. + optimizer_type: The optimization target. + active_fw_trace_ctx: An active set forward trace (in the object scope). + cached_fw_traces: Cached forward traces. + bw_trace_candidates: An instance of trace candidates. + best_pair_runtime: A final traace pair targetting the compute time. + best_pair_memory: A final traace pair targetting the peak memory consumption. + apply_bucketing_bw_trace: A distributed flag. + benchmark_iters: Benchmark iteration steps. + compile_data: Thunder compilation data. + """ def __init__( self, *, @@ -187,7 +233,6 @@ def __init__( self.fusion_executors_saved_for_later: Sequence[FusionExecutor] = [] self.debug_msg: str = "" - self.partial_costs: dict[TraceCtx, float] = {} self.log_file_name: str = log_file_name self.produce_log: bool = produce_log @@ -207,19 +252,44 @@ def __init__( self.compile_data = compile_data def optimize(self): + """ + Optimize the executor placement for the current trace. + """ pass def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): + """ + Attach a new trace for executors optimization. + + Args: + trace: The trace to attach. + trace_type: Forward or backward trace refrence. + """ pass def get_optimal_fw_traces(self) -> Sequence[TraceCtx]: + """ + Retrive the optimal forward traces that the object has tuned. + """ return [] def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: + """ + Retrive the optimal forward and backward trace pair. + """ return (TraceCtx(), TraceCtx()) class FusionPlacer_BeamSearch(PlacerBase): + """ + Represents a placer targetting the fusion regions. + + Attributes: + fusion_strat_helper: A helper structures to save intermediate values. + executor_placement_options: A helper structures to save different intemediate executor placement. + is_reduced: A flag indicating if the current trace under optimization is a reduced version of a bigger trace (by common blocks reduction). + cached_original_trace: A reference to the original trace if the optmization is performed on a reduced version. + """ def __init__( self, *, @@ -261,7 +331,8 @@ def __init__( } # Transformer based models optimization - # TODO: explain + # For models based on layers of transformer blocks we can optimize the tuning by researching the best placement + # on the model with a single layer and then mirror the configuration to the other layers. self.is_reduced: bool = False self.cached_original_trace: TraceCtx | None = None @@ -270,6 +341,12 @@ def __init__( """ def _best_runtime_and_memory_candidates(self, candidates): + """ + Retrive the best compute time and peak memory consumption trace pairs. + + Args: + candidates: A sequence of possible candidates. + """ from thunder.core.rematerialization import rematerialize_forward_and_backward from thunder.backend_optimizer.utils import benchmark_trace @@ -328,12 +405,17 @@ def _best_runtime_and_memory_candidates(self, candidates): return best_pair_runtime, best_pair_memory def _filter_candidates(self): + """ + Reduce the solutions count by comparing different options across different fusion executors. + + For forward traces all the options are cached. + """ self.debug_msg += "Traces benchmarks:\n\n" # We cache every optimized fw traces as they might impact differently on the bw trace # Number of fw traces to cached are: #fusion_executors * 2 def fw_benchmark(): - # The optimizator builds the results in order following the self.fusion_executors list order + # The optimizer builds the results in order following the self.fusion_executors list order pair_time: dict pair_mem: dict for pair_time, pair_mem in zip( @@ -392,15 +474,14 @@ def bw_benchmark(): ) # Here we have to recover the traces without the pass through remat in order to be compliant - # with thunder flow as we might have request for no remat - # Unpack dict + # with thunder flow as we might have request for no remat. trc = list(self.fusion_strat_helper.optimized_traces_time[time_result.index].values())[0][0] self.bw_trace_candidates.attach_best_time_candidate(trc) trc = list(self.fusion_strat_helper.optimized_traces_mem[memory_result.index].values())[0][0] self.bw_trace_candidates.attach_best_mem_candidate(trc) # Now, finally build the pair fw and bw traces - # The current fw trace is set by the caller and we take it as is. All current bw traces optimizations are made with the fw trace set by the caller + # The current fw trace is set by the caller and we take it as is. All current bw traces optimizations are made with the fw trace set by the caller. for bw in self.bw_trace_candidates.iterable(): self.out_traces_candidates.append( OutputCandidate(fw=self.active_fw_trace_ctx[0], bw=bw, compile_opt=self.active_fw_trace_ctx[1]) @@ -423,6 +504,15 @@ def bw_benchmark(): self.debug_msg = "" def _search_candidates(self, increment_factor: int = 1): + """ + For the current trace generate all the placement candidates. + + For each fusion executor the time-memory pair candidates will be generated and cached. + If any compile options for an executor is available, it will be take under consideration. + + Args: + increment_factor: An integer controlling the increment step during the fusion exclusion to speed up the compilation. + """ from thunder.executors.data_dependent_partition import Node, fuse_bound_symbols from thunder.backend_optimizer.utils import ( get_not_used_intermediate_outsputs, @@ -433,6 +523,13 @@ def _search_candidates(self, increment_factor: int = 1): ) def get_placed_trace(mapping: dict[str, Executor], bound_symbols_in: Sequence[BoundSymbol]): + """ + Generates a trace with the requested executors. + + Args: + mapping: a dictionary pointing to the assigned executor for a trace region. + bound_symbols_in: Input trace regions. + """ trc = from_trace(self.trace) trc.bound_symbols = list(bound_symbols_in) @@ -484,11 +581,20 @@ def get_placed_trace(mapping: dict[str, Executor], bound_symbols_in: Sequence[Bo def _search(ex: FusionExecutor, executor_compile_option: FusionCompileOptionsHelper | None = None): """ - Fusable fn definition for nvFuser - """ + For the given executor search and cached the best placements. - # Each executor has a custom should fuse function, but the current impl need to access local executor object + Args: + ex: A fusion executor. + executor_placement_options: Any compile option this executor might activate. + """ def _should_fuse_nvfuser(a: Node, b: Node): + """ + Fusable fn definition for nvFuser. + + Args: + a: First node. + b: Second node. + """ def _can_fuse_node(n: Node): # if already merged, then node can be fused if len(n.group_bsyms) > 1: @@ -500,11 +606,15 @@ def _can_fuse_node(n: Node): return _can_fuse_node(a) and _can_fuse_node(b) - """ - Fusable fn definition for torch.compile - """ def _should_fuse_torchcompile(a: Node, b: Node): + """ + Fusable fn definition for torch.compile. + + Args: + a: First node. + b: Second node. + """ def _can_fuse_node(n: Node): if len(n.group_bsyms) > 1: return True @@ -513,7 +623,15 @@ def _can_fuse_node(n: Node): return _can_fuse_node(a) and _can_fuse_node(b) - def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): + def match_bsym_executor(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): + """ + Match a bound symbol to its executor. + + Args: + bsym_in: The bound symbol to match. + dicts: The matrching destination. + ex_in: The executor to assign. + """ if isinstance(bsym_in.output, Sequence): for d in dicts: d[sequence_hash(bsym_in.output)] = ex_in @@ -548,6 +666,7 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): dict_time_strat: dict[str, Executor] = {} dict_mem_strat: dict[str, Executor] = {} increasing_symbols = [] + # Tuning starting point: iterate over all the groups. for group_id, group in enumerate(bound_symbol_groups): log(f"Fusion group id: {group_id}", level=LogLevel.DEBUG) log( @@ -584,7 +703,7 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): # Not executors available if not candidate_executors: - match_bsym_output( + match_bsym_executor( current_bsym, [dict_time_strat, dict_mem_strat], Executor(name=self.empty_executor_hashable_placeholder), @@ -605,9 +724,7 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): candidate_best_time = BenchmarkResult(index=0) candidate_best_mem = BenchmarkResult(index=0) else: - # Search for best candidate, by default remat will be called to find the optimal choice - # TODO: enable requests for no remat becnhmarks - # TODO: we should consider also FusionExecutor that can execute this single bsym in this beam search + # Search for best candidate for i, candidate in enumerate(candidate_executors): from thunder.common import transform_for_execution @@ -640,8 +757,8 @@ def match_bsym_output(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor): level=LogLevel.DEBUG, ) - match_bsym_output(current_bsym, [dict_time_strat], candidate_executors[candidate_best_time.index]) - match_bsym_output(current_bsym, [dict_mem_strat], candidate_executors[candidate_best_mem.index]) + match_bsym_executor(current_bsym, [dict_time_strat], candidate_executors[candidate_best_time.index]) + match_bsym_executor(current_bsym, [dict_mem_strat], candidate_executors[candidate_best_mem.index]) # Go to next bsym group continue @@ -686,7 +803,7 @@ def measure_and_update_result(): if last_embedding_idx != -1: # Until last_embedding_idx (included) assigned to current fusion ex for i in range(0, last_embedding_idx + 1, 1): - match_bsym_output(group[i], [dict_time_strat, dict_mem_strat], ex) + match_bsym_executor(group[i], [dict_time_strat, dict_mem_strat], ex) if last_embedding_idx == len(group) - 1: # Benchmark @@ -695,7 +812,7 @@ def measure_and_update_result(): start_idx = last_embedding_idx + 1 n_missing_bsyms = len(group) - start_idx - # TODO (matteochen): consider to add the iteration with no fusion regions + # Tune a single fusion group. for i in range(0, n_missing_bsyms, n_missing_bsyms - 1 if self.trace_type == TraceType.BW else 1): if ex.name == "torchcompile": import torch @@ -707,9 +824,9 @@ def measure_and_update_result(): # -> First iteration is the one with fusion region with single element # -> Last iteration gives the complete fusion region for j in range(start_idx, start_idx + i + 1, increment_factor): - match_bsym_output(group[j], [dict_time_strat, dict_mem_strat], ex) + match_bsym_executor(group[j], [dict_time_strat, dict_mem_strat], ex) for k in range(start_idx + i + 1, len(group), increment_factor): - match_bsym_output( + match_bsym_executor( group[k], [dict_time_strat, dict_mem_strat], # In order to benchmark the fusion placecement, we can use any executor for the excluded bsym from the fusion region @@ -776,7 +893,7 @@ def measure_and_update_result(): raise AssertionError(f"Type not handled: {type(bsym.output)}") # For the forward trace we benchmark (memory) the mocked return statement as we don't know which - # tensor will be returned after the rematerialize_forward_and_backward call in order to do not underestimate the memory consumption + # tensor will be returned after the rematerialize_forward_and_backward call in order to do not under/over-estimate the memory consumption trace = self.trace if self.trace_type == TraceType.FW: trace = from_trace(self.trace) @@ -801,7 +918,6 @@ def measure_and_update_result(): ) container.append({ex.name: trc}) - # Save executors in order to generate real fw and bw trace with correct output with the placer # We add any provided compile option reference self.executor_placement_options.placement_options_time.append( FusionExecutorsPlacementCtx(placement=executors_time, compile_options=executor_compile_option) @@ -810,8 +926,7 @@ def measure_and_update_result(): FusionExecutorsPlacementCtx(placement=executors_mem, compile_options=executor_compile_option) ) - # If executor specific compile option is activated we need to know where a specific - # trace does come from and the zip logic afterward can not be employed with self.fusion_executors list + # If any compile options will be used we will need to have duplicated executors inside the executors list to maintain the matching. self.fusion_executors_saved_for_later = [] ex: FusionExecutor for ex in self.fusion_executors: @@ -919,7 +1034,7 @@ def _optimize(): level=LogLevel.INFO, ) - # This performs executor search + # This performs executor tuning self._search_candidates() # From now on we have the optimized executors for each trace region. Apply them... @@ -950,7 +1065,6 @@ def _optimize(): # Reset original trace self.trace = self.cached_original_trace - # We will create the best compute time and peak memory consumption placement for each fusion executor for placement_ctx, ex in zip( self.executor_placement_options.placement_options_time, self.fusion_executors_saved_for_later @@ -989,13 +1103,13 @@ def _optimize(): # Clear any previous results self.cached_fw_traces = [] _optimize() - # We have multiple cached optimized fw traces, find the best backward - # TODO: make this prettier with a machine state for example + # We have multiple cached optimized fw traces, this iteration will create a fw-bw pair for + # every cached forward trace. At the end the best one will be picked up. case TraceType.BW: # Clear any previous results self.out_traces_candidates = [] - # Cached the bw trace as we need to modify the input trace during the loop + # Cached the bw trace as we need to modify the self.trace during the loop cached_self_trace = from_trace(self.trace) cached_self_trace.bound_symbols = list(self.trace.bound_symbols) @@ -1016,6 +1130,7 @@ def _optimize(): self.trace = update_bw_from_forward_optimization(fw=fw_trace_candidate.trace, bw=self.trace) + # Taken from: https://github.com/Lightning-AI/lightning-thunder/blob/339a782e3d75061a065a3d2e47b5206f23aea7c3/thunder/executors/torch_autograd.py#L222 if self.apply_bucketing_bw_trace: from thunder.distributed.transforms import FSDPCommBucketing @@ -1030,12 +1145,19 @@ def _optimize(): else: _optimize() + # For every pair being generated filter out the best choice. self.best_pair_runtime, self.best_pair_memory = self._best_runtime_and_memory_candidates( self.out_traces_candidates ) class BackendOptimizer: + """ + Represents a generic backend optimizer. + + Attributes: + optimizer: An optimizer instance based on the configurations. + """ def __init__( self, *, @@ -1066,13 +1188,29 @@ def __init__( ) def optimize(self): + """ + Optimize the executor placement for the current trace. + """ self.optimizer.optimize() def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): + """ + Attach a new trace for executors optimization. + + Args: + trace: The trace to attach. + trace_type: Forward or backward trace refrence. + """ self.optimizer.attach_trace(trace=trace, trace_type=trace_type) def get_optimal_fw_traces(self) -> Sequence[TraceCtx]: + """ + Retrive the optimal forward traces that the object has tuned. + """ return self.optimizer.get_optimal_fw_traces() def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: + """ + Retrive the optimal forward and backward trace pair. + """ return self.optimizer.get_optimal_fw_bw_traces() diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 6da45c363c..71faf60659 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -22,6 +22,72 @@ from itertools import chain import torch from thunder.core.dtypes import dtype +from enum import Enum + + +class LogLevel(Enum): + """ + Represents a log level. + """ + DEBUG = 0 + INFO = 1 + + +log_level: LogLevel = LogLevel.INFO + + +def log(what: str, level: LogLevel = LogLevel.INFO): + """ + Conditionally print to stdout. + + Args: + what: The content to print. + level: Tuning parameter to print the content only if allowed by the configuration. + """ + if log_level == LogLevel.DEBUG or log_level == level: + print(f"================================================================================ Autotune: {what}") + +class TraceType(Enum): + """ + Represents the nature of a trace, if forward (computational) or backward. + """ + FW = 0 + BW = 1 + + +class BenchmarkResult: + """ + Represents a trace benchmark result information. + + Attributes: + time: Benchmark computation time. + memory: Benchmark peak memory usage. + trace: Computaiton trace. + label: A generic label. + index: A generic index in a sequence. + """ + def __init__( + self, + *, + time: float = float("inf"), + memory: float = float("inf"), + trace: TraceCtx = TraceCtx(), + label: str | Hashable = "", + index: int = -1, + ) -> None: + self.runtime: float = time + self.memory: float = memory + self.trace: TraceCtx = trace + self.label: str | Hashable = label + self.index: int = index + + +class OptimizerType(Enum): + """ + Represents the autotuner target. + """ + MEMORY = 0 + RUNTIME = 1 # Maybe we can use id(s) @@ -137,7 +203,6 @@ def unpack_output(out) -> Sequence[Proxy]: break if not f: ans.append(e) - from thunder.backend_optimizer.optimizer import log, LogLevel log(f"Returning not used proxies: {[p.name if hasattr(p, 'name') else p for p in ans ]}", level=LogLevel.DEBUG) return ans @@ -730,8 +795,6 @@ def update_compile_options_executor_list_after_fw_bw_split() -> None: This utility can be used to save all the executors that have been effectively placed in a trace. """ - from thunder.backend_optimizer.optimizer import get_fw_bw_split_backends_options - cd = get_compile_data() assert cd @@ -867,7 +930,6 @@ def reorder_executors_list(executors: Sequence, **kwargs): Args: executors: The executors to be reordered. """ - from thunder.backend_optimizer.optimizer import get_fw_bw_split_backends_options from thunder.executors.torch_compile import torch_compile_ex from thunder.executors.nvfuserex_impl import ex as nvfuser_ex @@ -1377,3 +1439,34 @@ def map_executors_from_reduced_trace_to_complete_trace( ) return complete_trace_executors + +# This fn is used before compile data being set, rely on kwargs +def get_fw_bw_split_backends_options(bsym: BoundSymbol | None = None, **kwargs) -> list | dict: + """ + Retrieves the executors tuning options for the vector jacobian product pass. + These executors must be tuned at the vjp stage as we have to choose the correspective backward grad function. + + For new executors support the followig lists can be expanded. + + A guard is put for the transformer_engine_ex as its usage should not be tuned if not requested in a explicit way. + + Args: + bsym: The query bound symbol. + """ + from thunder.executors.sdpaex import sdpa_ex + from thunder.executors.cudnnex import cudnn_ex + from thunder.executors.fa3ex import fa3_ex + from thunder.executors.transformer_engineex import transformer_engine_ex + + if kwargs is None or not kwargs.get("autotune_enable_te", False): + options: dict[str, list] = { + "scaled_dot_product_attention": [sdpa_ex, cudnn_ex, fa3_ex], + } + else: + options: dict[str, list] = { + "linear": [transformer_engine_ex], + "scaled_dot_product_attention": [sdpa_ex, cudnn_ex, fa3_ex], + } + + return options.get(bsym.sym.name, []) if bsym else options + diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index ee90b90250..2e34d26a7b 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -221,9 +221,9 @@ def bw_fn(*args, **kwargs): return fw_fn, bw_fn # We have a backend else: - from thunder.backend_optimizer.optimizer import get_fw_bw_split_backends_options + from thunder.backend_optimizer.utils import get_fw_bw_split_backends_options from thunder.backend_optimizer.utils import benchmark_trace - from thunder.backend_optimizer.optimizer import log, LogLevel + from thunder.backend_optimizer.utils import log, LogLevel # In order define this unique trace region we need an unique id key = (bsym.sym, Executor(f"{id(bsym)}-autotuned"), subkey := _make_cache_key(bsym.args, bsym.kwargs)) diff --git a/thunder/tests/test_autotuner.py b/thunder/tests/test_autotuner.py index baf24e7a02..07d4c6c98e 100644 --- a/thunder/tests/test_autotuner.py +++ b/thunder/tests/test_autotuner.py @@ -1,4 +1,4 @@ -from thunder.backend_optimizer.optimizer import get_fw_bw_split_backends_options +from thunder.backend_optimizer.utils import get_fw_bw_split_backends_options from thunder.core.dtypes import to_torch_dtype from thunder.core.prims import PrimIDs from thunder.core.proxies import FloatProxy, IntegerProxy, TensorProxy From 42e4b6b13edf4d48431ca6a77583e09819239fed Mon Sep 17 00:00:00 2001 From: matteochen Date: Sun, 1 Sep 2024 19:20:41 +0300 Subject: [PATCH 134/171] Using python logger --- examples/dev/LLaMAMLP.py | 7 +- examples/dev/litGPT.py | 1 + examples/dev/nanogpt.py | 1 + examples/dev/nvfuser_optimizations.py | 1 + examples/dev/sdpa.py | 7 +- examples/dev/te.py | 3 +- thunder/backend_optimizer/optimizer.py | 138 ++++++++++++------------- thunder/backend_optimizer/utils.py | 44 +++----- thunder/benchmarks/utils.py | 2 + thunder/core/vjp_utils.py | 17 ++- 10 files changed, 103 insertions(+), 118 deletions(-) diff --git a/examples/dev/LLaMAMLP.py b/examples/dev/LLaMAMLP.py index f69a687e9b..21ff039207 100644 --- a/examples/dev/LLaMAMLP.py +++ b/examples/dev/LLaMAMLP.py @@ -2,6 +2,7 @@ This benchmark script is intended to demonstrate the optimizer on a generic model. No executor are given leaving full responsibility to the engine. """ + import torch import thunder from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_total_benchmark @@ -30,11 +31,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: model = LLaMAMLP(a, b) jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit( - model, - autotune_type="runtime", - autotune_enable_te=True - ) + jmodel_auto = thunder.jit(model, autotune_type="runtime", autotune_enable_te=True) print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item()) print("deviation auto:", (jmodel_auto(x) - model(x)).abs().max().item()) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index a20209e136..56f9bb74b3 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -2,6 +2,7 @@ This script benchmarks litGPT models in a easier wrt to benchmark_litgpt.py way with a fake training loop with no optimizers in order to focus more on forward and backward computation time and not others kernel during the loop. """ + from litgpt import GPT from thunder.benchmarks.utils import ( thunder_fw_bw_benchmark, diff --git a/examples/dev/nanogpt.py b/examples/dev/nanogpt.py index f0b1331e49..bd6e7323a7 100644 --- a/examples/dev/nanogpt.py +++ b/examples/dev/nanogpt.py @@ -2,6 +2,7 @@ This benchmark script is intended to demonstrate the optimizer on nanoGPT model. The script runner is taken from: https://github.com/karpathy/nanoGPT/blob/master/bench.py """ + import torch import thunder from thunder.benchmarks.utils import thunder_fw_bw_benchmark diff --git a/examples/dev/nvfuser_optimizations.py b/examples/dev/nvfuser_optimizations.py index 7055a04df3..6a97191600 100644 --- a/examples/dev/nvfuser_optimizations.py +++ b/examples/dev/nvfuser_optimizations.py @@ -3,6 +3,7 @@ nvFuser compile options can be autotune with the argument `autotune_enable_nvfuser_all=True`. """ + import torch import thunder from thunder.benchmarks.utils import ( diff --git a/examples/dev/sdpa.py b/examples/dev/sdpa.py index 6aeaed4d13..f50ff2bf84 100644 --- a/examples/dev/sdpa.py +++ b/examples/dev/sdpa.py @@ -4,6 +4,7 @@ Set the log level at least to INF0 in `thunder/backend_optimizer/optimizer.py`. """ + import torch import thunder from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_total_benchmark @@ -11,6 +12,7 @@ dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 torch.set_default_dtype(dtype) + class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -21,13 +23,12 @@ def forward(self, query, key, value): b = torch.nn.functional.scaled_dot_product_attention(query + query, key + key, value + value) return a + b + with torch.device("cuda"): model = Model() jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit( - model, autotune_type="runtime", executors=["nvfuser", "cudnn", "sdpa"] - ) + jmodel_auto = thunder.jit(model, autotune_type="runtime", executors=["nvfuser", "cudnn", "sdpa"]) q = torch.rand(32, 8, 128, 64 * 1, requires_grad=True) k = torch.rand(32, 8, 128, 64 * 1, requires_grad=True) diff --git a/examples/dev/te.py b/examples/dev/te.py index 6f8d08d691..f4ec28a606 100644 --- a/examples/dev/te.py +++ b/examples/dev/te.py @@ -3,6 +3,7 @@ This option can be enabled inside the autotuner by using the flag `autotune_enable_te=True`. """ + import torch import thunder from thunder.benchmarks.utils import ( @@ -41,7 +42,7 @@ def forward(self, x: torch.Tensor): "nvfuser", "transformer_engine", ], - autotune_enable_te=True + autotune_enable_te=True, ) y = jmodel_def(x) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 98a11add02..eff21819c2 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -11,15 +11,20 @@ from thunder.core.symbol import BoundSymbol from thunder.core.trace import from_trace, TraceCtx from thunder.core.transforms import construct_trace -from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_always_executors +from thunder.extend import Executor, FusionExecutor, get_always_executors from typing import Hashable -from thunder.backend_optimizer.utils import benchmark_trace, BenchmarkResult, OptimizerType, TraceType, LogLevel, log +from thunder.backend_optimizer.utils import benchmark_trace, BenchmarkResult, OptimizerType, TraceType +import logging + +logging.basicConfig(level=logging.INFO, format="{name} {message}", style="{") +logger = logging.getLogger("Autotuner: ") class OptimizationAlgorithm(Enum): """ Represents the optimization technique used by the autotuner. """ + BEST_FUSER = 0 @@ -34,6 +39,7 @@ class FusionCompileOptionsHelper: impl: A callable implementation. checker: A callable checker. """ + def __init__(self, fusion_tag: str, symbol_tag: str, id: PrimIDs, impl: Callable, checker: Callable) -> None: self.fusion_tag = fusion_tag self.symbol_tag = symbol_tag @@ -55,6 +61,7 @@ class TraceCandidate: impl: A callable implementation. checker: A callable checker. """ + def __init__(self, *, trace: TraceCtx, compile_opt: FusionCompileOptionsHelper | None = None, label: str) -> None: self.trace: TraceCtx = trace self.compile_opt: FusionCompileOptionsHelper | None = compile_opt @@ -71,6 +78,7 @@ class TraceCandidates: compile_opt_time: Any compile options used for a fusion executor regarding the first trace. compile_opt_mem: Any compile options used for a fusion executor regarding the second trace. """ + def __init__( self, best_time: TraceCtx | None = None, @@ -130,6 +138,7 @@ class OutputCandidate: compile_opt: Any compile options being used for a fusion executor. tot_cost: The total cost to execute the pair (ms for a time strategy and GB for a memory strategy). """ + def __init__( self, *, fw: TraceCtx, bw: TraceCtx, compile_opt: FusionCompileOptionsHelper | None = None, cost: float = 0.0 ) -> None: @@ -156,6 +165,7 @@ class FusionStratHelper: optimized_traces_time: a list of dictionaries containing informations regarding the optimized traces for total compute time. optimized_traces_time_benchmark_only: a list of dictionaries containing informations regarding the optimized traces for total compute time (used only for internal benchmarking). """ + def __init__(self) -> None: self.supported_executors: set = set(["nvfuser", "torchcompile"]) self.optimized_traces_mem: list[dict[str | Hashable, tuple[TraceCtx, FusionCompileOptionsHelper | None]]] = [] @@ -172,6 +182,7 @@ class FusionExecutorsPlacementCtx: placement: A list of executors. compile_options: Any compile options being used for the fusion executor contained in the placement. """ + def __init__(self, *, placement: list, compile_options: FusionCompileOptionsHelper | None = None) -> None: self.placement: list = placement self.compile_options: FusionCompileOptionsHelper | None = compile_options @@ -185,6 +196,7 @@ class ExecutorPlacementOptions: placement_options_mem: A list of placement contexts. placement_options_time: A list of placement contexts. """ + def __init__(self) -> None: self.placement_options_mem: list[FusionExecutorsPlacementCtx] = [] self.placement_options_time: list[FusionExecutorsPlacementCtx] = [] @@ -213,6 +225,7 @@ class PlacerBase: benchmark_iters: Benchmark iteration steps. compile_data: Thunder compilation data. """ + def __init__( self, *, @@ -290,6 +303,7 @@ class FusionPlacer_BeamSearch(PlacerBase): is_reduced: A flag indicating if the current trace under optimization is a reduced version of a bigger trace (by common blocks reduction). cached_original_trace: A reference to the original trace if the optmization is performed on a reduced version. """ + def __init__( self, *, @@ -314,7 +328,7 @@ def __init__( self.executor_placement_options: ExecutorPlacementOptions = ExecutorPlacementOptions() # nvFuser compile options - if compile_data.compile_options.get('autotune_enable_nvfuser_all', False): + if compile_data.compile_options.get("autotune_enable_nvfuser_all", False): from thunder.executors.nvfuserex_impl import linear, _linear_check from thunder.executors.nvfuserex_impl import matmul, _matmul_check @@ -326,8 +340,7 @@ def __init__( } else: self.known_fusion_ex_compile_options: dict[str | Hashable, list[FusionCompileOptionsHelper]] = { - "nvfuser": [ - ] + "nvfuser": [] } # Transformer based models optimization @@ -384,22 +397,22 @@ def _best_runtime_and_memory_candidates(self, candidates): pair_cost_time = 0 pair_cost_mem = 0 t, m, _ = benchmark_trace(fw, iters=self.benchmark_iters) - log(f"Pair fw time: {t} ms, mem: {m/(2**30)} GB", level=LogLevel.DEBUG) + logger.debug(f"Pair fw time: {t} ms, mem: {m/(2**30)} GB") pair_cost_time = pair_cost_time + t pair_cost_mem = pair_cost_mem + m t, m, _ = benchmark_trace(bw, iters=self.benchmark_iters, fw_trace=fw) - log(f"Pair bw time: {t} ms, mem: {m/(2**30)} GB", level=LogLevel.DEBUG) + logger.debug(f"Pair bw time: {t} ms, mem: {m/(2**30)} GB") pair_cost_time = pair_cost_time + t pair_cost_mem = pair_cost_mem + m if pair_cost_time < min_value_time: best_pair_runtime = OutputCandidate(fw=fw, bw=bw, cost=pair_cost_time) - log(f"New best runtime pair (no remat):\n{best_pair_runtime}", level=LogLevel.DEBUG) + logger.debug(f"New best runtime pair (no remat):\n{best_pair_runtime}") min_value_time = pair_cost_time if pair_cost_mem < min_value_mem: best_pair_memory = OutputCandidate(fw=fw, bw=bw, cost=pair_cost_mem) - log(f"New best memory pair (no remat):\n{best_pair_memory}", level=LogLevel.DEBUG) + logger.debug(f"New best memory pair (no remat):\n{best_pair_memory}") min_value_mem = pair_cost_mem return best_pair_runtime, best_pair_memory @@ -435,7 +448,7 @@ def fw_benchmark(): ) # For forward trace we cache the best placement for both runtime and memory for the current Fusion executor (represented by label) for t, o in zip([trc_time, trc_mem], [compile_opt_time, compile_opt_mem]): - log(f'Caching fw candidate [compile option: {o.fusion_tag if o else "None"}]') + logger.info(f"Caching fw candidate [compile option: {o.fusion_tag if o else 'None'}]") self.cached_fw_traces.append( TraceCandidate( trace=t, compile_opt=o, label=label + "_enabled_" + o.fusion_tag if o is not None else label @@ -587,6 +600,7 @@ def _search(ex: FusionExecutor, executor_compile_option: FusionCompileOptionsHel ex: A fusion executor. executor_placement_options: Any compile option this executor might activate. """ + def _should_fuse_nvfuser(a: Node, b: Node): """ Fusable fn definition for nvFuser. @@ -595,6 +609,7 @@ def _should_fuse_nvfuser(a: Node, b: Node): a: First node. b: Second node. """ + def _can_fuse_node(n: Node): # if already merged, then node can be fused if len(n.group_bsyms) > 1: @@ -606,7 +621,6 @@ def _can_fuse_node(n: Node): return _can_fuse_node(a) and _can_fuse_node(b) - def _should_fuse_torchcompile(a: Node, b: Node): """ Fusable fn definition for torch.compile. @@ -615,6 +629,7 @@ def _should_fuse_torchcompile(a: Node, b: Node): a: First node. b: Second node. """ + def _can_fuse_node(n: Node): if len(n.group_bsyms) > 1: return True @@ -653,7 +668,7 @@ def match_bsym_executor(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor case "torchcompile": merge_fn = _should_fuse_torchcompile bound_symbol_groups = fuse_bound_symbols(self.trace, merge_fn) - log(f"Number of Fusion groups = {len(bound_symbol_groups)}", level=LogLevel.DEBUG) + logger.debug(f"Number of Fusion groups = {len(bound_symbol_groups)}") # Print fusion groups if requested # for id, group in enumerate(bound_symbol_groups): @@ -668,14 +683,12 @@ def match_bsym_executor(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor increasing_symbols = [] # Tuning starting point: iterate over all the groups. for group_id, group in enumerate(bound_symbol_groups): - log(f"Fusion group id: {group_id}", level=LogLevel.DEBUG) - log( - f"Fusion group start = [{group[0].output.name if hasattr(group[0].output, 'name') else 'unknown'} = {group[0].sym.name}]", - level=LogLevel.DEBUG, + logger.debug(f"Fusion group id: {group_id}") + logger.debug( + f"Fusion group start = [{group[0].output.name if hasattr(group[0].output, 'name') else 'unknown'} = {group[0].sym.name}]" ) - log( - f"Fusion group end = [{group[-1].output.name if hasattr(group[-1].output, 'name') else 'unknown'} = {group[-1].sym.name}]", - level=LogLevel.DEBUG, + logger.debug( + f"Fusion group end = [{group[-1].output.name if hasattr(group[-1].output, 'name') else 'unknown'} = {group[-1].sym.name}]" ) if group[0].sym.name != "return": @@ -684,9 +697,8 @@ def match_bsym_executor(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor # We assign to a Fusion executor only region with at least 2 elements. Otherwise let the best OperatorExecutor pick the symbol up if len(group) < 2: current_bsym = group[0] - log( - f"--> Single group: [{current_bsym.output.name if hasattr(current_bsym.output, 'name') else 'unknown'} = {current_bsym.sym.name}]", - level=LogLevel.DEBUG, + logger.debug( + f"--> Single group: [{current_bsym.output.name if hasattr(current_bsym.output, 'name') else 'unknown'} = {current_bsym.sym.name}]" ) # Filter out all possible candidates for the current symbol candidate_executors = [ @@ -710,7 +722,7 @@ def match_bsym_executor(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor ) continue else: - log(f"Available executors for single region:\n{candidate_executors}", level=LogLevel.DEBUG) + logger.debug(f"Available executors for single region:\n{candidate_executors}") # Define the standalone trace in order to benchmark this symbol subtrace = construct_trace()(current_bsym.sym, *current_bsym.args, **current_bsym.kwargs) @@ -729,14 +741,11 @@ def match_bsym_executor(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor from thunder.common import transform_for_execution subtrace_placed = transform_for_execution(subtrace, executors_list=[candidate])[-1] - log(f"Subtrace to benchmark single symbol:\n{subtrace_placed}", level=LogLevel.DEBUG) + logger.debug(f"Subtrace to benchmark single symbol:\n{subtrace_placed}") t, m, _ = benchmark_trace( subtrace_placed, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0] ) - log( - f"Operator excutor [{candidate.name}] candidate perf: {t} ms {m/(2**30)} GB", - level=LogLevel.DEBUG, - ) + logger.debug(f"Operator excutor [{candidate.name}] candidate perf: {t} ms {m/(2**30)} GB") # Update results if t < candidate_best_time.runtime: candidate_best_time = BenchmarkResult(time=t, index=i) @@ -748,13 +757,11 @@ def match_bsym_executor(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor f"Failed to get optimal single trace region candidate. Available candidates for {current_bsym.sym.name}:\n{candidate_executors}" ) - log( - f"Best time OperatorExecutor for single {current_bsym.sym.name}: {candidate_executors[candidate_best_time.index].name}", - level=LogLevel.DEBUG, + logger.debug( + f"Best time OperatorExecutor for single {current_bsym.sym.name}: {candidate_executors[candidate_best_time.index].name}" ) - log( - f"Best mem OperatorExecutor for single {current_bsym.sym.name}: {candidate_executors[candidate_best_mem.index].name}", - level=LogLevel.DEBUG, + logger.debug( + f"Best mem OperatorExecutor for single {current_bsym.sym.name}: {candidate_executors[candidate_best_mem.index].name}" ) match_bsym_executor(current_bsym, [dict_time_strat], candidate_executors[candidate_best_time.index]) @@ -781,7 +788,7 @@ def measure_and_update_result(): nonlocal best_keys_mem trc, keys, placements = get_placed_trace(dict_time_strat, increasing_symbols) cost, mem, _ = benchmark_trace(trc, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0]) - log(f"Placed trace (cost = {cost} ms, mem = {mem/(2**30)} GB)\n{trc}", level=LogLevel.DEBUG) + logger.debug(f"Placed trace (cost = {cost} ms, mem = {mem/(2**30)} GB)\n{trc}") if cost < best_res_time.runtime or (cost == best_res_time.runtime and mem < best_res_time.memory): best_res_time = BenchmarkResult(time=cost, memory=mem, trace=trc) best_placement_time = placements @@ -799,7 +806,7 @@ def measure_and_update_result(): for idx in range(0, len(group)): if group[idx].sym.name == "embedding_backward": last_embedding_idx = idx - log(f"last embedding idx: {last_embedding_idx}", level=LogLevel.DEBUG) + logger.debug(f"last embedding idx: {last_embedding_idx}") if last_embedding_idx != -1: # Until last_embedding_idx (included) assigned to current fusion ex for i in range(0, last_embedding_idx + 1, 1): @@ -844,13 +851,11 @@ def measure_and_update_result(): if best_placement_mem is None or best_keys_mem is None: raise AssertionError("Failed to get best placement") - log( - f"For group {group_id} best placement with time cost = {best_res_time.runtime} ms:\n{best_res_time.trace}", - level=LogLevel.DEBUG, + logger.debug( + f"For group {group_id} best placement with time cost = {best_res_time.runtime} ms:\n{best_res_time.trace}" ) - log( - f"For group {group_id} best placement with mem cost = {best_res_mem.memory / (2**30)} GB:\n{best_res_mem.trace}", - level=LogLevel.DEBUG, + logger.debug( + f"For group {group_id} best placement with mem cost = {best_res_mem.memory / (2**30)} GB:\n{best_res_mem.trace}" ) # Update our dict @@ -931,10 +936,9 @@ def measure_and_update_result(): ex: FusionExecutor for ex in self.fusion_executors: if ex.name not in self.fusion_strat_helper.supported_executors: - # log(f"Fusion operator not supported: {ex.name}. Skipping it.") continue - log(f"Searching best placement for fusion executor = {ex.name}", level=LogLevel.INFO) + logger.info(f"Searching best placement for fusion executor = {ex.name}") # We try to enable fusion specific compile options only for fw traces # Backward traces will follow fw traces options @@ -949,7 +953,7 @@ def measure_and_update_result(): # Currently we are enabling one compile option at the time as testing all the permutations might need too much time. # TODO: Consider implementing patterns based on the executor under investingation if ex_compile_opts: - log(f"{ex.name} compile options: {[option.fusion_tag for option in ex_compile_opts]}") + logger.info(f"{ex.name} compile options: {[option.fusion_tag for option in ex_compile_opts]}") for opt in ex_compile_opts: # Search only if we have an instruction related to the compile option op_in_trace: bool = operation_in_trace(trace=self.trace, op=opt.symbol_tag) @@ -957,6 +961,8 @@ def measure_and_update_result(): self.fusion_executors_saved_for_later.append(ex) wrap_fn_with_exeuctor_compile_option(opt, _search, ex, opt) + logger.info(f"Searching best placement for fusion executor = {ex.name} ended.") + """ ################################################## Public methods ################################################## """ @@ -982,17 +988,12 @@ def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): match self.trace_type: case TraceType.FW: - log( - f"New forward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}", level=LogLevel.INFO - ) + logger.info(f"New forward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}") # TODO (matteochen): support bw trace optimization even though with no fw traces cached (computational trace?) case TraceType.BW: if not self.cached_fw_traces: raise AssertionError("Can not optimize backward traces before forward traces") - log( - f"New backward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}", - level=LogLevel.INFO, - ) + logger.info(f"New backward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}") def optimize(self): from thunder.core.transform_common import dce @@ -1008,7 +1009,9 @@ def _optimize(): cd = get_compile_data() # Check if common blocks optimization is requested - optimize_common_blocks = False if cd is None else cd.compile_options.get("autotune_optimize_common_blocks", False) + optimize_common_blocks = ( + False if cd is None else cd.compile_options.get("autotune_optimize_common_blocks", False) + ) optimize_common_blocks_min_size = ( -1 if cd is None else cd.compile_options.get("autotune_optimize_common_blocks_min_size", -1) ) @@ -1019,19 +1022,15 @@ def _optimize(): ) # print(common_trace_blocks) if len(common_trace_blocks) >= 2 and optimize_common_blocks: - log(f"Common blocks found {common_trace_blocks}", level=LogLevel.INFO) + logger.info(f"Running optimization with common blocks reduction. Found {common_trace_blocks}") reduced_trace = reduce_common_trace_blocks(trace=self.trace, common_blocks_in=common_trace_blocks) - log( - f"Operating on reduced trace (by cutting common transformer blocks):\n{reduced_trace}", - level=LogLevel.INFO, - ) + logger.info(f"Operating on reduced trace (by cutting common transformer blocks):\n{reduced_trace}") self.is_reduced = True self.cached_original_trace = self.trace self.trace = reduced_trace else: - log( - "Optimizing the whole trace directly. No common transformer block optimization will be applied.", - level=LogLevel.INFO, + logger.info( + "Optimizing the whole trace directly. No common transformer block optimization will be applied." ) # This performs executor tuning @@ -1115,18 +1114,17 @@ def _optimize(): # Now we can generate backward solutions from the cached fw traces for fw_trace_candidate in self.cached_fw_traces: - log(f"Backward optimization with fw from {fw_trace_candidate.label}", level=LogLevel.INFO) + logger.info(f"Backward optimization with fw from {fw_trace_candidate.label}") # Restore the original bw trace self.trace = from_trace(cached_self_trace) self.trace.bound_symbols = list(cached_self_trace.bound_symbols) # Set the current active cached forward trace context - log( - f'Current fw cached ctx:\n{fw_trace_candidate.trace}\nOptions: {fw_trace_candidate.compile_opt.fusion_tag if fw_trace_candidate.compile_opt is not None else "None"}', - level=LogLevel.DEBUG, + logger.info( + f"Current fw cached ctx:\n{fw_trace_candidate.trace}\nOptions: {fw_trace_candidate.compile_opt.fusion_tag if fw_trace_candidate.compile_opt is not None else 'None'}" ) self.active_fw_trace_ctx = fw_trace_candidate.trace, fw_trace_candidate.compile_opt - log(f"Input bw trace:\n{self.trace}", level=LogLevel.DEBUG) + logger.debug(f"Input bw trace:\n{self.trace}") self.trace = update_bw_from_forward_optimization(fw=fw_trace_candidate.trace, bw=self.trace) @@ -1158,6 +1156,7 @@ class BackendOptimizer: Attributes: optimizer: An optimizer instance based on the configurations. """ + def __init__( self, *, @@ -1180,12 +1179,7 @@ def __init__( compile_data=compile_data, ) - log("Executors:", level=LogLevel.INFO) - for e in priority_executors: - log( - f"{e.name} -> is operator = {isinstance(e, OperatorExecutor)}, is fusion = {isinstance(e, FusionExecutor)}", - level=LogLevel.INFO, - ) + logger.info(f"Executors: {[ex.name for ex in priority_executors]}") def optimize(self): """ diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 71faf60659..1464fae1e4 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -25,32 +25,11 @@ from enum import Enum -class LogLevel(Enum): - """ - Represents a log level. - """ - DEBUG = 0 - INFO = 1 - - -log_level: LogLevel = LogLevel.INFO - - -def log(what: str, level: LogLevel = LogLevel.INFO): - """ - Conditionally print to stdout. - - Args: - what: The content to print. - level: Tuning parameter to print the content only if allowed by the configuration. - """ - if log_level == LogLevel.DEBUG or log_level == level: - print(f"================================================================================ Autotune: {what}") - class TraceType(Enum): """ Represents the nature of a trace, if forward (computational) or backward. """ + FW = 0 BW = 1 @@ -66,6 +45,7 @@ class BenchmarkResult: label: A generic label. index: A generic index in a sequence. """ + def __init__( self, *, @@ -86,6 +66,7 @@ class OptimizerType(Enum): """ Represents the autotuner target. """ + MEMORY = 0 RUNTIME = 1 @@ -99,6 +80,7 @@ def sequence_hash(s: Sequence) -> str: Args: s: A sequence to hash. """ + def rec(s) -> str: name = "[" for e in s: @@ -177,6 +159,7 @@ def get_not_used_intermediate_outsputs(trace_in: TraceCtx) -> list[Proxy]: Args: in_trace: A generic trace. """ + def is_in_sequence(seq: Sequence[Any], t: Proxy): for e in seq: if hasattr(e, "name") and hasattr(t, "name") and e.name == t.name: @@ -203,8 +186,9 @@ def unpack_output(out) -> Sequence[Proxy]: break if not f: ans.append(e) + from thunder.backend_optimizer.optimizer import logger - log(f"Returning not used proxies: {[p.name if hasattr(p, 'name') else p for p in ans ]}", level=LogLevel.DEBUG) + logger.debug(f"Returning not used proxies: {[p.name if hasattr(p, 'name') else p for p in ans ]}") return ans @@ -674,6 +658,7 @@ def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causa t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) except Exception: import traceback + traceback.print_exc() finally: reset_tracectx(trace_tok) @@ -799,7 +784,9 @@ def update_compile_options_executor_list_after_fw_bw_split() -> None: assert cd # Get all the possible options that the vjp_optimization pass will use - options: dict = get_fw_bw_split_backends_options(autotune_enable_te=cd.compile_options.get('autotune_enable_te', False)) + options: dict = get_fw_bw_split_backends_options( + autotune_enable_te=cd.compile_options.get("autotune_enable_te", False) + ) executors_list = list(cd.executors_list) # Remove all the initial options @@ -1049,7 +1036,7 @@ def repetead_trace_blocks( return [] if known_points is not None: - raise RuntimeError('known_points research is not supported.') + raise RuntimeError("known_points research is not supported.") symbols = [ s @@ -1158,7 +1145,7 @@ def _range_seen(index: int, s: set): if max_lcs < min_block_size: return [] - print(f'\n\nMax lcs {max_lcs}') + print(f"\n\nMax lcs {max_lcs}") # print(res) for r in res: @@ -1187,6 +1174,7 @@ def _regions_between_blocks(trace: TraceCtx, common_blocks: list[tuple]) -> int: common_blocks: A list containig the common blocks for the given trace. """ + def _assert_args(seq_a: Sequence, seq_b: Sequence): assert len(seq_a) == len(seq_b) for a, b in zip(seq_a, seq_b): @@ -1246,6 +1234,7 @@ def reduce_common_trace_blocks( common_blocks_in: A previously computed common block pattern. skip_between_blocks: A flag to control if gaps between common blocks should be included in the output trace or not. See _regions_between_blocks. """ + def _exclude(blocks: list[tuple[int, int]], index: int, black_list: set): # Exclude if the index is in a repeated block for block in blocks: @@ -1352,6 +1341,7 @@ def _correct_bsym(bsym: BoundSymbol) -> BoundSymbol: bound_symbols[-1] = bsym # Bw trace else: + def _returned(seq: Sequence) -> tuple: ret = [] for e in seq: @@ -1440,6 +1430,7 @@ def map_executors_from_reduced_trace_to_complete_trace( return complete_trace_executors + # This fn is used before compile data being set, rely on kwargs def get_fw_bw_split_backends_options(bsym: BoundSymbol | None = None, **kwargs) -> list | dict: """ @@ -1469,4 +1460,3 @@ def get_fw_bw_split_backends_options(bsym: BoundSymbol | None = None, **kwargs) } return options.get(bsym.sym.name, []) if bsym else options - diff --git a/thunder/benchmarks/utils.py b/thunder/benchmarks/utils.py index 2a541e985f..c6977eca5d 100644 --- a/thunder/benchmarks/utils.py +++ b/thunder/benchmarks/utils.py @@ -4,6 +4,7 @@ warm_up_iters = 50 + class SplitFwBwBenchmarkUtils: """ Represents a benchmark result container. @@ -15,6 +16,7 @@ class SplitFwBwBenchmarkUtils: bw_fn: Storage for a backward trace. executor: An OperatorExecutor. """ + def __init__( self, *, cost: float = float("inf"), fw_fn: Callable | None = None, bw_fn: Callable | None = None, executor=None ) -> None: diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index 2e34d26a7b..39cb278189 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -221,9 +221,11 @@ def bw_fn(*args, **kwargs): return fw_fn, bw_fn # We have a backend else: - from thunder.backend_optimizer.utils import get_fw_bw_split_backends_options + from thunder.backend_optimizer.optimizer import OptimizerType + from thunder.backend_optimizer.optimizer import logger from thunder.backend_optimizer.utils import benchmark_trace - from thunder.backend_optimizer.utils import log, LogLevel + from thunder.backend_optimizer.utils import get_fw_bw_split_backends_options + from thunder.benchmarks.utils import SplitFwBwBenchmarkUtils # In order define this unique trace region we need an unique id key = (bsym.sym, Executor(f"{id(bsym)}-autotuned"), subkey := _make_cache_key(bsym.args, bsym.kwargs)) @@ -245,8 +247,6 @@ def bw_fn(*args, **kwargs): cached_executors_list = list(cd.executors_list) # Retrieve all the executors which are requested to be used requested_executors_list_for_bsym = [ex for ex in cached_executors_list if ex in backends] - from thunder.benchmarks.utils import SplitFwBwBenchmarkUtils - from thunder.backend_optimizer.optimizer import OptimizerType best = SplitFwBwBenchmarkUtils() @@ -263,9 +263,9 @@ def bw_fn(*args, **kwargs): # Restrict the search space backends = list(requested_executors_list_for_bsym) - log(f"Search space for {bsym.sym.name}: {backends}", level=LogLevel.INFO) + logger.info(f"Search space for {bsym.sym.name}: {backends}") for b in backends: - log(f"Benchmarking executor {b.name} for {bsym.sym.name}", level=LogLevel.INFO) + logger.info(f"Benchmarking executor {b.name} for {bsym.sym.name}") # Let downstream fn to pick up this requested_executors_list_for_bsym.remove(b) requested_executors_list_for_bsym.insert(0, b) @@ -283,10 +283,7 @@ def bw_fn(*args, **kwargs): assert best.cost != float("inf") - log( - f"Best executor for symbol [{bsym.output.name} = {bsym.sym.name}]: {best.executor.name}", - level=LogLevel.INFO, - ) + logger.info(f"Best executor for symbol [{bsym.output.name} = {bsym.sym.name}]: {best.executor.name}") # Cache the bsym result for common trace's common block reductions if bsym.sym.name in ['linear', 'scaled_dot_product_attention'] and optmimizer_common_transformer_block: From e494514af3398f9bc7bbd5a3c3a9a54a92544c77 Mon Sep 17 00:00:00 2001 From: matteochen Date: Sun, 1 Sep 2024 19:54:15 +0300 Subject: [PATCH 135/171] Moved cache to compile data --- thunder/__init__.py | 3 +++ thunder/common.py | 2 ++ thunder/core/vjp_utils.py | 18 +++++++++++------- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 037a01d334..cce32c1524 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -699,6 +699,9 @@ def get_computation_and_inputs(*args, **kwargs): # Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces # by split_forward_backward + # Reset the cache for the next compilation + cd.autotuner_bsym_with_gradfn_executor_cache = {} + if backward_trc is None: ## EPILOGUE and TRANSFORMS should not mix... # applies transforms diff --git a/thunder/common.py b/thunder/common.py index 3a40c6a56a..283b5b347d 100644 --- a/thunder/common.py +++ b/thunder/common.py @@ -280,6 +280,8 @@ def __init__( self.additional_return_names = None self.num_constant_args = 0 + self.autotuner_bsym_with_gradfn_executor_cache: dict = {} + assert disable_preprocessing, "please use thunder.compile if you need preprocessing" diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index 39cb278189..a22b23214e 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -17,7 +17,6 @@ _cache = {} -_autototune_common_bsym_in_blocks_cache = {} def disable_caching_split_forward_and_backward(fn): @@ -253,17 +252,21 @@ def bw_fn(*args, **kwargs): # Do we have a common transformer block optimization enabled? # If yes we have to restrict the same executor on every bsym # in the transformer block (e.g. every scaled_dot_product in every transformer block will have the same executor - # as they are expected to work on same input size and shapes). + # as they are expected to work on same input size, shape and dtype). optmimizer_common_transformer_block = cd.compile_options.get('autotune_optimize_common_blocks', False) # The generated hash will rely on the operation, input args metadata and output metadata h = symbol_hash(bsym) - if h in _autototune_common_bsym_in_blocks_cache and optmimizer_common_transformer_block: - best = _autototune_common_bsym_in_blocks_cache[h] + # Recover the cache stored in the compile data + autotuner_bsym_with_gradfn_executor_cache = cd.autotuner_bsym_with_gradfn_executor_cache + + # Run the search only if not already visited before + if h in autotuner_bsym_with_gradfn_executor_cache and optmimizer_common_transformer_block: + best = autotuner_bsym_with_gradfn_executor_cache[h] else: # Restrict the search space backends = list(requested_executors_list_for_bsym) - logger.info(f"Search space for {bsym.sym.name}: {backends}") + logger.info(f"Search space for bsym {bsym.sym.name}: {backends}") for b in backends: logger.info(f"Benchmarking executor {b.name} for {bsym.sym.name}") # Let downstream fn to pick up this @@ -286,8 +289,9 @@ def bw_fn(*args, **kwargs): logger.info(f"Best executor for symbol [{bsym.output.name} = {bsym.sym.name}]: {best.executor.name}") # Cache the bsym result for common trace's common block reductions - if bsym.sym.name in ['linear', 'scaled_dot_product_attention'] and optmimizer_common_transformer_block: - _autototune_common_bsym_in_blocks_cache[h] = best + # At this stage we are tuning trace regions for these symbols name: linear and scaled_dot_product_attention + if bsym.sym.name in get_fw_bw_split_backends_options().keys() and optmimizer_common_transformer_block: + autotuner_bsym_with_gradfn_executor_cache[h] = best # Update the compile options cd.compile_options["autotune_executors_placed_by_fw_bw_split"].add(best.executor) From 00bff88afc7294c62d7dead30b6f6a5ef48f7f4a Mon Sep 17 00:00:00 2001 From: matteochen Date: Sun, 1 Sep 2024 19:54:48 +0300 Subject: [PATCH 136/171] Updated logger --- thunder/backend_optimizer/optimizer.py | 6 +++--- thunder/backend_optimizer/utils.py | 7 +++---- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index eff21819c2..bf6b6429e2 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -16,8 +16,8 @@ from thunder.backend_optimizer.utils import benchmark_trace, BenchmarkResult, OptimizerType, TraceType import logging -logging.basicConfig(level=logging.INFO, format="{name} {message}", style="{") -logger = logging.getLogger("Autotuner: ") +logging.basicConfig(level=logging.INFO, format="[{name}]: {message}", style="{") +logger = logging.getLogger("Thunder Autotuner") class OptimizationAlgorithm(Enum): @@ -1022,7 +1022,7 @@ def _optimize(): ) # print(common_trace_blocks) if len(common_trace_blocks) >= 2 and optimize_common_blocks: - logger.info(f"Running optimization with common blocks reduction. Found {common_trace_blocks}") + logger.info(f"Running optimization with common blocks reduction. Found block indices in trace: {common_trace_blocks}") reduced_trace = reduce_common_trace_blocks(trace=self.trace, common_blocks_in=common_trace_blocks) logger.info(f"Operating on reduced trace (by cutting common transformer blocks):\n{reduced_trace}") self.is_reduced = True diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 1464fae1e4..cf301929d9 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -1145,11 +1145,10 @@ def _range_seen(index: int, s: set): if max_lcs < min_block_size: return [] - print(f"\n\nMax lcs {max_lcs}") - # print(res) + from thunder.backend_optimizer.optimizer import logger + logger.debug(f"Max block lcs fouund: {max_lcs}") + logger.debug(f"{[(symbols[r[0]].output.name, symbols[r[1]].output.name) for r in res]}") - for r in res: - print(symbols[r[0]].output.name, symbols[r[1]].output.name) return [ (original_map_indexes[symbols[t[0]].output.name], original_map_indexes[symbols[t[1]].output.name]) for t in res ] From 2d8eacedbbfaefff5715dd349a2f66851d9249fa Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 5 Sep 2024 10:49:34 +0200 Subject: [PATCH 137/171] Trace configuration dumps and restore (#20) --- thunder/backend_optimizer/optimizer.py | 286 ++++++++++++++++++------- thunder/backend_optimizer/utils.py | 214 ++++++++++++++++-- thunder/core/vjp_utils.py | 2 +- thunder/executors/passes.py | 14 +- thunder/tests/test_autotuner.py | 65 +++++- 5 files changed, 476 insertions(+), 105 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index bf6b6429e2..d6b496616e 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -1,9 +1,11 @@ from collections.abc import Callable, Sequence from enum import Enum from thunder.backend_optimizer.utils import ( + dump_traces_placement, map_executors_from_reduced_trace_to_complete_trace, operation_in_trace, wrap_fn_with_exeuctor_compile_option, + apply_results_from_file, ) from thunder.core.compile_data import get_compile_data from thunder.core.prims import PrimIDs @@ -48,23 +50,39 @@ def __init__(self, fusion_tag: str, symbol_tag: str, id: PrimIDs, impl: Callable self.checker: Callable = checker +class FusionExecutorsPlacementCtx: + """ + Represents a executor placement context. + + Attributes: + placement: A list of executors. + compile_options: Any compile options being used for the fusion executor contained in the placement. + """ + + def __init__(self, *, placement: list, compile_options: FusionCompileOptionsHelper | None = None) -> None: + self.placement: list = placement + self.compile_options: FusionCompileOptionsHelper | None = compile_options + + class TraceCandidate: """ Represents an optimal trace candidate. Attributes: trace: The candidate trace. - compile_opt: Any compile options used for the current candidate. - label: A generic label. - symbol_tag: The symbol name - id: The symbol id. - impl: A callable implementation. - checker: A callable checker. + ctx: Trace's placement context. + label: A generic label to identify this candidate. """ - def __init__(self, *, trace: TraceCtx, compile_opt: FusionCompileOptionsHelper | None = None, label: str) -> None: + def __init__( + self, + *, + trace: TraceCtx, + ctx: FusionExecutorsPlacementCtx, + label: str, + ) -> None: self.trace: TraceCtx = trace - self.compile_opt: FusionCompileOptionsHelper | None = compile_opt + self.ctx: FusionExecutorsPlacementCtx = ctx self.label: str = label @@ -75,21 +93,21 @@ class TraceCandidates: Attributes: best_time: The trace with the optimal runtime. best_mem: The trace with the optimal peak memory consumption. - compile_opt_time: Any compile options used for a fusion executor regarding the first trace. - compile_opt_mem: Any compile options used for a fusion executor regarding the second trace. + placement_ctx_time: Trace placement context: exeuctors and any applied fusion compile options. + placement_ctx_mem: Trace placement context: exeuctors and any applied fusion compile options. """ def __init__( self, best_time: TraceCtx | None = None, best_mem: TraceCtx | None = None, - compile_opt_time: FusionCompileOptionsHelper | None = None, - compile_opt_mem: FusionCompileOptionsHelper | None = None, + placement_ctx_time: FusionExecutorsPlacementCtx | None = None, + placement_ctx_mem: FusionExecutorsPlacementCtx | None = None, ) -> None: self.best_time: TraceCtx | None = best_time self.best_mem: TraceCtx | None = best_mem - self.compile_opt_time: FusionCompileOptionsHelper | None = compile_opt_time - self.compile_opt_mem: FusionCompileOptionsHelper | None = compile_opt_mem + self.placement_ctx_time: FusionExecutorsPlacementCtx | None = placement_ctx_time + self.placement_ctx_mem: FusionExecutorsPlacementCtx | None = placement_ctx_mem def __repr__(self) -> str: """ @@ -103,29 +121,45 @@ def is_set(self) -> bool: """ return False if self.best_time is None or self.best_mem is None else True - def attach_best_time_candidate(self, trace: TraceCtx): + def attach_best_time_candidate(self, trace: TraceCtx, ctx: FusionExecutorsPlacementCtx | None = None): """ Attach a new best time trace result. + + Args: + trace: The trace to assign. + ctx: The trace placement context. """ self.best_time = trace + self.placement_ctx_time = ctx - def attach_best_mem_candidate(self, trace: TraceCtx): + def attach_best_mem_candidate(self, trace: TraceCtx, ctx: FusionExecutorsPlacementCtx | None = None): """ Attach a new best memory trace result. + + Args: + trace: The trace to assign. + ctx: The trace placement context. """ self.best_mem = trace + self.placement_ctx_mem = ctx - def iterable(self) -> tuple[TraceCtx | None, TraceCtx | None]: + def iterable(self) -> tuple[tuple, tuple]: """ - Returns an iterable object over the traces contained in the current object. + Returns an iterable object over the traces paired with their contexts. + """ + return (self.best_time, self.placement_ctx_time), (self.best_mem, self.placement_ctx_mem) + + def trace_ctx_iterable(self) -> tuple[TraceCtx | None, TraceCtx | None]: + """ + Returns an iterable object over the traces. """ return self.best_time, self.best_mem - def compile_opt_iterables(self) -> tuple[FusionCompileOptionsHelper | None, FusionCompileOptionsHelper | None]: + def placement_ctx_iterable(self) -> tuple[FusionExecutorsPlacementCtx | None, FusionExecutorsPlacementCtx | None]: """ - Returns an iterable object over the compile options used in the traces contained in the current object. + Returns an iterable object over the placement contexts. """ - return self.compile_opt_time, self.compile_opt_mem + return self.placement_ctx_time, self.placement_ctx_mem class OutputCandidate: @@ -135,17 +169,31 @@ class OutputCandidate: Attributes: fw: The forward trace. bw: The backward trace. - compile_opt: Any compile options being used for a fusion executor. + executors_fw: The forward trace regions' executors + executors_bw: The backward trace regions' executors + compile_opt: Any compile options being used for a fusion executor in the forward trace. tot_cost: The total cost to execute the pair (ms for a time strategy and GB for a memory strategy). + apply_remat: If rematerialization has been applied. """ def __init__( - self, *, fw: TraceCtx, bw: TraceCtx, compile_opt: FusionCompileOptionsHelper | None = None, cost: float = 0.0 + self, + *, + fw: TraceCtx, + bw: TraceCtx, + executors_fw: list[Executor], + executors_bw: list[Executor], + compile_opt: FusionCompileOptionsHelper | None = None, + cost: float = 0.0, + apply_remat: bool = False, ) -> None: self.fw: TraceCtx = fw self.bw: TraceCtx = bw + self.executors_fw: list[Executor] = executors_fw + self.executors_bw: list[Executor] = executors_bw self.compile_opt: FusionCompileOptionsHelper | None = compile_opt self.tot_cost: float = cost + self.apply_remat = apply_remat def __repr__(self) -> str: """ @@ -168,26 +216,12 @@ class FusionStratHelper: def __init__(self) -> None: self.supported_executors: set = set(["nvfuser", "torchcompile"]) - self.optimized_traces_mem: list[dict[str | Hashable, tuple[TraceCtx, FusionCompileOptionsHelper | None]]] = [] + self.optimized_traces_mem: list[dict[str | Hashable, tuple[TraceCtx, FusionExecutorsPlacementCtx | None]]] = [] self.optimized_traces_mem_benchmark_only: list[dict[str | Hashable, TraceCtx]] = [] - self.optimized_traces_time: list[dict[str | Hashable, tuple[TraceCtx, FusionCompileOptionsHelper | None]]] = [] + self.optimized_traces_time: list[dict[str | Hashable, tuple[TraceCtx, FusionExecutorsPlacementCtx | None]]] = [] self.optimized_traces_time_benchmark_only: list[dict[str | Hashable, TraceCtx]] = [] -class FusionExecutorsPlacementCtx: - """ - Represents a executor placement context. - - Attributes: - placement: A list of executors. - compile_options: Any compile options being used for the fusion executor contained in the placement. - """ - - def __init__(self, *, placement: list, compile_options: FusionCompileOptionsHelper | None = None) -> None: - self.placement: list = placement - self.compile_options: FusionCompileOptionsHelper | None = compile_options - - class ExecutorPlacementOptions: """ Represents an aggregate placement options for executors combining those that targets peak memory consumption and those for total compute time. @@ -216,8 +250,10 @@ class PlacerBase: log_file_name: The output log file name if generated. produce_log: A tuning parameter to control log file generation. optimizer_type: The optimization target. - active_fw_trace_ctx: An active set forward trace (in the object scope). - cached_fw_traces: Cached forward traces. + active_fw_trace_ctx: An active forward trace set to optimize backward. + cached_fw_traces: Cached optimized forward traces. + cached_computational_trace: Original computational trace + cached_computational_backward_trace: Original computational backward trace bw_trace_candidates: An instance of trace candidates. best_pair_runtime: A final traace pair targetting the compute time. best_pair_memory: A final traace pair targetting the peak memory consumption. @@ -251,8 +287,10 @@ def __init__( self.optimizer_type: OptimizerType = optimizer_type - self.active_fw_trace_ctx: tuple[TraceCtx | None, FusionCompileOptionsHelper | None] = None, None + self.active_fw_trace_ctx: tuple[TraceCtx | None, FusionExecutorsPlacementCtx | None] = None, None self.cached_fw_traces: list[TraceCandidate] = [] + self.cached_computational_trace: TraceCtx = TraceCtx() + self.cached_computational_backward_trace: TraceCtx = TraceCtx() self.bw_trace_candidates: TraceCandidates = TraceCandidates() self.out_traces_candidates: list[OutputCandidate] = [] self.best_pair_runtime: OutputCandidate @@ -270,7 +308,7 @@ def optimize(self): """ pass - def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): + def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType, apply_dce=True): """ Attach a new trace for executors optimization. @@ -353,7 +391,7 @@ def __init__( ################################################## Internal methods ################################################## """ - def _best_runtime_and_memory_candidates(self, candidates): + def _best_runtime_and_memory_candidates(self, candidates: Sequence[OutputCandidate]): """ Retrive the best compute time and peak memory consumption trace pairs. @@ -376,9 +414,11 @@ def _best_runtime_and_memory_candidates(self, candidates): else: remat_fw, remat_bw = rematerialize_forward_and_backward(pair.fw, pair.bw) # Create pair final options by applying final optimizations: cudagraphs and rematerialization - pair_options: list[tuple[TraceCtx, TraceCtx]] = [ - (pair.fw, pair.bw), - (remat_fw, remat_bw), + pair_options: list[ + tuple[TraceCtx, TraceCtx, FusionCompileOptionsHelper | None, list[Executor], list[Executor], bool] + ] = [ + (pair.fw, pair.bw, pair.compile_opt, pair.executors_fw, pair.executors_bw, False), + (remat_fw, remat_bw, pair.compile_opt, pair.executors_fw, pair.executors_bw, True), ] # if self.compile_data.use_cudagraphs is not None and self.compile_data.use_cudagraphs: # from thunder.executors.cudagraphex import cudagraphex @@ -391,8 +431,7 @@ def _best_runtime_and_memory_candidates(self, candidates): # ) # Select the best options for pair_option in pair_options: - fw = pair_option[0] - bw = pair_option[1] + fw, bw, compile_opt, executors_fw, executors_bw, remat_applied = pair_option pair_cost_time = 0 pair_cost_mem = 0 @@ -406,12 +445,28 @@ def _best_runtime_and_memory_candidates(self, candidates): pair_cost_mem = pair_cost_mem + m if pair_cost_time < min_value_time: - best_pair_runtime = OutputCandidate(fw=fw, bw=bw, cost=pair_cost_time) + best_pair_runtime = OutputCandidate( + fw=fw, + bw=bw, + compile_opt=compile_opt, + executors_fw=executors_fw, + executors_bw=executors_bw, + cost=pair_cost_time, + apply_remat=remat_applied, + ) logger.debug(f"New best runtime pair (no remat):\n{best_pair_runtime}") min_value_time = pair_cost_time if pair_cost_mem < min_value_mem: - best_pair_memory = OutputCandidate(fw=fw, bw=bw, cost=pair_cost_mem) + best_pair_memory = OutputCandidate( + fw=fw, + bw=bw, + compile_opt=compile_opt, + executors_fw=executors_fw, + executors_bw=executors_bw, + cost=pair_cost_mem, + apply_remat=remat_applied, + ) logger.debug(f"New best memory pair (no remat):\n{best_pair_memory}") min_value_mem = pair_cost_mem @@ -434,8 +489,12 @@ def fw_benchmark(): for pair_time, pair_mem in zip( self.fusion_strat_helper.optimized_traces_time, self.fusion_strat_helper.optimized_traces_mem ): - trc_time, compile_opt_time = list(pair_time.values())[0] - trc_mem, compile_opt_mem = list(pair_mem.values())[0] + placement_ctx_time: FusionExecutorsPlacementCtx + placement_ctx_mem: FusionExecutorsPlacementCtx + trc_time: TraceCtx + trc_mem: TraceCtx + trc_time, placement_ctx_time = list(pair_time.values())[0] + trc_mem, placement_ctx_mem = list(pair_mem.values())[0] label = list(pair_time.keys())[0] # TODO (matteochen): remove the benchmark here as will done later on the bw pass c, m, _ = benchmark_trace(trc_time, self.benchmark_iters) @@ -447,13 +506,20 @@ def fw_benchmark(): f"Trace name = [{label}] - Target: MEM - Mem = {m / (2**30)} GB - Time = {c} ms\n{trc_mem}\n\n" ) # For forward trace we cache the best placement for both runtime and memory for the current Fusion executor (represented by label) - for t, o in zip([trc_time, trc_mem], [compile_opt_time, compile_opt_mem]): - logger.info(f"Caching fw candidate [compile option: {o.fusion_tag if o else 'None'}]") + for t, ctx in zip([trc_time, trc_mem], [placement_ctx_time, placement_ctx_mem]): + logger.info( + f"Caching fw candidate [compile option: {ctx.compile_options.fusion_tag if ctx.compile_options else 'None'}]" + ) self.cached_fw_traces.append( TraceCandidate( - trace=t, compile_opt=o, label=label + "_enabled_" + o.fusion_tag if o is not None else label + trace=t, + ctx=ctx, + label=(label + "_enabled_" + ctx.compile_options.fusion_tag) + if ctx.compile_options is not None + else label, ) ) + self.cached_computational_trace = self.trace def bw_benchmark(): time_result = BenchmarkResult() @@ -488,18 +554,29 @@ def bw_benchmark(): # Here we have to recover the traces without the pass through remat in order to be compliant # with thunder flow as we might have request for no remat. - trc = list(self.fusion_strat_helper.optimized_traces_time[time_result.index].values())[0][0] - self.bw_trace_candidates.attach_best_time_candidate(trc) - trc = list(self.fusion_strat_helper.optimized_traces_mem[memory_result.index].values())[0][0] - self.bw_trace_candidates.attach_best_mem_candidate(trc) + trc, placement_ctx = list(self.fusion_strat_helper.optimized_traces_time[time_result.index].values())[0] + self.bw_trace_candidates.attach_best_time_candidate(trc, placement_ctx) + trc, placement_ctx = list(self.fusion_strat_helper.optimized_traces_mem[memory_result.index].values())[0] + self.bw_trace_candidates.attach_best_mem_candidate(trc, placement_ctx) # Now, finally build the pair fw and bw traces # The current fw trace is set by the caller and we take it as is. All current bw traces optimizations are made with the fw trace set by the caller. + + assert self.active_fw_trace_ctx[0] is not None and self.active_fw_trace_ctx[1] is not None + for bw in self.bw_trace_candidates.iterable(): self.out_traces_candidates.append( - OutputCandidate(fw=self.active_fw_trace_ctx[0], bw=bw, compile_opt=self.active_fw_trace_ctx[1]) + OutputCandidate( + fw=self.active_fw_trace_ctx[0], + bw=bw[0], + executors_fw=self.active_fw_trace_ctx[1].placement, + executors_bw=bw[1].placement, + compile_opt=self.active_fw_trace_ctx[1].compile_options, + ) ) + self.cached_computational_backward_trace = self.trace + match self.trace_type: case TraceType.FW: fw_benchmark() @@ -973,26 +1050,40 @@ def get_optimal_fw_traces(self) -> Sequence[TraceCtx]: return [candidate.trace for candidate in self.cached_fw_traces] def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: + restore_file = self.compile_data.compile_options.get("autotune_restore_configuration", "") + + # We apply the dce transform as it will be applied to the cached traces during the past optimization + # (dce has been applied to the traces saved in the configuration). + if restore_file: + from thunder.core.transforms import dce + + fw_extrace, bw_extrace = apply_results_from_file( + fw_trace=dce(self.cached_computational_trace), + bw_trace=dce(self.cached_computational_backward_trace), + file=restore_file, + ) + return fw_extrace, bw_extrace return ( (self.best_pair_runtime.fw, self.best_pair_runtime.bw) if self.optimizer_type == OptimizerType.RUNTIME else (self.best_pair_memory.fw, self.best_pair_memory.bw) ) - def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): + def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType, apply_dce: bool = True): from thunder.core.transform_common import dce self.trace_type = trace_type - # dce for the backward trace will be passed afterwards - self.trace: TraceCtx = dce(trace) if trace_type == TraceType.FW else trace + # dce for the backward trace will be passed afterwards as we might modify it before + self.trace: TraceCtx = dce(trace) if apply_dce else trace match self.trace_type: case TraceType.FW: logger.info(f"New forward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}") # TODO (matteochen): support bw trace optimization even though with no fw traces cached (computational trace?) case TraceType.BW: - if not self.cached_fw_traces: - raise AssertionError("Can not optimize backward traces before forward traces") + if not self.compile_data.compile_options.get("autotune_restore_configuration", ""): + if not self.cached_fw_traces: + raise AssertionError("Can not optimize backward traces before forward traces") logger.info(f"New backward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}") def optimize(self): @@ -1020,9 +1111,11 @@ def _optimize(): common_trace_blocks = repetead_trace_blocks( trace=self.trace, min_block_size=optimize_common_blocks_min_size if optimize_common_blocks else -1 ) - # print(common_trace_blocks) + # A valid block is defined with at least 2 trace regions if len(common_trace_blocks) >= 2 and optimize_common_blocks: - logger.info(f"Running optimization with common blocks reduction. Found block indices in trace: {common_trace_blocks}") + logger.info( + f"Running optimization with common blocks reduction. Found block indices in trace: {common_trace_blocks}" + ) reduced_trace = reduce_common_trace_blocks(trace=self.trace, common_blocks_in=common_trace_blocks) logger.info(f"Operating on reduced trace (by cutting common transformer blocks):\n{reduced_trace}") self.is_reduced = True @@ -1076,9 +1169,7 @@ def _optimize(): compile_data=self.compile_data, fusion_executor_compile_options_to_activate=placement_ctx.compile_options, ) - self.fusion_strat_helper.optimized_traces_time.append( - {ex.name: tuple([trc, placement_ctx.compile_options])} - ) + self.fusion_strat_helper.optimized_traces_time.append({ex.name: tuple([trc, placement_ctx])}) for placement_ctx, ex in zip( self.executor_placement_options.placement_options_mem, self.fusion_executors_saved_for_later ): @@ -1090,21 +1181,35 @@ def _optimize(): compile_data=self.compile_data, fusion_executor_compile_options_to_activate=placement_ctx.compile_options, ) - self.fusion_strat_helper.optimized_traces_mem.append( - {ex.name: tuple([trc, placement_ctx.compile_options])} - ) + self.fusion_strat_helper.optimized_traces_mem.append({ex.name: tuple([trc, placement_ctx])}) # Filter out the optimal candidates for the current serach iteration self._filter_candidates() + restore_file_name = self.compile_data.compile_options.get("autotune_restore_configuration", "") + match self.trace_type: case TraceType.FW: + # Perform optimization only if we don't restore it from a past configuration + if restore_file_name: + self.cached_computational_trace = self.trace + logger.info("Skipping forward trace optimization as it will be restored from a configuration file.") + return + # Clear any previous results self.cached_fw_traces = [] _optimize() # We have multiple cached optimized fw traces, this iteration will create a fw-bw pair for # every cached forward trace. At the end the best one will be picked up. case TraceType.BW: + # Perform optimization only if we don't restore it from a past configuration + if restore_file_name: + logger.info( + "Skipping backward trace optimization as it will be restored from a configuration file." + ) + self.cached_computational_backward_trace = self.trace + return + # Clear any previous results self.out_traces_candidates = [] @@ -1120,9 +1225,9 @@ def _optimize(): self.trace.bound_symbols = list(cached_self_trace.bound_symbols) # Set the current active cached forward trace context logger.info( - f"Current fw cached ctx:\n{fw_trace_candidate.trace}\nOptions: {fw_trace_candidate.compile_opt.fusion_tag if fw_trace_candidate.compile_opt is not None else 'None'}" + f"Current fw cached ctx:\n{fw_trace_candidate.trace}\nOptions: {fw_trace_candidate.ctx.compile_options.fusion_tag if fw_trace_candidate.ctx.compile_options is not None else 'None'}" ) - self.active_fw_trace_ctx = fw_trace_candidate.trace, fw_trace_candidate.compile_opt + self.active_fw_trace_ctx = fw_trace_candidate.trace, fw_trace_candidate.ctx logger.debug(f"Input bw trace:\n{self.trace}") @@ -1138,8 +1243,8 @@ def _optimize(): self.trace = dce(self.trace) # Enable any forward active compilation flag - if fw_trace_candidate.compile_opt: - wrap_fn_with_exeuctor_compile_option(fw_trace_candidate.compile_opt, _optimize) + if fw_trace_candidate.ctx.compile_options: + wrap_fn_with_exeuctor_compile_option(fw_trace_candidate.ctx.compile_options, _optimize) else: _optimize() @@ -1148,6 +1253,29 @@ def _optimize(): self.out_traces_candidates ) + # Save the tuning if requested + do_save = self.compile_data.compile_options.get("autotune_save_configuration", False) + if do_save: + model_name = self.compile_data.compile_options.get("model_name", "unknown") + file_name = f"{model_name}_runtime.json" + dump_traces_placement( + fw_trace=self.cached_computational_trace, + bw_trace=self.cached_computational_backward_trace, + file_name=file_name, + apply_remat=self.best_pair_runtime.apply_remat, + exs_fw=self.best_pair_runtime.executors_fw, + exs_bw=self.best_pair_runtime.executors_bw, + ) + file_name = f"{model_name}_memory.json" + dump_traces_placement( + fw_trace=self.cached_computational_trace, + bw_trace=self.cached_computational_backward_trace, + file_name=file_name, + apply_remat=self.best_pair_memory.apply_remat, + exs_fw=self.best_pair_memory.executors_fw, + exs_bw=self.best_pair_memory.executors_bw, + ) + class BackendOptimizer: """ @@ -1187,7 +1315,7 @@ def optimize(self): """ self.optimizer.optimize() - def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType): + def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType, apply_dce=True): """ Attach a new trace for executors optimization. diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index cf301929d9..59a72d5327 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -6,6 +6,7 @@ from thunder.core.prims import PrimIDs from thunder.core.proxies import ( AnyProxy, + CollectionProxy, FloatProxy, IntegerProxy, NumberProxy, @@ -16,7 +17,7 @@ ) from thunder.core.symbol import BoundSymbol, Symbol from thunder.core.trace import TraceCtx, from_trace, get_tracectx, reset_tracectx, set_tracectx -from thunder.extend import Executor, FusionExecutor, OperatorExecutor +from thunder.extend import Executor, FusionExecutor, OperatorExecutor, get_always_executors from thunder.core.utils import check, safe_map_flat import thunder.core.transforms as transforms from itertools import chain @@ -953,19 +954,31 @@ def reorder_executors_list(executors: Sequence, **kwargs): return reordered -def symbol_hash(bsym: BoundSymbol): +def symbol_hash( + *, + bsym: BoundSymbol, + ignore_returns_meta: bool = False, + ignore_unpacks_meta: bool = False, + ignore_unpacks: bool = False, +): """ Hash a bound symbol relying on its metadata (symbol name, bound symbol inputsa and outputs). No hash functions will be applied in order to leave the output readable. Args: bsym: A bound symbol. + ignore_returns_meta: If True, return statement metadata will be ignored + ignore_unpacks_meta: If True, unpack statements metadata will be ignored + ignore_unpacks: If True, unpack symbols will not be included. """ def _tensor_hash(t: TensorProxy) -> str: assert t.dtype shapes = [str(s) for s in t.shape] - return "{" + "-".join(shapes) + "/" + str(t.device) + t.dtype.full_name + str(t.requires_grad) + "}" + return "{" + "-".join(shapes) + "," + str(t.device) + "," + t.dtype.full_name + "," + str(t.requires_grad) + "}" + + def _collection_hash(c: CollectionProxy) -> str: + return "{Collection," + c.name + "," + str((type(c.collection()))) + "," + str(len(c.collection())) + "}" def _number_hash(t: NumberProxy) -> str: return "{" + str(t.value) + "}" @@ -980,7 +993,7 @@ def _sequence_hash(s: Sequence | None) -> str: ret = "[" for e in s: if e is None: - ret += "{None}" + ret += "{None}," elif isinstance(e, TensorProxy): ret += _tensor_hash(e) + "," elif isinstance(e, NumberProxy): @@ -988,30 +1001,52 @@ def _sequence_hash(s: Sequence | None) -> str: elif isinstance(e, Sequence): ret += _sequence_hash(e) + "," elif isinstance(e, AnyProxy): - ret += _any_proxy_hash(e) + ret += _any_proxy_hash(e) + "," + elif isinstance(e, CollectionProxy): + ret += _collection_hash(e) + "," elif isinstance(e, int) or isinstance(e, float) or isinstance(e, bool): - ret += f"{(type(e))}" + ret += "{" + f"{(type(e))}" + "}," elif isinstance(e, dtype): - ret += f"{(type(e))}" + ret += "{" + f"{(type(e))}" + "}," else: raise RuntimeError(f"Not implemented {type(e)}. Failed bsym: {bsym}") return ret + "]" def _hash(bsym: BoundSymbol) -> str: + match = { + TensorProxy: _tensor_hash, + tuple: _sequence_hash, + list: _sequence_hash, + Sequence: _sequence_hash, + CollectionProxy: _collection_hash, + } + + if ignore_returns_meta and bsym.sym.id == PrimIDs.RETURN: + return "{return}" + + if ignore_unpacks and bsym.sym.name.startswith("unpack"): + return "" + elif ignore_unpacks_meta and bsym.sym.name.startswith("unpack"): + if bsym is not None and bsym.output is not None: + if isinstance(bsym.output, Sequence) and len(bsym.output) < 1: + return "" + return "{general_unpack}" + h = bsym.sym.name # Handle tensor as output or sequences - if not isinstance(bsym.output, TensorProxy) and not isinstance(bsym.output, Sequence): + if type(bsym.output) not in match.keys(): raise RuntimeError(f"type {type(bsym.output)} not implemented") h += ( "#out:" - + (_tensor_hash(bsym.output) if isinstance(bsym.output, TensorProxy) else _sequence_hash(bsym.output)) + + match[type(bsym.output)](bsym.output) + "#in:" # Args is always a tuple + _sequence_hash(bsym.args) ) return h - return _hash(bsym) + h = _hash(bsym) + return ("{" + h + "}") if h else h # Both lhs and rhs are included in the range @@ -1073,7 +1108,7 @@ def _lcs(start_indexes) -> int: lcs = 0 while start_indexes[0] < max_first_len and start_indexes[-1] < max_last_len: # Get all the hashes - hashes = [symbol_hash(symbols[i]) for i in start_indexes] + hashes = [symbol_hash(bsym=symbols[i]) for i in start_indexes] # Advance if all the hashes coincides uniques = set(hashes) if len(uniques) == 1: @@ -1092,7 +1127,7 @@ def _skip(bsym: BoundSymbol) -> bool: break if _skip(bsym): continue - h = symbol_hash(bsym) + h = symbol_hash(bsym=bsym) if h in bsym_indexes: bsym_indexes[h].append(i) else: @@ -1114,7 +1149,7 @@ def _range_seen(index: int, s: set): if _skip(bsym): continue - h = symbol_hash(bsym) + h = symbol_hash(bsym=bsym) # Normally, bsym are expected to output a TensorProxy if not isinstance(bsym.output, Proxy) or h in seen_hashes or _range_seen(i, seen_ranges): continue @@ -1146,6 +1181,7 @@ def _range_seen(index: int, s: set): return [] from thunder.backend_optimizer.optimizer import logger + logger.debug(f"Max block lcs fouund: {max_lcs}") logger.debug(f"{[(symbols[r[0]].output.name, symbols[r[1]].output.name) for r in res]}") @@ -1459,3 +1495,155 @@ def get_fw_bw_split_backends_options(bsym: BoundSymbol | None = None, **kwargs) } return options.get(bsym.sym.name, []) if bsym else options + + +def trace_symbolic_hash(trace: TraceCtx) -> str: + res = "" + for b in trace.bound_symbols: + # Ignoring unpacks as when tuple has size zero, there are cases when None is given as static args/output and cases where a zero sized tuple is returned. + res += symbol_hash(bsym=b, ignore_returns_meta=True, ignore_unpacks_meta=True) + return res + + +supported_file_modes = set(["json"]) + + +def dump_traces_placement( + *, + fw_trace: TraceCtx, + bw_trace: TraceCtx, + exs_fw: list[Executor], + exs_bw: list[Executor], + apply_remat: bool, + file_name: str, + output_mode: str = "json", +) -> str: + """ + Creates an output configuration file where the current forward and backward trace optimization are saved. + + Args: + fw_trace: A forward trace. + bw_trace: A backward trace. + exs_fw: Forward trace region executors. + exs_bw: Backward trace region executors. + apply_remat: If forward and backward traces are output of rematerialize_forward_and_backward + file_name: The output file name. + output_mode: The output file format. Must be one of ['json']. + """ + assert output_mode in supported_file_modes + + if output_mode == "json": + # We defined an unique trace by reading its bsym metadata, the proxies name are ignored as they may + # change but the overall computation can remain the same. + fw_hash = trace_symbolic_hash(fw_trace) + bw_hash = trace_symbolic_hash(bw_trace) + + executors_fw_name = [ex.name if (ex and ex.name != "empty") else "None" for ex in exs_fw] + executors_bw_name = [ex.name if (ex and ex.name != "empty") else "None" for ex in exs_bw] + + assert len(fw_trace.bound_symbols) == len(executors_fw_name) + assert len(bw_trace.bound_symbols) == len(executors_bw_name) + + from thunder.backend_optimizer.optimizer import logger + + logger.info( + f"Size match between len(fw_trace.bound_symbols)[{len(fw_trace.bound_symbols)}] and len(executors_fw_name)[{len(executors_fw_name)}]" + ) + logger.info( + f"Size match between len(bw_trace.bound_symbols)[{len(bw_trace.bound_symbols)}] and len(executors_bw_name)[{len(executors_bw_name)}]" + ) + logger.info(f"Saving configuration in {file_name}") + + data = { + "forward": { + "hash": fw_hash, + "executors": executors_fw_name, + }, + "backward": { + "hash": bw_hash, + "executors": executors_bw_name, + }, + "rematerialize": apply_remat, + } + try: + with open(file_name, "w") as file: + import json + + json.dump(data, file) + except Exception: + from thunder.backend_optimizer.optimizer import logger + import traceback + + err = traceback.format_exc() + logger.error(f"Can not dump {file_name} file:\n{err}") + return "" + return file_name + return "" + + +def apply_results_from_file( + *, fw_trace: TraceCtx, bw_trace: TraceCtx, file: str, input_mode: str = "json" +) -> tuple[TraceCtx, TraceCtx]: + """ + Generate a transformed forward and backward trace from a configuration file. + Compatibility check is performed on both traces. + + Args: + fw_trace: The original augmented forward trace. + bw_trace: The original backward trace. + file: The configuration file. + input_mode: The configuration structure. Must be one of ['json']. + """ + import json + from thunder.executors.torchex import ex as torch_ex + from thunder.executors.pythonex import ex as python_ex + from thunder.executors.sdpaex import sdpa_ex + from thunder.executors.cudnnex import cudnn_ex + from thunder.executors.fa3ex import fa3_ex + from thunder.executors.nvfuserex_impl import ex as nvfuser_ex + from thunder.executors.torch_compile import torch_compile_ex + from thunder.executors.torch_autograd import update_bw_from_forward_optimization + + assert input_mode in supported_file_modes + + # Extend this if more executors will be added + conversion_map: dict[str | Hashable, Executor] = { + "None": Executor("empty"), + torch_ex.name: torch_ex, + python_ex.name: python_ex, + nvfuser_ex.name: nvfuser_ex, + torch_compile_ex.name: torch_compile_ex, + sdpa_ex.name: sdpa_ex, + cudnn_ex.name: cudnn_ex, + fa3_ex.name: fa3_ex, + } + + if input_mode == "json": + data = json.load(open(file, "r")) + + fw_hash = trace_symbolic_hash(fw_trace) + bw_hash = trace_symbolic_hash(bw_trace) + assert fw_hash == data["forward"]["hash"] + assert bw_hash == data["backward"]["hash"] + + fw_executors_recovered: list[str] = data["forward"]["executors"] + extrace_fw = assign_executors( + in_trace=fw_trace, + executors_list=[conversion_map[ex] for ex in fw_executors_recovered], + empty_str="empty", + always_executors=get_always_executors(), + ) + bw_executors_recovered: list[str] = data["backward"]["executors"] + bw_trace = update_bw_from_forward_optimization(fw=extrace_fw, bw=bw_trace) + extrace_bw = assign_executors( + in_trace=bw_trace, + executors_list=[conversion_map[ex] for ex in bw_executors_recovered], + empty_str="empty", + always_executors=get_always_executors(), + ) + + if data["rematerialize"]: + from thunder.core.rematerialization import rematerialize_forward_and_backward + + return rematerialize_forward_and_backward(extrace_fw, extrace_bw) + return extrace_fw, extrace_bw diff --git a/thunder/core/vjp_utils.py b/thunder/core/vjp_utils.py index a22b23214e..1131a69bf9 100644 --- a/thunder/core/vjp_utils.py +++ b/thunder/core/vjp_utils.py @@ -255,7 +255,7 @@ def bw_fn(*args, **kwargs): # as they are expected to work on same input size, shape and dtype). optmimizer_common_transformer_block = cd.compile_options.get('autotune_optimize_common_blocks', False) # The generated hash will rely on the operation, input args metadata and output metadata - h = symbol_hash(bsym) + h = symbol_hash(bsym=bsym) # Recover the cache stored in the compile data autotuner_bsym_with_gradfn_executor_cache = cd.autotuner_bsym_with_gradfn_executor_cache diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index a6fb17526a..5239aa1f41 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -157,7 +157,7 @@ def autotune_transform_for_execution( trace = apply_bucketing_to_grad_allreduce(trace) # Attach new trace and set the debug file name - optimizer_context.attach_trace(trace=trace, trace_type=trace_type) + optimizer_context.attach_trace(trace=trace, trace_type=trace_type, apply_dce=trace_type == TraceType.FW) optimizer_context.log_file_name = f"autotune_transform_for_execution_{sig_name}.log" # Forward traces are cached inside the context optimizer_context.optimize() @@ -176,11 +176,13 @@ def autotune_transform_for_execution( # Assign the trace provenance match trace_type: case TraceType.FW: - fw_traces = optimizer_context.get_optimal_fw_traces() - for trc in fw_traces: - trc.set_provenance( - TraceProvenance(f"Autotuned transform for execution (took {elapsed_time_millis} milliseconds)") - ) + cd = get_compile_data() + if not cd or not cd.compile_options.get('autotune_restore_configuration', ""): + fw_traces = optimizer_context.get_optimal_fw_traces() + for trc in fw_traces: + trc.set_provenance( + TraceProvenance(f"Autotuned transform for execution (took {elapsed_time_millis} milliseconds)") + ) return None case TraceType.BW: bw_extrace.set_provenance( diff --git a/thunder/tests/test_autotuner.py b/thunder/tests/test_autotuner.py index 07d4c6c98e..336db20986 100644 --- a/thunder/tests/test_autotuner.py +++ b/thunder/tests/test_autotuner.py @@ -499,7 +499,16 @@ def forward(self, x): @pytest.mark.parametrize( "model, tensor_shape, dtype, autotune_type, executors, expected_executors, use_cudagraphs, use_te", [ - (Model_1(32, 32), (32, 32), torch.float32, "runtime", [nvfuserex], [[nvfuserex, torchex, pythonex]], True, False), + ( + Model_1(32, 32), + (32, 32), + torch.float32, + "runtime", + [nvfuserex], + [[nvfuserex, torchex, pythonex]], + True, + False, + ), ( Model_1(32, 32), (32, 32), @@ -508,7 +517,7 @@ def forward(self, x): [torch_compile_ex], [[torch_compile_ex, torchex, pythonex]], True, - False + False, ), ( Model_1(4096, 4096), @@ -518,7 +527,7 @@ def forward(self, x): [transformer_engine_ex], [[transformer_engine_ex, nvfuserex, torchex, pythonex]], False, - True + True, ), ( Model_2(), @@ -541,7 +550,7 @@ def forward(self, x): [transformer_engine_ex, sdpa_ex, nvfuserex, torchex, pythonex], ], False, - True + True, ), ], ) @@ -554,14 +563,18 @@ def test_autotuner( executors: list, expected_executors: list[list], use_cudagraphs: bool, - use_te: bool + use_te: bool, ): def _run(): model.to("cuda") x = torch.randn(tensor_shape, dtype=dtype, device="cuda") jitted_def = thunder.jit(model, executors=executors) jitted_auto = thunder.jit( - model, autotune_type=autotune_type, executors=executors, use_cudagraphs=use_cudagraphs, autotune_enable_te=use_te + model, + autotune_type=autotune_type, + executors=executors, + use_cudagraphs=use_cudagraphs, + autotune_enable_te=use_te, ) y_def = jitted_def(x) y_auto = jitted_auto(x) @@ -639,3 +652,43 @@ def test_reduce_common_trace_blocks(): for b in reduced_trace.bound_symbols: if hasattr(b.output, "name"): assert b.output.name not in should_remove + +@requiresCUDA +def test_save_configuration_cuda(): + class _LLaMAMLP(torch.nn.Module): + def __init__(self, n_embd, intermediate_size) -> None: + super().__init__() + self.fc_1 = torch.nn.Linear(n_embd, intermediate_size, bias=False) + self.fc_2 = torch.nn.Linear(n_embd, intermediate_size, bias=False) + self.proj = torch.nn.Linear(intermediate_size, n_embd, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_fc_1 = self.fc_1(x) + x_fc_2 = self.fc_2(x) + x = torch.nn.functional.silu(x_fc_1) * x_fc_2 + return self.proj(x) + + with torch.device("cuda"): + model = _LLaMAMLP(4, 4) + jitted = thunder.jit( + model, + autotune_type="memory", + model_name="llamamlp", + autotune_save_configuration=True, + ) + jitted_recovered = thunder.jit( + model, + autotune_type="runtime", + autotune_restore_configuration="llamamlp_memory.json", + ) + + x = torch.randn(4, 4) + a = jitted(x) + b = jitted_recovered(x) + + torch.testing.assert_close(a, b) + + for bsym_a, bsym_b in zip( + thunder.last_traces(jitted)[-1].bound_symbols, thunder.last_traces(jitted_recovered)[-1].bound_symbols + ): + assert bsym_a.sym.executor == bsym_b.sym.executor From 625b473c51ce1af68d5cc79beeb35cc75f856d9e Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 6 Sep 2024 00:29:25 +0300 Subject: [PATCH 138/171] Doc --- thunder/backend_optimizer/optimizer.py | 126 ++++++++++++++----------- 1 file changed, 72 insertions(+), 54 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index d6b496616e..58c2b772c0 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -35,16 +35,16 @@ class FusionCompileOptionsHelper: Represents compile options for a fusion executor. Attributes: - fusion_tag: A label representing the fusion ops regarding a compile option (e.g. nv_linear). - symbol_tag: The symbol name - id: The symbol id. - impl: A callable implementation. - checker: A callable checker. + fusion_tag (str): A label representing the fusion ops regarding a compile option (e.g. nv_linear). + symbol_tag (str): The symbol name + id (PrimIDs): The symbol id. + impl (Callable): A callable implementation. + checker (Callable): A callable checker. """ def __init__(self, fusion_tag: str, symbol_tag: str, id: PrimIDs, impl: Callable, checker: Callable) -> None: - self.fusion_tag = fusion_tag - self.symbol_tag = symbol_tag + self.fusion_tag: str = fusion_tag + self.symbol_tag: str = symbol_tag self.id: PrimIDs = id self.impl: Callable = impl self.checker: Callable = checker @@ -55,8 +55,8 @@ class FusionExecutorsPlacementCtx: Represents a executor placement context. Attributes: - placement: A list of executors. - compile_options: Any compile options being used for the fusion executor contained in the placement. + placement (list): A list of executors. + compile_options (FusionExecutorsPlacementCtx | None): Any compile options being used for the fusion executor contained in the placement. """ def __init__(self, *, placement: list, compile_options: FusionCompileOptionsHelper | None = None) -> None: @@ -69,9 +69,9 @@ class TraceCandidate: Represents an optimal trace candidate. Attributes: - trace: The candidate trace. - ctx: Trace's placement context. - label: A generic label to identify this candidate. + trace (TraceCtx): The candidate trace. + ctx (FusionExecutorsPlacementCtx): Trace's placement context. + label (str): A generic label to identify this candidate. """ def __init__( @@ -91,10 +91,10 @@ class TraceCandidates: Represents an optimal pair of trace candidates (compute time and memory consumption). Attributes: - best_time: The trace with the optimal runtime. - best_mem: The trace with the optimal peak memory consumption. - placement_ctx_time: Trace placement context: exeuctors and any applied fusion compile options. - placement_ctx_mem: Trace placement context: exeuctors and any applied fusion compile options. + best_time (TraceCtx): The trace with the optimal runtime. + best_mem (TraceCtx): The trace with the optimal peak memory consumption. + placement_ctx_time (FusionExecutorsPlacementCtx): Trace placement context with exeuctors and any applied fusion compile options. + placement_ctx_mem (FusionExecutorsPlacementCtx): Trace placement context with exeuctors and any applied fusion compile options. """ def __init__( @@ -112,12 +112,18 @@ def __init__( def __repr__(self) -> str: """ Give a representation for the current object. + + Returns: + str: A string as the representation of the current object """ return f"\nBest runtime candidate:\n{self.best_time}\nBest memory candidate:\n{self.best_mem}" def is_set(self) -> bool: """ Check that the optimal trace pair has been set. + + Returns: + bool: A flag indicating if the optimal trace is not None. """ return False if self.best_time is None or self.best_mem is None else True @@ -126,8 +132,8 @@ def attach_best_time_candidate(self, trace: TraceCtx, ctx: FusionExecutorsPlacem Attach a new best time trace result. Args: - trace: The trace to assign. - ctx: The trace placement context. + trace (TraceCtx): The trace to assign. + ctx (FusionExecutorsPlacementCtx | None): The trace placement context. """ self.best_time = trace self.placement_ctx_time = ctx @@ -137,8 +143,8 @@ def attach_best_mem_candidate(self, trace: TraceCtx, ctx: FusionExecutorsPlaceme Attach a new best memory trace result. Args: - trace: The trace to assign. - ctx: The trace placement context. + trace (TraceCtx): The trace to assign. + ctx (FusionExecutorsPlacementCtx | None): The trace placement context. """ self.best_mem = trace self.placement_ctx_mem = ctx @@ -146,18 +152,27 @@ def attach_best_mem_candidate(self, trace: TraceCtx, ctx: FusionExecutorsPlaceme def iterable(self) -> tuple[tuple, tuple]: """ Returns an iterable object over the traces paired with their contexts. + + Returns: + tuple: A tuple with paired values of performance metric and its context. """ return (self.best_time, self.placement_ctx_time), (self.best_mem, self.placement_ctx_mem) def trace_ctx_iterable(self) -> tuple[TraceCtx | None, TraceCtx | None]: """ Returns an iterable object over the traces. + + Returns: + tuple: A tuple of traces with time and memory consumption targets. """ return self.best_time, self.best_mem def placement_ctx_iterable(self) -> tuple[FusionExecutorsPlacementCtx | None, FusionExecutorsPlacementCtx | None]: """ Returns an iterable object over the placement contexts. + + Returns: + tuple: A tuple of contexes referring to traces targetting compute time and peak memory consumption. """ return self.placement_ctx_time, self.placement_ctx_mem @@ -167,13 +182,13 @@ class OutputCandidate: Represents a final output candidate: forward and backward trace pair. Attributes: - fw: The forward trace. - bw: The backward trace. - executors_fw: The forward trace regions' executors - executors_bw: The backward trace regions' executors - compile_opt: Any compile options being used for a fusion executor in the forward trace. - tot_cost: The total cost to execute the pair (ms for a time strategy and GB for a memory strategy). - apply_remat: If rematerialization has been applied. + fw (TraceCtx): The forward trace. + bw (TraceCtx): The backward trace. + executors_fw (list): The forward trace regions' executors + executors_bw (list): The backward trace regions' executors + compile_opt (FusionExecutorsPlacementCtx | None): Any compile options being used for a fusion executor in the forward trace. + tot_cost (float): The total cost to execute the pair (ms for a time strategy and GB for a memory strategy). + apply_remat (bool): If rematerialization has been applied. """ def __init__( @@ -193,11 +208,14 @@ def __init__( self.executors_bw: list[Executor] = executors_bw self.compile_opt: FusionCompileOptionsHelper | None = compile_opt self.tot_cost: float = cost - self.apply_remat = apply_remat + self.apply_remat: bool = apply_remat def __repr__(self) -> str: """ Give a representation of the current object. + + Returns: + str: A string representing the current object. """ return f"Final output candidate: forward trace:\n{self.fw.__repr__()}\nFinal output candidate: backward trace:\n{self.bw.__repr__()}" @@ -207,11 +225,11 @@ class FusionStratHelper: Represents a helper structure for the fusion strategy. Attributes: - supported_executors: A list of supported fusion executors. - optimized_traces_mem: a list of dictionaries containing informations regarding the optimized traces for peak memory consumption. - optimized_traces_mem_benchmark_only: a list of dictionaries containing informations regarding the optimized traces for peak memory consumption (used only for internal benchmarking). - optimized_traces_time: a list of dictionaries containing informations regarding the optimized traces for total compute time. - optimized_traces_time_benchmark_only: a list of dictionaries containing informations regarding the optimized traces for total compute time (used only for internal benchmarking). + supported_executors (set): A list of supported fusion executors. + optimized_traces_mem (list): a list of dictionaries containing informations regarding the optimized traces for peak memory consumption. + optimized_traces_mem_benchmark_only (list): a list of dictionaries containing informations regarding the optimized traces for peak memory consumption (used only for internal benchmarking). + optimized_traces_time (list): a list of dictionaries containing informations regarding the optimized traces for total compute time. + optimized_traces_time_benchmark_only (list): a list of dictionaries containing informations regarding the optimized traces for total compute time (used only for internal benchmarking). """ def __init__(self) -> None: @@ -227,8 +245,8 @@ class ExecutorPlacementOptions: Represents an aggregate placement options for executors combining those that targets peak memory consumption and those for total compute time. Attributes: - placement_options_mem: A list of placement contexts. - placement_options_time: A list of placement contexts. + placement_options_mem (list): A list of placement contexts. + placement_options_time (list): A list of placement contexts. """ def __init__(self) -> None: @@ -241,25 +259,25 @@ class PlacerBase: Represents a base (interface) class for a placement class. Attributes: - always_executors: A list of always present executors. - empty_executor_hashable_placeholder: A label representing en empty executor. - executors: A list of executors to use. - fusion_executors: A list of fusion executors to use. - fusion_executors_saved_for_later: A helper list containing maybe repeated fusion executors. - debug_msg: A dynamic filled log message. - log_file_name: The output log file name if generated. - produce_log: A tuning parameter to control log file generation. - optimizer_type: The optimization target. - active_fw_trace_ctx: An active forward trace set to optimize backward. - cached_fw_traces: Cached optimized forward traces. - cached_computational_trace: Original computational trace - cached_computational_backward_trace: Original computational backward trace - bw_trace_candidates: An instance of trace candidates. - best_pair_runtime: A final traace pair targetting the compute time. - best_pair_memory: A final traace pair targetting the peak memory consumption. - apply_bucketing_bw_trace: A distributed flag. - benchmark_iters: Benchmark iteration steps. - compile_data: Thunder compilation data. + always_executors (tuple): A list of always present executors. + empty_executor_hashable_placeholder (str): A label representing en empty executor. + executors (Sequence): A list of executors to use. + fusion_executors (Sequence): A list of fusion executors to use. + fusion_executors_saved_for_later (Sequence): A helper list containing maybe repeated fusion executors. + debug_msg (str): A dynamic filled log message. + log_file_name (str): The output log file name if generated. + produce_log (bool): A tuning parameter to control log file generation. + optimizer_type (OptimizerType): The optimization target. + active_fw_trace_ctx (tuple): An active forward trace set to optimize backward. + cached_fw_traces (list): Cached optimized forward traces. + cached_computational_trace (TraceCtx): Original computational trace + cached_computational_backward_trace (TraceCtx): Original computational backward trace + bw_trace_candidates (TraceCandidate): An instance of trace candidates. + best_pair_runtime (OutputCandidate): A final trace pair targetting the compute time. + best_pair_memory (OutputCandidate): A final trace pair targetting the peak memory consumption. + apply_bucketing_bw_trace (bool): A distributed flag. + benchmark_iters (int): Benchmark iteration steps. + compile_data (Any): Thunder compilation data. """ def __init__( From ed12755697500223b567827241273e73bdb7df77 Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 17 Sep 2024 14:42:33 +0300 Subject: [PATCH 139/171] Removed trace print --- thunder/backend_optimizer/optimizer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 58c2b772c0..cee8f5070b 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -1096,13 +1096,13 @@ def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType, apply_dce: boo match self.trace_type: case TraceType.FW: - logger.info(f"New forward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}") + logger.info(f"New forward trace to optimize (strat = {self.optimizer_type})") # TODO (matteochen): support bw trace optimization even though with no fw traces cached (computational trace?) case TraceType.BW: if not self.compile_data.compile_options.get("autotune_restore_configuration", ""): if not self.cached_fw_traces: raise AssertionError("Can not optimize backward traces before forward traces") - logger.info(f"New backward trace to optimize (strat = {self.optimizer_type}):\n{self.trace}") + logger.info(f"New backward trace to optimize (strat = {self.optimizer_type})") def optimize(self): from thunder.core.transform_common import dce @@ -1135,7 +1135,7 @@ def _optimize(): f"Running optimization with common blocks reduction. Found block indices in trace: {common_trace_blocks}" ) reduced_trace = reduce_common_trace_blocks(trace=self.trace, common_blocks_in=common_trace_blocks) - logger.info(f"Operating on reduced trace (by cutting common transformer blocks):\n{reduced_trace}") + logger.info("Operating on reduced trace (by cutting common transformer blocks)") self.is_reduced = True self.cached_original_trace = self.trace self.trace = reduced_trace @@ -1242,9 +1242,9 @@ def _optimize(): self.trace = from_trace(cached_self_trace) self.trace.bound_symbols = list(cached_self_trace.bound_symbols) # Set the current active cached forward trace context - logger.info( - f"Current fw cached ctx:\n{fw_trace_candidate.trace}\nOptions: {fw_trace_candidate.ctx.compile_options.fusion_tag if fw_trace_candidate.ctx.compile_options is not None else 'None'}" - ) + # logger.info( + # f"Current fw cached ctx:\n{fw_trace_candidate.trace}\nOptions: {fw_trace_candidate.ctx.compile_options.fusion_tag if fw_trace_candidate.ctx.compile_options is not None else 'None'}" + # ) self.active_fw_trace_ctx = fw_trace_candidate.trace, fw_trace_candidate.ctx logger.debug(f"Input bw trace:\n{self.trace}") From 519855add22facda4c6a66b9e1ffd08779510b2a Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 17 Sep 2024 14:58:08 +0300 Subject: [PATCH 140/171] Enhanced timing measurement / autotuned nvmath matmul --- examples/dev/litGPT.py | 72 ++++++------------------ thunder/benchmarks/utils.py | 100 +++++++++++++++++++++++++++++----- thunder/executors/nvmathex.py | 51 +++++++++++------ thunder/extend/__init__.py | 1 + 4 files changed, 138 insertions(+), 86 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 56f9bb74b3..b32c128bc0 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -1,6 +1,5 @@ """ -This script benchmarks litGPT models in a easier wrt to benchmark_litgpt.py way with a fake training loop with no optimizers in order to focus more on -forward and backward computation time and not others kernel during the loop. +This script benchmarks litGPT models in a easier way wrt to benchmark_litgpt.py with a fake training loop with no optimizers. """ from litgpt import GPT @@ -9,16 +8,16 @@ torch_fw_bw_benchmark, torch_fw_bw_benchmark_nvsight, torch_total_benchmark, + torch_timer_total_benchmark ) from thunder.tests.litgpt_model import Config import thunder import torch -import time +# import time torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn - class Test: def __init__( self, @@ -41,7 +40,7 @@ def __init__( self.optimize_transformer_min_block_size = optimize_transformer_min_block_size -layers = [ +to_run = [ Test( 1, "runtime", @@ -49,26 +48,15 @@ def __init__( executors=[ "cudnn", "sdpa", - # "fa3", - "nvfuser", - "torchcompile", - ], - ), - Test( - 4, - "runtime", - 1, - executors=[ - "cudnn", - "sdpa", - # "fa3", + "fa3", "nvfuser", + "nvmath", "torchcompile", ], ), ] -for test in layers: +for test in to_run: try: cfg = Config.from_name(test.model_name) cfg.n_layer = test.layers @@ -79,6 +67,7 @@ def __init__( with torch.device("cuda"): model = GPT(cfg) x = torch.randint(1, model.config.vocab_size, (test.batch_size, cfg.block_size)) + target = torch.ones_like(x) print(f"Input size: {x.size()}") eager = model @@ -92,47 +81,20 @@ def __init__( autotune_optimize_common_blocks=test.optimize_transformer_blocks, autotune_optimize_common_blocks_min_size=test.optimize_transformer_min_block_size, ) - print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item()) - s = time.time_ns() - print("deviation auto:", (jmodel_auto(x) - model(x)).abs().max().item()) - e = time.time_ns() - print("Compilation time:", {(e - s) / 1000000000}, "s") + # print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item()) + # s = time.time_ns() + # print("deviation auto:", (jmodel_auto(x) - model(x)).abs().max().item()) + # e = time.time_ns() + # print("Compilation time:", {(e - s) / 1000000000}, "s") iters = 100 - fw_traces = [ - thunder.last_traces(jmodel_def)[-1], - thunder.last_traces(jmodel_auto)[-1], - ] - bw_traces = [ - thunder.last_backward_traces(jmodel_def)[-1], - thunder.last_backward_traces(jmodel_auto)[-1], - ] - fw_labels = ["fw_def", "fw_auto"] - bw_labels = ["bw_def", "bw_auto"] - print("\n\n####################################################", test.model_name) - print(f"Results thunder benchmark ({iters} iters):") - thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters, nvsight=False) - # thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, 10, nvsight=True) - - print(f"\n\nResults torch fw bw benchmark ({iters} iters):") callables = [eager, torch_compile, jmodel_def, jmodel_auto] labels = ["eager", "torch.compile", "Thunder", "Thunder Autotuner"] inputs = [x, x, x, x] - torch_fw_bw_benchmark(callables, labels, inputs, iters) - print(f"\n\nResults torch total benchmark ({iters} iters):") - torch_total_benchmark(callables, labels, inputs, iters) - - torch_fw_bw_benchmark_nvsight(callables, labels, inputs, iters) - - print("\n\n\n\n\n\n") - print(f"{thunder.last_traces(jmodel_def)[-1]}") - print("###############################################################################") - print(f"{thunder.last_traces(jmodel_auto)[-1]}") - - print("\n\n") - print(f"{thunder.last_backward_traces(jmodel_def)[-1]}") - print("###############################################################################") - print(f"{thunder.last_backward_traces(jmodel_auto)[-1]}") + print(f"\nResults torch total benchmark ({iters} iters):") + torch_total_benchmark(callables, labels, inputs, iters, torch.nn.functional.cross_entropy) + print(f"\nResults torch timer benchmark ({iters} iters):") + torch_timer_total_benchmark(callables, labels, inputs, test.model_name, torch.nn.functional.cross_entropy) except Exception as e: print(f"Test failed:\n{e}") import traceback diff --git a/thunder/benchmarks/utils.py b/thunder/benchmarks/utils.py index c6977eca5d..c19f25cffb 100644 --- a/thunder/benchmarks/utils.py +++ b/thunder/benchmarks/utils.py @@ -1,6 +1,7 @@ from collections.abc import Callable import torch from thunder.backend_optimizer.utils import benchmark_trace +from torch.utils.benchmark import Timer, Compare warm_up_iters = 50 @@ -25,8 +26,18 @@ def __init__( self.bw_fn: Callable | None = bw_fn self.executor = executor +def _run_loss(model, input, target, loss_fn): + logits = model(input) + logits = logits.reshape(-1, logits.size(-1)) + target = target.reshape(-1) + loss = loss_fn(logits, target) + loss.backward() -def torch_fw_bw_benchmark_nvsight(models: list, labels: list, inputs: list, iters: int) -> None: +def _run_autograd(model, input): + y = model(input) + torch.autograd.grad(y, input, grad_outputs=torch.ones_like(y)) + +def torch_fw_bw_benchmark_nvsight(models: list, labels: list, inputs: list, iters: int, loss) -> None: """ Benchmark a mock trainig loop of the given models. The loss function is defined as a naive torch.sum(). This util will generate nvsight system profiles. @@ -61,10 +72,11 @@ def torch_fw_bw_benchmark_nvsight(models: list, labels: list, inputs: list, iter torch.cuda.cudart().cudaProfilerStop() -def torch_fw_bw_benchmark(models: list, labels: list, inputs: list, iters: int) -> None: +def torch_fw_bw_benchmark(models: list, labels: list, inputs: list, iters: int, loss_fn: Callable | None = None) -> None: """ - Benchmark a mock trainig loop of the given models. The loss function is defined as a naive torch.sum(). - Forward and backward pass will be both recorded. + Benchmark a mock trainig loop of the given models. Time measurements will be performed by using cuda events. + A loss function is applied to trigger backward if provided. Otherwise torch.autograd will be used. + Forward and backward pass will be recorded separately. Args: models: a list of Callable models to benchmark. @@ -74,9 +86,12 @@ def torch_fw_bw_benchmark(models: list, labels: list, inputs: list, iters: int) """ for m, input, label in zip(models, inputs, labels): # Warm up + target = torch.ones_like(input) for _ in range(warm_up_iters): - y = m(input) - y.sum().backward() + if loss_fn is not None: + _run_loss(m, input, target, loss_fn) + else: + _run_autograd(m, input) start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] @@ -106,14 +121,20 @@ def torch_fw_bw_benchmark(models: list, labels: list, inputs: list, iters: int) max_allocated_bytes = 0 torch.cuda.synchronize() for i in range(iters): + target = torch.ones_like(input) torch.cuda.empty_cache() torch.cuda._sleep(1_000_000) y = m(input) - loss = y.sum() torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) start_events[i].record(stream) - loss.backward() + if loss_fn is not None: + y = y.reshape(-1, y.size(-1)) + target = target.reshape(-1) + loss = loss_fn(y, target) + loss.backward() + else: + torch.autograd.grad(y, input, grad_outputs=torch.ones_like(y)) end_events[i].record(stream) max_allocated_bytes = max(max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device())) @@ -124,10 +145,51 @@ def torch_fw_bw_benchmark(models: list, labels: list, inputs: list, iters: int) print(f"{label} tot bw time: {tot_time} ms") print(f"{label} max bw allocated memory: {max_allocated_bytes / (2**30)} GB") +def torch_timer_total_benchmark( + models: list, labels: list, inputs: list, name: str = "Model", loss_fn: Callable | None = None +) -> None: + """ + Benchmark a mock trainig loop time of the given models. Measurements will be computed by using torch.utils.benchmark.Timer. + A loss function is applied to trigger backward if provided. Otherwise torch.autograd will be used. -def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int) -> None: + Args: + models: a list of Callable models to benchmark. + labels: a list of labels (names) referring to the models. + inputs: a list of inputs to give to models' forward pass. + name: the model name + loss_fn: a Pytorch loss function """ - Benchmark a mock trainig loop of the given models. The loss function is defined as a naive torch.sum(). + results = [] + for m, l, i in zip(models, labels, inputs): + t = Timer( + stmt=""" + _run_loss(m, i, target, loss_fn) + """ + if loss_fn is not None + else """ + _run_atograd(m, i) + """, + globals={ + "i": i, + "m": m, + "target": torch.zeros_like(i), + "_run_loss": _run_loss, + "_run_autograd": _run_autograd, + "loss_fn": loss_fn, + }, + label=name, + description=l, + ) + results.append(t.blocked_autorange(min_run_time=1)) + print(results[-1]) + compare = Compare(results) + compare.colorize(rowwise=True) + compare.print() + +def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int, loss_fn: Callable | None = None) -> None: + """ + Benchmark a mock trainig loop of the given models. Time measurements will be performed by using cuda events. + A loss function is applied to trigger backward if provided. Otherwise torch.autograd will be used. The complete time will be recorded with no split between forward pass and backward pass. Args: @@ -138,9 +200,12 @@ def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int) """ for m, input, label in zip(models, inputs, labels): # Warm up + target = torch.ones_like(input) for _ in range(warm_up_iters): - y = m(input) - y.sum().backward() + if loss_fn is not None: + _run_loss(m, input, target, loss_fn) + else: + _run_autograd(m, input) start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] @@ -148,14 +213,20 @@ def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int) max_allocated_bytes = 0 torch.cuda.synchronize() for i in range(iters): + target = torch.ones_like(input) torch.cuda.empty_cache() torch.cuda._sleep(1_000_000) torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) start_events[i].record(stream) y = m(input) - loss = y.sum() - loss.backward() + if loss_fn is not None: + y = y.reshape(-1, y.size(-1)) + target = target.reshape(-1) + loss = loss_fn(y, target) + loss.backward() + else: + torch.autograd.grad(y, input, grad_outputs=torch.ones_like(y)) end_events[i].record(stream) max_allocated_bytes = max(max_allocated_bytes, torch.cuda.max_memory_allocated(torch.cuda.current_device())) @@ -166,7 +237,6 @@ def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int) print(f"{label} tot time: {tot_time} ms") print(f"{label} max allocated memory: {max_allocated_bytes / (2**30)} GB") - def thunder_fw_bw_benchmark( fw_traces: list, bw_traces: list, fw_labels: list, bw_labels: list, iters: int, nvsight: bool = False ) -> None: diff --git a/thunder/executors/nvmathex.py b/thunder/executors/nvmathex.py index 0661822a78..b0ae70dbb7 100644 --- a/thunder/executors/nvmathex.py +++ b/thunder/executors/nvmathex.py @@ -1,31 +1,50 @@ -from thunder import TensorProxy +from importlib.metadata import version from thunder.core.prims import PrimIDs +import logging import nvmath import thunder import thunder.torch as ltorch import torch -nvmath_ex = thunder.extend.OperatorExecutor("nvmath", version="0.1.0") +logger = logging.getLogger("Thunder nvmath_ex") +logger.disabled = True + +nvmath_ex = thunder.extend.OperatorExecutor("nvmath", version=version('nvmath-python')) thunder.extend.register_executor(nvmath_ex) +_cache = {} +options = nvmath.linalg.advanced.MatmulOptions(logger=logger) + +def _cache_key(a: torch.Tensor, b: torch.Tensor) -> str: + def _get_shape_str(t: tuple): + return '_'.join(str(num) for num in t) + + return f'{_get_shape_str(a.size())}-{_get_shape_str(b.size())}' def _nvmath_linalg_advanced_matmul_impl(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: - return nvmath.linalg.advanced.matmul(a, b) - - -def _nvmath_linalg_advanced_matmul_checker(a: TensorProxy, b: TensorProxy) -> bool: - if len(a.shape) < 2 or len(b.shape) < 2: - return False - if a.shape[-1] != b.shape[-2]: - return False - if a.device != b.device: - return False - if a.dtype != b.dtype: - return False - # Handle distribuited + # Check if these shapes have been cached + k = _cache_key(a, b) + if k in _cache: + algo = _cache[k] + with nvmath.linalg.advanced.Matmul(a, b, options=options) as mm: + # Provide the optimized algorithms directly to plan. + mm.plan(algorithms=algo) + # Execute the multiplication + return mm.execute() + + # Compute a new shape and cache the result + with nvmath.linalg.advanced.Matmul(a, b, options=options) as mm: + preferences = nvmath.linalg.advanced.MatmulPlanPreferences(limit=25) + mm.plan(preferences=preferences) + mm.autotune(iterations=10) + # Execute the multiplication + result = mm.execute() + _cache[k] = mm.algorithms + return result + +def _nvmath_linalg_advanced_matmul_checker(*args, **kwargs) -> bool: return True - nvmath_linalg_advanced_matmul = nvmath_ex.register_operator( "nvmath_linalg_advanced_matmul", like=ltorch.matmul, diff --git a/thunder/extend/__init__.py b/thunder/extend/__init__.py index 337c9b0ada..5a51f968c8 100644 --- a/thunder/extend/__init__.py +++ b/thunder/extend/__init__.py @@ -365,6 +365,7 @@ def get_all_executors() -> tuple[Executor, ...]: torchex, transformer_engineex, triton_crossentropy, + nvmathex ) return tuple(_executor_map.values()) From f09a8132480b0f3f0ec7e13920a85759b9972f14 Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 17 Sep 2024 15:06:12 +0300 Subject: [PATCH 141/171] Jit doc --- thunder/__init__.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/thunder/__init__.py b/thunder/__init__.py index cce32c1524..468d855ce0 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -305,6 +305,13 @@ def jit( interpretation: (deprecated: don't use this, use the thunder.functional.jit entry point to get the functional jit) transforms: List of transforms to be applied. It should be an instance :class:`thunder.core.transforms.Transform`. Default: ``None`` + + autotune_type: string representing the required autotuner performance target (``"runtime"`` or ``"memory"``). + autotune_nv_enable_options: boolean to enable nvFuser compilation options autotuning. Currently at most one option will be used. Default: ``"False"`` + autotune_enable_te: boolean to enable TransformerEngineFP8 executor autotuning. Default: ``"False"`` + autotune_optimize_common_blocks: boolean to enable trace's common block optimization during the compilation (for example transformer layers). This optimization can be used if you are working with a model with repeated block structures as transformer based models. You don't need to know + where a block starts or ends as it's handled automatically. Default: ``"False"`` + autotune_optimize_common_blocks_min_size: integer to control the minimum block length to trigger the common block optimization. Default: ``-1`` """ from thunder.backend_optimizer.optimizer import OptimizerType From a86630c1021532aa7e472ffbec80b394233ad5f4 Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 17 Sep 2024 15:22:04 +0300 Subject: [PATCH 142/171] Renamed class --- examples/dev/litGPT.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index b32c128bc0..51a370b662 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -18,7 +18,7 @@ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn -class Test: +class LitGPTModelThunderConfig: def __init__( self, layers: int, @@ -41,7 +41,7 @@ def __init__( to_run = [ - Test( + LitGPTModelThunderConfig( 1, "runtime", 1, @@ -96,7 +96,7 @@ def __init__( print(f"\nResults torch timer benchmark ({iters} iters):") torch_timer_total_benchmark(callables, labels, inputs, test.model_name, torch.nn.functional.cross_entropy) except Exception as e: - print(f"Test failed:\n{e}") + print(f"Benchmark failed:\n{e}") import traceback traceback.print_exc() From 9383c5e180047520ca632c0a2f536f25727a362c Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 17 Sep 2024 15:27:40 +0300 Subject: [PATCH 143/171] Removed print --- thunder/benchmarks/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/thunder/benchmarks/utils.py b/thunder/benchmarks/utils.py index c19f25cffb..ac4a9fe1bf 100644 --- a/thunder/benchmarks/utils.py +++ b/thunder/benchmarks/utils.py @@ -181,7 +181,6 @@ def torch_timer_total_benchmark( description=l, ) results.append(t.blocked_autorange(min_run_time=1)) - print(results[-1]) compare = Compare(results) compare.colorize(rowwise=True) compare.print() From e5f4e7dfb10af3c6330147ddf64c907e14bb8111 Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 17 Sep 2024 15:29:41 +0300 Subject: [PATCH 144/171] Updated doc --- thunder/benchmarks/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/thunder/benchmarks/utils.py b/thunder/benchmarks/utils.py index ac4a9fe1bf..bc392f50de 100644 --- a/thunder/benchmarks/utils.py +++ b/thunder/benchmarks/utils.py @@ -83,6 +83,7 @@ def torch_fw_bw_benchmark(models: list, labels: list, inputs: list, iters: int, labels: a list of labels (names) referring to the models. inputs: a list of inputs to give to models' forward pass. iters: benchmark iterations. + loss_fn: a Pytorch loss function. """ for m, input, label in zip(models, inputs, labels): # Warm up @@ -196,6 +197,7 @@ def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int, labels: a list of labels (names) referring to the models. inputs: a list of inputs to give to models' forward pass. iters: benchmark iterations. + loss_fn: a Pytorch loss function. """ for m, input, label in zip(models, inputs, labels): # Warm up From 0f98689acd7257fe56e5e20388d6d49a976c3306 Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 17 Sep 2024 15:42:15 +0300 Subject: [PATCH 145/171] Updated doc and removed unused imports --- examples/dev/litGPT.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/dev/litGPT.py b/examples/dev/litGPT.py index 51a370b662..6775c6dfd3 100644 --- a/examples/dev/litGPT.py +++ b/examples/dev/litGPT.py @@ -1,12 +1,9 @@ """ -This script benchmarks litGPT models in a easier way wrt to benchmark_litgpt.py with a fake training loop with no optimizers. +This script benchmarks litGPT models in a easier way wrt thunder.benchmarks.benchmark_litgpt.py with a fake training loop with no optimizers. """ from litgpt import GPT from thunder.benchmarks.utils import ( - thunder_fw_bw_benchmark, - torch_fw_bw_benchmark, - torch_fw_bw_benchmark_nvsight, torch_total_benchmark, torch_timer_total_benchmark ) From 38f176ad2f8f79952ee125fbadd15962856e8511 Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 17 Sep 2024 15:58:20 +0300 Subject: [PATCH 146/171] Restored partial trace benchmark options --- thunder/backend_optimizer/optimizer.py | 34 ++++++++++++++++++-------- thunder/backend_optimizer/utils.py | 4 +-- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index cee8f5070b..563619f72d 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -21,6 +21,8 @@ logging.basicConfig(level=logging.INFO, format="[{name}]: {message}", style="{") logger = logging.getLogger("Thunder Autotuner") +# Control if single trace regions or partial traces are benchmarked during OperatorExecutor tuning +_benchmark_single_trace_region = False class OptimizationAlgorithm(Enum): """ @@ -316,7 +318,7 @@ def __init__( self.apply_bucketing_bw_trace: bool = apply_bucketing_bw_trace - self.benchmark_iters: int = 20 + self.benchmark_iters: int = 5 self.compile_data = compile_data @@ -819,9 +821,6 @@ def match_bsym_executor(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor else: logger.debug(f"Available executors for single region:\n{candidate_executors}") - # Define the standalone trace in order to benchmark this symbol - subtrace = construct_trace()(current_bsym.sym, *current_bsym.args, **current_bsym.kwargs) - # Helpers candidate_best_time = BenchmarkResult() candidate_best_mem = BenchmarkResult() @@ -831,16 +830,31 @@ def match_bsym_executor(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor candidate_best_time = BenchmarkResult(index=0) candidate_best_mem = BenchmarkResult(index=0) else: + if _benchmark_single_trace_region: + # Define the standalone trace in order to benchmark this symbol + subtrace = construct_trace()(current_bsym.sym, *current_bsym.args, **current_bsym.kwargs) + # Search for best candidate for i, candidate in enumerate(candidate_executors): - from thunder.common import transform_for_execution - subtrace_placed = transform_for_execution(subtrace, executors_list=[candidate])[-1] - logger.debug(f"Subtrace to benchmark single symbol:\n{subtrace_placed}") - t, m, _ = benchmark_trace( - subtrace_placed, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0] + if _benchmark_single_trace_region: + from thunder.common import transform_for_execution + subtrace_placed = transform_for_execution(subtrace, executors_list=[candidate])[-1] + logger.debug(f"Subtrace to benchmark single symbol:\n{subtrace_placed}") + t, m, _ = benchmark_trace( + subtrace_placed, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0] + ) + else: + # Match the current candidate into helper dicts to benchmark partial trace + match_bsym_executor(current_bsym, [dict_time_strat, dict_mem_strat], candidate) + # Retrieve partial trace and benchmark, apply remat if possible + trc, _, _ = get_placed_trace(dict_time_strat, increasing_symbols) + t, m, _ = benchmark_trace( + trc, self.benchmark_iters, fw_trace=self.active_fw_trace_ctx[0] + ) + logger.info( + f"Operator excutor [{candidate.name}] candidate perf (is single trace region: {_benchmark_single_trace_region}): {t} ms {m/(2**30)} GB" ) - logger.debug(f"Operator excutor [{candidate.name}] candidate perf: {t} ms {m/(2**30)} GB") # Update results if t < candidate_best_time.runtime: candidate_best_time = BenchmarkResult(time=t, index=i) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 59a72d5327..924a8ee72c 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -473,7 +473,7 @@ def clone_args_if_needed(args): def compute_time_cost_nsight(fn: Callable, iters: int, *args) -> tuple[float, float, Any]: try: - warm_up_iters = 50 + warm_up_iters = 10 torch.cuda.empty_cache() torch.cuda.synchronize() # Warm up cycles @@ -503,7 +503,7 @@ def compute_time_cost_nsight(fn: Callable, iters: int, *args) -> tuple[float, fl def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[float, float, Any]: try: current_iter = 0 - warm_up_iters = 50 + warm_up_iters = 10 out = None # Warm up cycles From a3c037c9244348c90cb5cc17f094c2c01dcce4f7 Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 17 Sep 2024 16:28:01 +0300 Subject: [PATCH 147/171] Fixed function name typo and renamed dir --- examples/{dev => autotuner}/.gitignore | 0 examples/{dev => autotuner}/LLaMAMLP.py | 34 ++--- examples/{dev => autotuner}/litGPT.py | 12 +- examples/dev/nanogpt.py | 188 ------------------------ examples/dev/nvfuser_optimizations.py | 76 ---------- examples/dev/sdpa.py | 62 -------- examples/dev/te.py | 69 --------- thunder/benchmarks/utils.py | 2 +- 8 files changed, 18 insertions(+), 425 deletions(-) rename examples/{dev => autotuner}/.gitignore (100%) rename examples/{dev => autotuner}/LLaMAMLP.py (52%) rename examples/{dev => autotuner}/litGPT.py (91%) delete mode 100644 examples/dev/nanogpt.py delete mode 100644 examples/dev/nvfuser_optimizations.py delete mode 100644 examples/dev/sdpa.py delete mode 100644 examples/dev/te.py diff --git a/examples/dev/.gitignore b/examples/autotuner/.gitignore similarity index 100% rename from examples/dev/.gitignore rename to examples/autotuner/.gitignore diff --git a/examples/dev/LLaMAMLP.py b/examples/autotuner/LLaMAMLP.py similarity index 52% rename from examples/dev/LLaMAMLP.py rename to examples/autotuner/LLaMAMLP.py index 21ff039207..85df3a4200 100644 --- a/examples/dev/LLaMAMLP.py +++ b/examples/autotuner/LLaMAMLP.py @@ -1,12 +1,11 @@ """ -This benchmark script is intended to demonstrate the optimizer on a generic model. -No executor are given leaving full responsibility to the engine. +This benchmark script is intended to demonstrate the autotuner on a generic model. +No executor are given leaving full responsibility to Thunder. """ import torch import thunder -from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_fw_bw_benchmark, torch_total_benchmark - +from thunder.benchmarks.utils import torch_timer_total_benchmark, torch_total_benchmark class LLaMAMLP(torch.nn.Module): def __init__(self, n_embd, intermediate_size) -> None: @@ -30,30 +29,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: model = LLaMAMLP(a, b) + eager = model + torchcompile = torch.compile(model) jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type="runtime", autotune_enable_te=True) + jmodel_auto = thunder.jit(model, autotune_type="memory", autotune_enable_te=False, autotune_nv_enable_options=True) print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item()) print("deviation auto:", (jmodel_auto(x) - model(x)).abs().max().item()) iters = 100 - print("Results with thunder benchmark:") - fw_traces = [ - thunder.last_traces(jmodel_def)[-1], - thunder.last_traces(jmodel_auto)[-1], - ] - bw_traces = [ - thunder.last_backward_traces(jmodel_def)[-1], - thunder.last_backward_traces(jmodel_auto)[-1], - ] - fw_labels = ["fw_def", "fw_auto"] - bw_labels = ["bw_def", "bw_auto"] - thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) - - callables = [jmodel_def, jmodel_auto] - labels = ["def", "auto"] - inputs = [x, x] - print("\nResults with torch fw bw benchmark:") - torch_fw_bw_benchmark(callables, labels, inputs, iters) + callables = [eager, torchcompile, jmodel_def, jmodel_auto] + labels = ['eager', 'torchcompile', 'Thunder', 'Thunder Autotuned'] + inputs = [x, x, x, x] print("\nResults with torch total benchmark:") torch_total_benchmark(callables, labels, inputs, iters) + print("\nResults with torch timer benchmark:") + torch_timer_total_benchmark(callables, labels, inputs, "LlamaMLP") diff --git a/examples/dev/litGPT.py b/examples/autotuner/litGPT.py similarity index 91% rename from examples/dev/litGPT.py rename to examples/autotuner/litGPT.py index 6775c6dfd3..b3df4309c7 100644 --- a/examples/dev/litGPT.py +++ b/examples/autotuner/litGPT.py @@ -10,7 +10,7 @@ from thunder.tests.litgpt_model import Config import thunder import torch -# import time +import time torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn @@ -78,11 +78,11 @@ def __init__( autotune_optimize_common_blocks=test.optimize_transformer_blocks, autotune_optimize_common_blocks_min_size=test.optimize_transformer_min_block_size, ) - # print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item()) - # s = time.time_ns() - # print("deviation auto:", (jmodel_auto(x) - model(x)).abs().max().item()) - # e = time.time_ns() - # print("Compilation time:", {(e - s) / 1000000000}, "s") + print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item()) + s = time.time_ns() + print("deviation auto:", (jmodel_auto(x) - model(x)).abs().max().item()) + e = time.time_ns() + print("Compilation time:", {(e - s) / 1000000000}, "s") iters = 100 callables = [eager, torch_compile, jmodel_def, jmodel_auto] diff --git a/examples/dev/nanogpt.py b/examples/dev/nanogpt.py deleted file mode 100644 index bd6e7323a7..0000000000 --- a/examples/dev/nanogpt.py +++ /dev/null @@ -1,188 +0,0 @@ -""" -This benchmark script is intended to demonstrate the optimizer on nanoGPT model. -The script runner is taken from: https://github.com/karpathy/nanoGPT/blob/master/bench.py -""" - -import torch -import thunder -from thunder.benchmarks.utils import thunder_fw_bw_benchmark -from thunder.tests.nanogpt_model import GPTConfig, GPT -from contextlib import nullcontext - -warm_up_iters = 50 - - -def run(target: str = "runtime"): - if target != "runtime" and target != "memory": - raise AssertionError(f"Target {target} not supported. Only runtime and memory available") - # ----------------------------------------------------------------------------- - batch_size = 12 - block_size = 1024 - bias = False - real_data = False - seed = 1337 - device = "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. - dtype = ( - "bfloat16" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "float16" - ) # 'float32' or 'bfloat16' or 'float16' - compile = False # use PyTorch 2.0 to compile the model to be faster - profile = False # use pytorch profiler, or just simple benchmarking? - # exec(open('configurator.py').read()) # overrides from command line or config file - # ----------------------------------------------------------------------------- - - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul - torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn - device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast - ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype] - ctx = nullcontext() if device_type == "cpu" else torch.amp.autocast(device_type=device_type, dtype=ptdtype) - - # data loading init - if real_data: - raise RuntimeError("Not supported") - else: - # alternatively, if fixed data is desired to not care about data loading - x = torch.randint(50304, (batch_size, block_size), device=device) - y = torch.randint(50304, (batch_size, block_size), device=device) - get_batch = lambda split: (x, y) - - # model init - gptconf = GPTConfig( - block_size=block_size, # how far back does the model look? i.e. context size - n_layer=4, - n_head=12, - n_embd=768, # size of the model - dropout=0, # for determinism - bias=bias, - ) - model = GPT(gptconf) - model.to(device) - - jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit( - model, - autotune_type=target, - executors=["torchcompile", "nvfuser", "cudnn", "sdpa", "transformer_engine"], - use_cudagraphs=False, - ) - - if compile: - print("Compiling model...") - model = torch.compile(model) # pytorch 2.0 - - if profile: - # useful docs on pytorch profiler: - # - tutorial https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html - # - api https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile - wait, warmup, active = 5, 5, 5 - num_steps = wait + warmup + active - with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], - schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1), - on_trace_ready=torch.profiler.tensorboard_trace_handler("./bench_log"), - record_shapes=False, - profile_memory=False, - with_stack=False, # incurs an additional overhead, disable if not needed - with_flops=True, - with_modules=False, # only for torchscript models atm - ) as prof: - models = [jmodel_def, jmodel_auto] - - for mod in models: - print("Profiling new model") - X, Y = get_batch("train") - for k in range(num_steps): - with ctx: - _, loss = mod(X, Y) - X, Y = get_batch("train") - loss.backward() - lossf = loss.item() - print(f"{k}/{num_steps} loss: {lossf:.4f}") - - prof.step() # notify the profiler at end of each step - - else: - # simple benchmarking - def measure(m, label): - iters = 100 - torch.cuda.synchronize() - - for i in range(warm_up_iters): - X, Y = get_batch("train") - with ctx: - _, loss = m(X, Y) - loss.backward() - - start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] - stream = torch.cuda.current_stream() - torch.cuda.synchronize() - for i in range(iters): - torch.cuda.empty_cache() - torch.cuda._sleep(1_000_000) - X, Y = get_batch("train") - start_events[i].record(stream) - with ctx: - _, loss = m(X, Y) - loss.backward() - end_events[i].record(stream) - - torch.cuda.synchronize() - tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] - tot_time = sum(tot) / iters - print("\n\nResults torch benchmark:") - print(f"{label} tot time: {tot_time} ms") - - def measure_nvsight(m, label): - # Warm up - for _ in range(warm_up_iters): - X, Y = get_batch("train") - with ctx: - _, loss = m(X, Y) - loss.backward() - - torch.cuda.empty_cache() - torch.cuda.synchronize() - torch.cuda.cudart().cudaProfilerStart() - # Perform less iterations - for _ in range(20): - torch.cuda.empty_cache() - X, Y = get_batch("train") - torch.cuda.nvtx.range_push(f"{label}: fw-bw") - with ctx: - _, loss = m(X, Y) - loss.backward() - torch.cuda.nvtx.range_pop() - torch.cuda.cudart().cudaProfilerStop() - - measure(jmodel_auto, "auto") - measure(jmodel_def, "def") - - print("\n\nResults thunder benchmark:") - fw_traces = [ - thunder.last_traces(jmodel_def)[-1], - thunder.last_traces(jmodel_auto)[-1], - ] - bw_traces = [ - thunder.last_backward_traces(jmodel_def)[-1], - thunder.last_backward_traces(jmodel_auto)[-1], - ] - fw_labels = ["fw_def", "fw_auto"] - bw_labels = ["bw_def", "bw_auto"] - thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, 100) - - measure_nvsight(jmodel_def, "def") - measure_nvsight(jmodel_auto, "auto") - - traces = [ - thunder.last_traces(jmodel_def)[-1], - thunder.last_traces(jmodel_auto)[-1], - thunder.last_backward_traces(jmodel_def)[-1], - thunder.last_backward_traces(jmodel_auto)[-1], - ] - for t in traces: - print(f"{t}\n############################################") - - -run() diff --git a/examples/dev/nvfuser_optimizations.py b/examples/dev/nvfuser_optimizations.py deleted file mode 100644 index 6a97191600..0000000000 --- a/examples/dev/nvfuser_optimizations.py +++ /dev/null @@ -1,76 +0,0 @@ -""" -This benchmark script is intended to demonstrate the optimizer optimizing the nvFuser executor with its compile options. - -nvFuser compile options can be autotune with the argument `autotune_enable_nvfuser_all=True`. -""" - -import torch -import thunder -from thunder.benchmarks.utils import ( - thunder_fw_bw_benchmark, - torch_fw_bw_benchmark, - torch_fw_bw_benchmark_nvsight, - torch_total_benchmark, -) - - -class Module(torch.nn.Module): - def __init__(self, in_features, out_features) -> None: - super().__init__() - self.linear = torch.nn.Sequential( - torch.nn.Linear(in_features, out_features), - torch.nn.Linear(out_features, in_features), - torch.nn.Linear(in_features, out_features), - torch.nn.Linear(out_features, in_features), - ) - self.silu = torch.nn.SiLU() - - def forward(self, x: torch.Tensor): - b = self.linear(x) - c = b @ torch.transpose(b, 0, 1) - for _ in range(4): - c = c @ torch.transpose(c, 0, 1) - return self.silu(c) - - -with torch.device("cuda"): - in_features = 1 << 8 - out_features = 1 << 10 - model = Module(in_features, out_features) - x = torch.randn(1 << 9, in_features, requires_grad=True) - - jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit( - model, autotune_type="runtime", executors=["nvfuser", "cudnn"], autotune_enable_nvfuser_all=True - ) - - y = jmodel_def(x) - y = jmodel_auto(x) - - iters = 100 - print("Results thunder benchmark:") - fw_traces = [ - thunder.last_traces(jmodel_def)[-1], - thunder.last_traces(jmodel_auto)[-1], - ] - bw_traces = [ - thunder.last_backward_traces(jmodel_def)[-1], - thunder.last_backward_traces(jmodel_auto)[-1], - ] - fw_labels = ["fw_def", "fw_auto"] - bw_labels = ["bw_def", "bw_auto"] - thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) - # thunder_fw_bw_benchmark(traces, labels, iters, nvsight=True) - - callables = [jmodel_def, jmodel_auto] - labels = ["def", "auto"] - inputs = [x, x] - print("Results torch benchmark:") - torch_fw_bw_benchmark(callables, labels, inputs, iters) - torch_total_benchmark(callables, labels, inputs, iters) - torch_fw_bw_benchmark_nvsight(callables, labels, inputs, iters) - - for t in fw_traces: - print(f"{t}\n#########################################") - for t in bw_traces: - print(f"{t}\n#########################################") diff --git a/examples/dev/sdpa.py b/examples/dev/sdpa.py deleted file mode 100644 index f50ff2bf84..0000000000 --- a/examples/dev/sdpa.py +++ /dev/null @@ -1,62 +0,0 @@ -""" -This benchmark script is intended to demonstrate the optimizer working on -the single trace region bext executor (when the forward trace symbol will influence the backward trace). - -Set the log level at least to INF0 in `thunder/backend_optimizer/optimizer.py`. -""" - -import torch -import thunder -from thunder.benchmarks.utils import thunder_fw_bw_benchmark, torch_total_benchmark - -dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 -torch.set_default_dtype(dtype) - - -class Model(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, query, key, value): - a = torch.nn.functional.scaled_dot_product_attention(query, key, value) - # Make different inputs as happens in a real model - b = torch.nn.functional.scaled_dot_product_attention(query + query, key + key, value + value) - return a + b - - -with torch.device("cuda"): - model = Model() - - jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type="runtime", executors=["nvfuser", "cudnn", "sdpa"]) - - q = torch.rand(32, 8, 128, 64 * 1, requires_grad=True) - k = torch.rand(32, 8, 128, 64 * 1, requires_grad=True) - v = torch.rand(32, 8, 128, 64 * 1, requires_grad=True) - - jmodel_def(q, k, v) - jmodel_auto(q, k, v) - - iters = 100 - fw_traces = [ - thunder.last_traces(jmodel_def)[-1], - thunder.last_traces(jmodel_auto)[-1], - ] - bw_traces = [ - thunder.last_backward_traces(jmodel_def)[-1], - thunder.last_backward_traces(jmodel_auto)[-1], - ] - fw_labels = ["fw_def", "fw_auto"] - bw_labels = ["bw_def", "bw_auto"] - print("Thunder benchmark:") - thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) - - print("\n\n\n\n\n\n") - print(f"{thunder.last_traces(jmodel_def)[-1]}") - print("###############################################################################") - print(f"{thunder.last_traces(jmodel_auto)[-1]}") - - print("\n\n") - print(f"{thunder.last_backward_traces(jmodel_def)[-1]}") - print("###############################################################################") - print(f"{thunder.last_backward_traces(jmodel_auto)[-1]}") diff --git a/examples/dev/te.py b/examples/dev/te.py deleted file mode 100644 index f4ec28a606..0000000000 --- a/examples/dev/te.py +++ /dev/null @@ -1,69 +0,0 @@ -""" -This benchmark script is intended to demonstrate the optimizer supporting the transformer engine executor. - -This option can be enabled inside the autotuner by using the flag `autotune_enable_te=True`. -""" - -import torch -import thunder -from thunder.benchmarks.utils import ( - thunder_fw_bw_benchmark, - torch_fw_bw_benchmark, - torch_fw_bw_benchmark_nvsight, - torch_total_benchmark, -) - - -class Module(torch.nn.Module): - def __init__(self, in_features, out_features) -> None: - super().__init__() - self.linear = torch.nn.Sequential( - torch.nn.Linear(in_features, out_features), - torch.nn.Linear(out_features, in_features), - torch.nn.Linear(in_features, out_features), - ) - - def forward(self, x: torch.Tensor): - return self.linear(x) - - -with torch.device("cuda"): - m = 1 - in_features = 4096 * m - out_features = 4096 * m - model = Module(in_features, out_features) - x = torch.randn(768, in_features, requires_grad=True) - - jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit( - model, - autotune_type="runtime", - executors=[ - "nvfuser", - "transformer_engine", - ], - autotune_enable_te=True, - ) - - y = jmodel_def(x) - y = jmodel_auto(x) - - iters = 100 - fw_traces = [ - thunder.last_traces(jmodel_def)[-1], - thunder.last_traces(jmodel_auto)[-1], - ] - bw_traces = [ - thunder.last_backward_traces(jmodel_def)[-1], - thunder.last_backward_traces(jmodel_auto)[-1], - ] - fw_labels = ["fw_def", "fw_auto"] - bw_labels = ["bw_def", "bw_auto"] - print("Results thunder benchmark:") - thunder_fw_bw_benchmark(fw_traces, bw_traces, fw_labels, bw_labels, iters) - - callables = [jmodel_def, jmodel_auto] - labels = ["def", "auto"] - inputs = [x, x] - print("\n\nResults torch benchmark:") - torch_total_benchmark(callables, labels, inputs, iters) diff --git a/thunder/benchmarks/utils.py b/thunder/benchmarks/utils.py index bc392f50de..4d48c66c54 100644 --- a/thunder/benchmarks/utils.py +++ b/thunder/benchmarks/utils.py @@ -168,7 +168,7 @@ def torch_timer_total_benchmark( """ if loss_fn is not None else """ - _run_atograd(m, i) + _run_autograd(m, i) """, globals={ "i": i, From 46e3a715b1439415cb885ba2b7b8b379d954fb5d Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 17 Sep 2024 16:33:15 +0300 Subject: [PATCH 148/171] Changed optimizaton type --- examples/autotuner/LLaMAMLP.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/autotuner/LLaMAMLP.py b/examples/autotuner/LLaMAMLP.py index 85df3a4200..d56f413542 100644 --- a/examples/autotuner/LLaMAMLP.py +++ b/examples/autotuner/LLaMAMLP.py @@ -7,6 +7,7 @@ import thunder from thunder.benchmarks.utils import torch_timer_total_benchmark, torch_total_benchmark + class LLaMAMLP(torch.nn.Module): def __init__(self, n_embd, intermediate_size) -> None: super().__init__() @@ -22,24 +23,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: with torch.device("cuda"): - mult = 1 + mult = 2 a = 4096 * mult b = 11008 * mult - x = torch.randn(2, 2048, a, requires_grad=True) + x = torch.randn(4, 2048, a, requires_grad=True) model = LLaMAMLP(a, b) eager = model torchcompile = torch.compile(model) jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type="memory", autotune_enable_te=False, autotune_nv_enable_options=True) + jmodel_auto = thunder.jit(model, autotune_type="runtime", autotune_enable_te=False, autotune_nv_enable_options=True) print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item()) print("deviation auto:", (jmodel_auto(x) - model(x)).abs().max().item()) iters = 100 callables = [eager, torchcompile, jmodel_def, jmodel_auto] - labels = ['eager', 'torchcompile', 'Thunder', 'Thunder Autotuned'] + labels = ["eager", "torchcompile", "Thunder", "Thunder Autotuned"] inputs = [x, x, x, x] print("\nResults with torch total benchmark:") torch_total_benchmark(callables, labels, inputs, iters) From ff9f661ec2b75ebb5f2c60e6bccb30d866760431 Mon Sep 17 00:00:00 2001 From: matteochen Date: Tue, 17 Sep 2024 16:33:49 +0300 Subject: [PATCH 149/171] Formatter --- examples/autotuner/litGPT.py | 6 ++---- thunder/backend_optimizer/optimizer.py | 3 ++- thunder/benchmarks/utils.py | 14 ++++++++++++-- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/examples/autotuner/litGPT.py b/examples/autotuner/litGPT.py index b3df4309c7..d2911d309c 100644 --- a/examples/autotuner/litGPT.py +++ b/examples/autotuner/litGPT.py @@ -3,10 +3,7 @@ """ from litgpt import GPT -from thunder.benchmarks.utils import ( - torch_total_benchmark, - torch_timer_total_benchmark -) +from thunder.benchmarks.utils import torch_total_benchmark, torch_timer_total_benchmark from thunder.tests.litgpt_model import Config import thunder import torch @@ -15,6 +12,7 @@ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn + class LitGPTModelThunderConfig: def __init__( self, diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 563619f72d..8038b295b7 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -24,6 +24,7 @@ # Control if single trace regions or partial traces are benchmarked during OperatorExecutor tuning _benchmark_single_trace_region = False + class OptimizationAlgorithm(Enum): """ Represents the optimization technique used by the autotuner. @@ -836,9 +837,9 @@ def match_bsym_executor(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor # Search for best candidate for i, candidate in enumerate(candidate_executors): - if _benchmark_single_trace_region: from thunder.common import transform_for_execution + subtrace_placed = transform_for_execution(subtrace, executors_list=[candidate])[-1] logger.debug(f"Subtrace to benchmark single symbol:\n{subtrace_placed}") t, m, _ = benchmark_trace( diff --git a/thunder/benchmarks/utils.py b/thunder/benchmarks/utils.py index 4d48c66c54..f40c8eff36 100644 --- a/thunder/benchmarks/utils.py +++ b/thunder/benchmarks/utils.py @@ -26,6 +26,7 @@ def __init__( self.bw_fn: Callable | None = bw_fn self.executor = executor + def _run_loss(model, input, target, loss_fn): logits = model(input) logits = logits.reshape(-1, logits.size(-1)) @@ -33,10 +34,12 @@ def _run_loss(model, input, target, loss_fn): loss = loss_fn(logits, target) loss.backward() + def _run_autograd(model, input): y = model(input) torch.autograd.grad(y, input, grad_outputs=torch.ones_like(y)) + def torch_fw_bw_benchmark_nvsight(models: list, labels: list, inputs: list, iters: int, loss) -> None: """ Benchmark a mock trainig loop of the given models. The loss function is defined as a naive torch.sum(). @@ -72,7 +75,9 @@ def torch_fw_bw_benchmark_nvsight(models: list, labels: list, inputs: list, iter torch.cuda.cudart().cudaProfilerStop() -def torch_fw_bw_benchmark(models: list, labels: list, inputs: list, iters: int, loss_fn: Callable | None = None) -> None: +def torch_fw_bw_benchmark( + models: list, labels: list, inputs: list, iters: int, loss_fn: Callable | None = None +) -> None: """ Benchmark a mock trainig loop of the given models. Time measurements will be performed by using cuda events. A loss function is applied to trigger backward if provided. Otherwise torch.autograd will be used. @@ -146,6 +151,7 @@ def torch_fw_bw_benchmark(models: list, labels: list, inputs: list, iters: int, print(f"{label} tot bw time: {tot_time} ms") print(f"{label} max bw allocated memory: {max_allocated_bytes / (2**30)} GB") + def torch_timer_total_benchmark( models: list, labels: list, inputs: list, name: str = "Model", loss_fn: Callable | None = None ) -> None: @@ -186,7 +192,10 @@ def torch_timer_total_benchmark( compare.colorize(rowwise=True) compare.print() -def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int, loss_fn: Callable | None = None) -> None: + +def torch_total_benchmark( + models: list, labels: list, inputs: list, iters: int, loss_fn: Callable | None = None +) -> None: """ Benchmark a mock trainig loop of the given models. Time measurements will be performed by using cuda events. A loss function is applied to trigger backward if provided. Otherwise torch.autograd will be used. @@ -238,6 +247,7 @@ def torch_total_benchmark(models: list, labels: list, inputs: list, iters: int, print(f"{label} tot time: {tot_time} ms") print(f"{label} max allocated memory: {max_allocated_bytes / (2**30)} GB") + def thunder_fw_bw_benchmark( fw_traces: list, bw_traces: list, fw_labels: list, bw_labels: list, iters: int, nvsight: bool = False ) -> None: From cdc9157aed5b03f76dff22193e5eeb98af4b68e4 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 18 Sep 2024 11:07:20 +0300 Subject: [PATCH 150/171] Enhanced logs description --- thunder/benchmarks/utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/thunder/benchmarks/utils.py b/thunder/benchmarks/utils.py index f40c8eff36..b6e165cd66 100644 --- a/thunder/benchmarks/utils.py +++ b/thunder/benchmarks/utils.py @@ -118,8 +118,8 @@ def torch_fw_bw_benchmark( torch.cuda.synchronize() tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] tot_time = sum(tot) / iters - print(f"{label} tot fw time: {tot_time} ms") - print(f"{label} max fw allocated memory: {max_allocated_bytes / (2**30)} GB") + print(f"{label} forward mean time: {tot_time} ms") + print(f"{label} peak forward allocated memory: {max_allocated_bytes / (2**30)} GB") start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] @@ -148,15 +148,15 @@ def torch_fw_bw_benchmark( torch.cuda.synchronize() tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] tot_time = sum(tot) / iters - print(f"{label} tot bw time: {tot_time} ms") - print(f"{label} max bw allocated memory: {max_allocated_bytes / (2**30)} GB") + print(f"{label} backward mean time: {tot_time} ms") + print(f"{label} peak backward allocated memory: {max_allocated_bytes / (2**30)} GB") def torch_timer_total_benchmark( models: list, labels: list, inputs: list, name: str = "Model", loss_fn: Callable | None = None ) -> None: """ - Benchmark a mock trainig loop time of the given models. Measurements will be computed by using torch.utils.benchmark.Timer. + Benchmark a mock trainig loop time of the given models. Measurements will be computed by using torch.utils.benchmark.Timer, median times will be provided. A loss function is applied to trigger backward if provided. Otherwise torch.autograd will be used. Args: @@ -244,8 +244,8 @@ def torch_total_benchmark( torch.cuda.synchronize() tot = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] tot_time = sum(tot) / iters - print(f"{label} tot time: {tot_time} ms") - print(f"{label} max allocated memory: {max_allocated_bytes / (2**30)} GB") + print(f"{label} forward+backward mean time: {tot_time} ms") + print(f"{label} peak forward+backward allocated memory: {max_allocated_bytes / (2**30)} GB") def thunder_fw_bw_benchmark( From 6954062fc25653d3491dea75855554ba08ab6f02 Mon Sep 17 00:00:00 2001 From: matteochen Date: Wed, 18 Sep 2024 14:32:49 +0300 Subject: [PATCH 151/171] Small fixes to align main --- examples/autotuner/litGPT.py | 1 - thunder/__init__.py | 5 ++--- thunder/backend_optimizer/optimizer.py | 9 --------- thunder/backend_optimizer/utils.py | 2 +- thunder/tests/test_autotuner.py | 9 +-------- 5 files changed, 4 insertions(+), 22 deletions(-) diff --git a/examples/autotuner/litGPT.py b/examples/autotuner/litGPT.py index d2911d309c..f6135a46eb 100644 --- a/examples/autotuner/litGPT.py +++ b/examples/autotuner/litGPT.py @@ -72,7 +72,6 @@ def __init__( model, autotune_type=test.autotune_type, executors=test.executors, - use_cudagraphs=False, autotune_optimize_common_blocks=test.optimize_transformer_blocks, autotune_optimize_common_blocks_min_size=test.optimize_transformer_min_block_size, ) diff --git a/thunder/__init__.py b/thunder/__init__.py index d2c98ba94d..a2b85a307b 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -330,12 +330,11 @@ def jit( # Default the executors list to all_executors if no options are given # Otherwise the user restricted choice will be used from thunder.executors.transformer_engineex import transformer_engine_ex - from thunder.executors.cudagraphex import cudagraphex from thunder.executors.pythonex import ex as python_ex if not executors: executors = get_all_executors() - # Remove python and cudagraph - executors = [ex for ex in executors if ex != python_ex and ex != cudagraphex] + # Remove pythonex + executors = [ex for ex in executors if ex != python_ex] # Remove transformer_engine if not requested executors = [ ex diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 8038b295b7..5162b3200d 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -441,15 +441,6 @@ def _best_runtime_and_memory_candidates(self, candidates: Sequence[OutputCandida (pair.fw, pair.bw, pair.compile_opt, pair.executors_fw, pair.executors_bw, False), (remat_fw, remat_bw, pair.compile_opt, pair.executors_fw, pair.executors_bw, True), ] - # if self.compile_data.use_cudagraphs is not None and self.compile_data.use_cudagraphs: - # from thunder.executors.cudagraphex import cudagraphex - - # pair_options.extend( - # [ - # (cudagraphex.fusion_pass(pair.fw), cudagraphex.fusion_pass(pair.bw)), - # (cudagraphex.fusion_pass(remat_fw), cudagraphex.fusion_pass(remat_bw)), - # ] - # ) # Select the best options for pair_option in pair_options: fw, bw, compile_opt, executors_fw, executors_bw, remat_applied = pair_option diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 924a8ee72c..fc90251931 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -869,7 +869,7 @@ def transform_proxies_to_real(sequence: Sequence, level=0, **kwargs) -> tuple | res = [] for e in sequence: - if type(e) is tuple: + if isinstance(e, Sequence): res.append(transform_proxies_to_real(e, level + 1, **kwargs)) else: if isinstance(e, TensorProxy): diff --git a/thunder/tests/test_autotuner.py b/thunder/tests/test_autotuner.py index 336db20986..83378672b7 100644 --- a/thunder/tests/test_autotuner.py +++ b/thunder/tests/test_autotuner.py @@ -497,7 +497,7 @@ def forward(self, x): @pytest.mark.parametrize( - "model, tensor_shape, dtype, autotune_type, executors, expected_executors, use_cudagraphs, use_te", + "model, tensor_shape, dtype, autotune_type, executors, expected_executors, use_te", [ ( Model_1(32, 32), @@ -506,7 +506,6 @@ def forward(self, x): "runtime", [nvfuserex], [[nvfuserex, torchex, pythonex]], - True, False, ), ( @@ -516,7 +515,6 @@ def forward(self, x): "memory", [torch_compile_ex], [[torch_compile_ex, torchex, pythonex]], - True, False, ), ( @@ -526,7 +524,6 @@ def forward(self, x): "runtime", [transformer_engine_ex], [[transformer_engine_ex, nvfuserex, torchex, pythonex]], - False, True, ), ( @@ -537,7 +534,6 @@ def forward(self, x): [sdpa_ex, cudnn_ex], [[sdpa_ex, nvfuserex, torchex, pythonex], [cudnn_ex, nvfuserex, torchex, pythonex]], False, - False, ), ( Model_2(), @@ -549,7 +545,6 @@ def forward(self, x): [sdpa_ex, transformer_engine_ex, nvfuserex, torchex, pythonex], [transformer_engine_ex, sdpa_ex, nvfuserex, torchex, pythonex], ], - False, True, ), ], @@ -562,7 +557,6 @@ def test_autotuner( autotune_type: str, executors: list, expected_executors: list[list], - use_cudagraphs: bool, use_te: bool, ): def _run(): @@ -573,7 +567,6 @@ def _run(): model, autotune_type=autotune_type, executors=executors, - use_cudagraphs=use_cudagraphs, autotune_enable_te=use_te, ) y_def = jitted_def(x) From 87a79fdf27f186cb6497683d2f56554f3d18d06c Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 19 Sep 2024 10:31:17 +0300 Subject: [PATCH 152/171] Fixed OOM errors during trace benchmarks leading to a premature end of the tuning process --- thunder/backend_optimizer/utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index fc90251931..ea49a9854d 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -549,7 +549,7 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl def build_static_args(sequence: Sequence, **kwargs) -> list: return transform_proxies_to_real(sequence, level=0, **kwargs) - def backward_trace_args_preprocess() -> list: + def backward_trace_args_preprocess() -> list | None: if "fw_trace" not in kwargs: raise RuntimeError( "Set the associated forward trace in order to benchmark backward pass with sdpa executor" @@ -559,6 +559,9 @@ def backward_trace_args_preprocess() -> list: raise AssertionError(f"forward trace is not a TraceCtx. Received: {type(fw_trace)}") # Run the fw trace and get the outputs fw_output = benchmark_trace(fw_trace, apply_del_last_used=False)[2] + # If any issue with the forward trace benchmark we have to stop this backward benchmark too (usually OOM errors) + if fw_output is None: + return None # Check if the fw trace is a final trace or an intermediate one (used for single trace region benchmarks) sig = fw_trace.signature_with_no_ctx() @@ -639,7 +642,11 @@ def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causa input_args = backward_trace_args_preprocess() # Forward or computational trace, parse the compile time input args... else: - input_args: list = build_static_args(trace.args, te_used=te_used) + input_args = build_static_args(trace.args, te_used=te_used) + + # Can not parse input args (usually due to OOM errors in upstream calls) + if input_args is None: + return float('inf'), float('inf'), None # Obtain the python executable string executable_str = trace.python() From 75b856dba0296271ecc12762cb66fb58d2699dec Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 19 Sep 2024 11:49:05 +0300 Subject: [PATCH 153/171] Handled nvmath missing installation --- thunder/executors/nvmathex.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/thunder/executors/nvmathex.py b/thunder/executors/nvmathex.py index b0ae70dbb7..6d4e8a5eb8 100644 --- a/thunder/executors/nvmathex.py +++ b/thunder/executors/nvmathex.py @@ -1,11 +1,17 @@ from importlib.metadata import version from thunder.core.prims import PrimIDs import logging -import nvmath import thunder import thunder.torch as ltorch import torch +try: + import nvmath + HAS_NVMATH = True +except: + pass + HAS_NVMATH = False + logger = logging.getLogger("Thunder nvmath_ex") logger.disabled = True @@ -43,7 +49,7 @@ def _nvmath_linalg_advanced_matmul_impl(a: torch.Tensor, b: torch.Tensor) -> tor return result def _nvmath_linalg_advanced_matmul_checker(*args, **kwargs) -> bool: - return True + return HAS_NVMATH nvmath_linalg_advanced_matmul = nvmath_ex.register_operator( "nvmath_linalg_advanced_matmul", From f2b934b2d4e565b5854e7da48e6299e863fbf611 Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 19 Sep 2024 11:56:46 +0300 Subject: [PATCH 154/171] Prev commit --- thunder/executors/nvmathex.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/thunder/executors/nvmathex.py b/thunder/executors/nvmathex.py index 6d4e8a5eb8..d40ff10e57 100644 --- a/thunder/executors/nvmathex.py +++ b/thunder/executors/nvmathex.py @@ -8,18 +8,19 @@ try: import nvmath HAS_NVMATH = True + version = version('nvmath-python') except: pass HAS_NVMATH = False + version = None logger = logging.getLogger("Thunder nvmath_ex") logger.disabled = True -nvmath_ex = thunder.extend.OperatorExecutor("nvmath", version=version('nvmath-python')) +nvmath_ex = thunder.extend.OperatorExecutor("nvmath", version=version) thunder.extend.register_executor(nvmath_ex) _cache = {} -options = nvmath.linalg.advanced.MatmulOptions(logger=logger) def _cache_key(a: torch.Tensor, b: torch.Tensor) -> str: def _get_shape_str(t: tuple): @@ -28,6 +29,7 @@ def _get_shape_str(t: tuple): return f'{_get_shape_str(a.size())}-{_get_shape_str(b.size())}' def _nvmath_linalg_advanced_matmul_impl(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + options = nvmath.linalg.advanced.MatmulOptions(logger=logger) # Check if these shapes have been cached k = _cache_key(a, b) if k in _cache: From 42401b133dccbc03171707693f7fb0b6703537e3 Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 19 Sep 2024 13:07:12 +0300 Subject: [PATCH 155/171] Prev commit --- thunder/tests/test_extend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/thunder/tests/test_extend.py b/thunder/tests/test_extend.py index 30015e76f4..58250a8a53 100644 --- a/thunder/tests/test_extend.py +++ b/thunder/tests/test_extend.py @@ -127,6 +127,7 @@ def test_get_all_executors_includes_all_native_executors(): "apex", "cudnn", "fa3", + "nvmath", "torch", "cudnn_layernorm", "sdpa", From 02f50380f7c16d4a3a80f74a3124df4735326ee0 Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 19 Sep 2024 13:08:28 +0300 Subject: [PATCH 156/171] Added comment --- thunder/executors/nvmathex.py | 1 + 1 file changed, 1 insertion(+) diff --git a/thunder/executors/nvmathex.py b/thunder/executors/nvmathex.py index d40ff10e57..388d716736 100644 --- a/thunder/executors/nvmathex.py +++ b/thunder/executors/nvmathex.py @@ -15,6 +15,7 @@ version = None logger = logging.getLogger("Thunder nvmath_ex") +# Disable nvmath logs logger.disabled = True nvmath_ex = thunder.extend.OperatorExecutor("nvmath", version=version) From fadafe476df79cfe0e05daa17c0de9b28cff7621 Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 19 Sep 2024 14:16:36 +0300 Subject: [PATCH 157/171] Updated test runner --- examples/autotuner/litGPT.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/autotuner/litGPT.py b/examples/autotuner/litGPT.py index f6135a46eb..8b0bf41058 100644 --- a/examples/autotuner/litGPT.py +++ b/examples/autotuner/litGPT.py @@ -8,6 +8,7 @@ import thunder import torch import time +from pprint import pprint torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn @@ -58,12 +59,12 @@ def __init__( if test.seq_len != -1: cfg.block_size = test.seq_len torch.set_default_dtype(torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16) - print(cfg) + pprint(cfg) + print("Batch size:", test.batch_size) with torch.device("cuda"): model = GPT(cfg) x = torch.randint(1, model.config.vocab_size, (test.batch_size, cfg.block_size)) target = torch.ones_like(x) - print(f"Input size: {x.size()}") eager = model torch_compile = torch.compile(model) From e4707fa37a0294d54ac40ebea57525fe7728fe48 Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 19 Sep 2024 14:21:30 +0300 Subject: [PATCH 158/171] Log applied executors --- examples/autotuner/litGPT.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/autotuner/litGPT.py b/examples/autotuner/litGPT.py index 8b0bf41058..0414c37877 100644 --- a/examples/autotuner/litGPT.py +++ b/examples/autotuner/litGPT.py @@ -90,6 +90,8 @@ def __init__( torch_total_benchmark(callables, labels, inputs, iters, torch.nn.functional.cross_entropy) print(f"\nResults torch timer benchmark ({iters} iters):") torch_timer_total_benchmark(callables, labels, inputs, test.model_name, torch.nn.functional.cross_entropy) + + print(f'Executors employed: {thunder.executors_applied(jmodel_auto)}') except Exception as e: print(f"Benchmark failed:\n{e}") import traceback From b728ab6416aca9a6fd621101a4fc68842b3ed60e Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 19 Sep 2024 14:21:52 +0300 Subject: [PATCH 159/171] Torch timer for benchmarks --- thunder/backend_optimizer/utils.py | 82 +++++++++++++++++++++++------- 1 file changed, 64 insertions(+), 18 deletions(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index ea49a9854d..74024f44e7 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -22,6 +22,7 @@ import thunder.core.transforms as transforms from itertools import chain import torch +from torch.utils.benchmark import Timer, Compare from thunder.core.dtypes import dtype from enum import Enum @@ -451,6 +452,8 @@ def benchmark_trace( from thunder.executors.passes import del_last_used import inspect + warm_up_iters = 10 + torch.compiler.reset() # TODO: If TE is used inside the trace we have to clone the input arguments as @@ -471,15 +474,28 @@ def clone_args_if_needed(args): res.append(arg) return tuple(res) + def warm_up(fn: Callable, args: Sequence): + for _ in range(warm_up_iters): + new_args = clone_args_if_needed(args) + fn(*new_args) + + def memory_snapshot(fn: Callable, args: Sequence, file_name:str): + new_args = clone_args_if_needed(args) + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.cuda.memory._record_memory_history() + fn(*new_args) + torch.cuda.memory._dump_snapshot(file_name + "_benchmark.pickle") + torch.cuda.memory._record_memory_history(enabled=None) + def compute_time_cost_nsight(fn: Callable, iters: int, *args) -> tuple[float, float, Any]: try: - warm_up_iters = 10 torch.cuda.empty_cache() torch.cuda.synchronize() + # Warm up cycles - for _ in range(warm_up_iters): - new_args = clone_args_if_needed(args) - fn(*new_args) + warm_up(fn, args) + # Benchmark torch.cuda.empty_cache() torch.cuda.synchronize() @@ -497,28 +513,21 @@ def compute_time_cost_nsight(fn: Callable, iters: int, *args) -> tuple[float, fl import inspect trc = inspect.getsource(fn) - print(f"#Trace execution failed for nsight (error: {e}):\n{trc}") + print(f"Trace execution failed for nsight (error: {e}):\n\nTrace executed:\n{trc}") raise e def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[float, float, Any]: try: current_iter = 0 - warm_up_iters = 10 out = None # Warm up cycles - for _ in range(warm_up_iters): - new_args = clone_args_if_needed(args) - out = fn(*new_args) + warm_up(fn, args) + # Snapshot request if snapshot: - new_args = clone_args_if_needed(args) - torch.cuda.empty_cache() - torch.cuda.synchronize() - torch.cuda.memory._record_memory_history() - fn(*new_args) - torch.cuda.memory._dump_snapshot(snapshot_name + "_benchmark.pickle") - torch.cuda.memory._record_memory_history(enabled=None) + memory_snapshot(fn, args, snapshot_name) + # Benchmark stream = torch.cuda.current_stream() max_allocated_bytes = 0 @@ -543,7 +552,43 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl tot_time = sum(times) / iters return tot_time, max_allocated_bytes, out except Exception as e: - print(f"#Trace execution failed at iter {current_iter} (error: {e})\n{repr}") + print(f"Trace execution failed at iter {current_iter} (error: {e})\n\nTrace executed:\n{repr}") + raise e + + def compute_time_cost_ms_torchtimer(fn: Callable, repr: str, *args) -> tuple[float, float, Any]: + try: + out = None + + # Warm up cycles + warm_up(fn, args) + + # Snapshot request + if snapshot: + memory_snapshot(fn, args, snapshot_name) + + # Measure memory consumption + torch.cuda.reset_peak_memory_stats(torch.cuda.current_device()) + new_args = clone_args_if_needed(args) + # Cache the output + out = fn(*new_args) + max_allocated_bytes = torch.cuda.max_memory_allocated(torch.cuda.current_device()) + + # Benchmark + new_args = clone_args_if_needed(args) + # Omit any labels as we are not going to print the Timer result + t = Timer( + stmt=""" + fn(*new_args) + """, + globals={ + "fn": fn, + "new_args": new_args + }, + ) + t = t.blocked_autorange(min_run_time=1) + return t.median, max_allocated_bytes, out + except Exception as e: + print(f"Trace execution failed (error: {e})\n\nTrace executed:\n{repr}") raise e def build_static_args(sequence: Sequence, **kwargs) -> list: @@ -663,7 +708,8 @@ def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causa if nsight: t, m, answer = compute_time_cost_nsight(executable, iters, *input_args) else: - t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) + t, m, answer = compute_time_cost_ms_torchtimer(executable, executable_str, *input_args) + # t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) except Exception: import traceback From bb308050aeb471e78540a884bbc569eb04010bd7 Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 19 Sep 2024 14:30:13 +0300 Subject: [PATCH 160/171] Doc --- thunder/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/thunder/__init__.py b/thunder/__init__.py index a2b85a307b..0505555ca8 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -300,6 +300,9 @@ def jit( autotune_optimize_common_blocks: boolean to enable trace's common block optimization during the compilation (for example transformer layers). This optimization can be used if you are working with a model with repeated block structures as transformer based models. You don't need to know where a block starts or ends as it's handled automatically. Default: ``"False"`` autotune_optimize_common_blocks_min_size: integer to control the minimum block length to trigger the common block optimization. Default: ``-1`` + autotune_save_configuration: boolean to produce a configuration file for the current model. This configuration can be loaded afterwards with ``"autotune_restore_configuration"``. Default ``"False"`` + autotune_restore_configuration: string containing the cached configuration file name with the relative path to the script invocation. + model_name: string containing the current model name used during the configuration file creation in ``"autotune_save_configuration"``. A default one is used if this is not provided. """ from thunder.backend_optimizer.optimizer import OptimizerType From 191f403b1791aad2291740931fc32ad98fb85840 Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 19 Sep 2024 15:25:28 +0300 Subject: [PATCH 161/171] Updated Anyproxy hash / updated test runner file --- examples/autotuner/LLaMAMLP.py | 9 ++++++++- thunder/backend_optimizer/utils.py | 17 +++++++++++++---- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/examples/autotuner/LLaMAMLP.py b/examples/autotuner/LLaMAMLP.py index d56f413542..9ec217347c 100644 --- a/examples/autotuner/LLaMAMLP.py +++ b/examples/autotuner/LLaMAMLP.py @@ -33,7 +33,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: eager = model torchcompile = torch.compile(model) jmodel_def = thunder.jit(model) - jmodel_auto = thunder.jit(model, autotune_type="runtime", autotune_enable_te=False, autotune_nv_enable_options=True) + jmodel_auto = thunder.jit( + model, + autotune_type="runtime", + autotune_enable_te=True, + autotune_nv_enable_options=True, + model_name="LLaMAMLP", + autotune_save_configuration=True, + ) print("deviation def:", (jmodel_def(x) - model(x)).abs().max().item()) print("deviation auto:", (jmodel_auto(x) - model(x)).abs().max().item()) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 74024f44e7..1a40cd746a 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -528,6 +528,10 @@ def compute_time_cost_ms(fn: Callable, repr: str, iters: int, *args) -> tuple[fl if snapshot: memory_snapshot(fn, args, snapshot_name) + # Save output + new_args = clone_args_if_needed(args) + out = fn(*new_args) + # Benchmark stream = torch.cuda.current_stream() max_allocated_bytes = 0 @@ -708,8 +712,13 @@ def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causa if nsight: t, m, answer = compute_time_cost_nsight(executable, iters, *input_args) else: - t, m, answer = compute_time_cost_ms_torchtimer(executable, executable_str, *input_args) - # t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) + # By default torch.utils.benchmark.Timer is employed for measurement but if TE FP8 is being used we have to used our custom measurer. + # https://github.com/mattteochen/lightning-thunder/blob/b728ab6416aca9a6fd621101a4fc68842b3ed60e/thunder/backend_optimizer/utils.py#L459 + if is_te_used(trace): + t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) + else: + t, m, answer = compute_time_cost_ms_torchtimer(executable, executable_str, *input_args) + # t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) except Exception: import traceback @@ -720,7 +729,6 @@ def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causa # Restore the autocast value to not mess up the input trace if te_used: trace._include_te_fp8_autocast = cached_te_fp8_autocast_value - return t, m, answer @@ -1037,7 +1045,8 @@ def _number_hash(t: NumberProxy) -> str: return "{" + str(t.value) + "}" def _any_proxy_hash(p: AnyProxy) -> str: - return "{" + p.__repr__() + "}" + # We are not using class' __repr__ as it might contain memory addresses and those could change during different iterations + return "{AnyProxy}" def _sequence_hash(s: Sequence | None) -> str: if s is None: From 7f64dbffb121c26c7af1f0096bd68d89255e4ed7 Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 19 Sep 2024 15:26:10 +0300 Subject: [PATCH 162/171] Formatter --- thunder/backend_optimizer/utils.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 1a40cd746a..b05d4f424c 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -479,7 +479,7 @@ def warm_up(fn: Callable, args: Sequence): new_args = clone_args_if_needed(args) fn(*new_args) - def memory_snapshot(fn: Callable, args: Sequence, file_name:str): + def memory_snapshot(fn: Callable, args: Sequence, file_name: str): new_args = clone_args_if_needed(args) torch.cuda.empty_cache() torch.cuda.synchronize() @@ -584,10 +584,7 @@ def compute_time_cost_ms_torchtimer(fn: Callable, repr: str, *args) -> tuple[flo stmt=""" fn(*new_args) """, - globals={ - "fn": fn, - "new_args": new_args - }, + globals={"fn": fn, "new_args": new_args}, ) t = t.blocked_autorange(min_run_time=1) return t.median, max_allocated_bytes, out @@ -695,7 +692,7 @@ def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causa # Can not parse input args (usually due to OOM errors in upstream calls) if input_args is None: - return float('inf'), float('inf'), None + return float("inf"), float("inf"), None # Obtain the python executable string executable_str = trace.python() From 350c4516a941cc1fd4aa3cde095d97f393444e72 Mon Sep 17 00:00:00 2001 From: matteochen Date: Thu, 19 Sep 2024 15:41:56 +0300 Subject: [PATCH 163/171] Restored manual benchmark configuration / added env var for test runner --- examples/autotuner/litGPT.py | 4 +++- thunder/backend_optimizer/utils.py | 15 ++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/autotuner/litGPT.py b/examples/autotuner/litGPT.py index 0414c37877..260ba925f1 100644 --- a/examples/autotuner/litGPT.py +++ b/examples/autotuner/litGPT.py @@ -9,10 +9,12 @@ import torch import time from pprint import pprint +import os torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn +os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" class LitGPTModelThunderConfig: def __init__( @@ -40,7 +42,7 @@ def __init__( LitGPTModelThunderConfig( 1, "runtime", - 1, + 2, executors=[ "cudnn", "sdpa", diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index b05d4f424c..7c2f2c8279 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -709,13 +709,14 @@ def scaled_dot_product_attention_backward(query, key, value, dropout_p, is_causa if nsight: t, m, answer = compute_time_cost_nsight(executable, iters, *input_args) else: - # By default torch.utils.benchmark.Timer is employed for measurement but if TE FP8 is being used we have to used our custom measurer. - # https://github.com/mattteochen/lightning-thunder/blob/b728ab6416aca9a6fd621101a4fc68842b3ed60e/thunder/backend_optimizer/utils.py#L459 - if is_te_used(trace): - t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) - else: - t, m, answer = compute_time_cost_ms_torchtimer(executable, executable_str, *input_args) - # t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) + # # By default torch.utils.benchmark.Timer is employed for measurement but if TE FP8 is being used we have to used our custom measurer. + # # https://github.com/mattteochen/lightning-thunder/blob/b728ab6416aca9a6fd621101a4fc68842b3ed60e/thunder/backend_optimizer/utils.py#L459 + # if is_te_used(trace): + # t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) + # else: + # t, m, answer = compute_time_cost_ms_torchtimer(executable, executable_str, *input_args) + + t, m, answer = compute_time_cost_ms(executable, executable_str, iters, *input_args) except Exception: import traceback From 580050ae3f22c42f295236ef804a9a27fcd9e382 Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 20 Sep 2024 10:19:58 +0300 Subject: [PATCH 164/171] Disabled flag --- examples/autotuner/litGPT.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/autotuner/litGPT.py b/examples/autotuner/litGPT.py index 260ba925f1..714b84a5fc 100644 --- a/examples/autotuner/litGPT.py +++ b/examples/autotuner/litGPT.py @@ -9,12 +9,12 @@ import torch import time from pprint import pprint -import os torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +# import os +# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" class LitGPTModelThunderConfig: def __init__( From 8bbe1ec836dab8a2e2ddd9af80934a53b905a8cd Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 20 Sep 2024 10:36:15 +0300 Subject: [PATCH 165/171] Updated doc --- thunder/backend_optimizer/optimizer.py | 6 ++++-- thunder/backend_optimizer/utils.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 5162b3200d..e066c91923 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -921,6 +921,7 @@ def measure_and_update_result(): n_missing_bsyms = len(group) - start_idx # Tune a single fusion group. + # NOTE: currently this is disabled for backward traces for i in range(0, n_missing_bsyms, n_missing_bsyms - 1 if self.trace_type == TraceType.BW else 1): if ex.name == "torchcompile": import torch @@ -938,6 +939,7 @@ def measure_and_update_result(): group[k], [dict_time_strat, dict_mem_strat], # In order to benchmark the fusion placecement, we can use any executor for the excluded bsym from the fusion region + # TODO: consider tuning the single trace regions removed from the fusion one get_first_available_operator_executor( bsym=group[k], executors=self.executors, @@ -1032,7 +1034,7 @@ def measure_and_update_result(): FusionExecutorsPlacementCtx(placement=executors_mem, compile_options=executor_compile_option) ) - # If any compile options will be used we will need to have duplicated executors inside the executors list to maintain the matching. + # If any compile options is used we will need to have duplicated executors inside the executors list to maintain the matching. self.fusion_executors_saved_for_later = [] ex: FusionExecutor for ex in self.fusion_executors: @@ -1048,7 +1050,7 @@ def measure_and_update_result(): ) self.fusion_executors_saved_for_later.append(ex) - # Always search with option disabled -> standard flow + # Always search with option disabled (standard flow) _search(ex) # Currently we are enabling one compile option at the time as testing all the permutations might need too much time. diff --git a/thunder/backend_optimizer/utils.py b/thunder/backend_optimizer/utils.py index 7c2f2c8279..4a163aa588 100644 --- a/thunder/backend_optimizer/utils.py +++ b/thunder/backend_optimizer/utils.py @@ -159,7 +159,7 @@ def get_not_used_intermediate_outsputs(trace_in: TraceCtx) -> list[Proxy]: This can be usefull if we want to force a specific TensorProxy to be returned in a modfied trace to avoid the dce. Args: - in_trace: A generic trace. + trace_in: A generic trace. """ def is_in_sequence(seq: Sequence[Any], t: Proxy): From df0469e06be15ffdc3c26b5e1a7c97e7971cbd7d Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 20 Sep 2024 12:10:17 +0300 Subject: [PATCH 166/171] Autotuner for jit with no autograd --- thunder/__init__.py | 29 ++++++++++++++++---- thunder/backend_optimizer/optimizer.py | 32 +++++++++++++++++----- thunder/executors/passes.py | 37 +++++++++++++++++--------- thunder/tests/test_autotuner.py | 27 +++++++++++++++++++ 4 files changed, 102 insertions(+), 23 deletions(-) diff --git a/thunder/__init__.py b/thunder/__init__.py index 0505555ca8..86962aaba1 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -686,8 +686,10 @@ def get_computation_and_inputs(*args, **kwargs): if backward_trc is None: from thunder.executors.passes import transform_for_execution as transform_for_execution_pass + from thunder.executors.passes import autotune_transform_for_execution from thunder.executors.passes import _transform_for_operator_executor_execution from thunder.distributed.utils import maybe_sort_waits + from thunder.backend_optimizer.optimizer import BackendOptimizer, TraceType tmp_comp_trc = _transform_for_operator_executor_execution(computation_trc, cd.executors_list) is_transformed, tmp_comp_trc = maybe_sort_waits(tmp_comp_trc) @@ -695,11 +697,28 @@ def get_computation_and_inputs(*args, **kwargs): computation_trc = tmp_comp_trc computation_traces.append(computation_trc) - extraces = transform_for_execution( - computation_trc, - executors_list=cd.executors_list, - use_del_last_used=False, - ) + autotune = cd.compile_options.get('autotune_type', None) + if autotune is None: + extraces = transform_for_execution( + computation_trc, + executors_list=cd.executors_list, + use_del_last_used=False, + ) + else: + optimizer_ctx = BackendOptimizer( + priority_executors=cd.executors_list, + apply_bucketing_bw_trace=False, + produce_log=False, + optimizer_type=autotune, + compile_data=cd, + ) + extrace = autotune_transform_for_execution( + optimizer_context=optimizer_ctx, + trace=computation_trc, + trace_type=TraceType.FW, + is_computational=True + ) + extraces = [extrace] computation_traces.extend(extraces) computation_trc = computation_traces[-1] diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index e066c91923..caf6d84463 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -273,6 +273,7 @@ class PlacerBase: optimizer_type (OptimizerType): The optimization target. active_fw_trace_ctx (tuple): An active forward trace set to optimize backward. cached_fw_traces (list): Cached optimized forward traces. + best_comp_trace (TraceCtx): The optimized computational trace. cached_computational_trace (TraceCtx): Original computational trace cached_computational_backward_trace (TraceCtx): Original computational backward trace bw_trace_candidates (TraceCandidate): An instance of trace candidates. @@ -310,6 +311,7 @@ def __init__( self.active_fw_trace_ctx: tuple[TraceCtx | None, FusionExecutorsPlacementCtx | None] = None, None self.cached_fw_traces: list[TraceCandidate] = [] + self.best_comp_trace: TraceCtx = TraceCtx() self.cached_computational_trace: TraceCtx = TraceCtx() self.cached_computational_backward_trace: TraceCtx = TraceCtx() self.bw_trace_candidates: TraceCandidates = TraceCandidates() @@ -339,9 +341,12 @@ def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType, apply_dce=True """ pass - def get_optimal_fw_traces(self) -> Sequence[TraceCtx]: + def get_optimal_fw_traces(self, is_computational=False) -> Sequence[TraceCtx] | TraceCtx: """ Retrive the optimal forward traces that the object has tuned. + + Args: + is_computational: The requested forward trace is a computational trace (autograd is disabled). """ return [] @@ -496,6 +501,8 @@ def _filter_candidates(self): # Number of fw traces to cached are: #fusion_executors * 2 def fw_benchmark(): # The optimizer builds the results in order following the self.fusion_executors list order + best_time = BenchmarkResult() + best_mem = BenchmarkResult() pair_time: dict pair_mem: dict for pair_time, pair_mem in zip( @@ -508,12 +515,15 @@ def fw_benchmark(): trc_time, placement_ctx_time = list(pair_time.values())[0] trc_mem, placement_ctx_mem = list(pair_mem.values())[0] label = list(pair_time.keys())[0] - # TODO (matteochen): remove the benchmark here as will done later on the bw pass c, m, _ = benchmark_trace(trc_time, self.benchmark_iters) + if c < best_time.runtime: + best_time = BenchmarkResult(time=c, trace=trc_time) self.debug_msg += ( f"Trace name = [{label}] - Target: TIME - Time = {c} ms - Mem = {m / (2**30)} GB\n{trc_time}\n\n" ) c, m, _ = benchmark_trace(trc_mem, self.benchmark_iters) + if m < best_mem.memory: + best_mem = BenchmarkResult(memory=m, trace=trc_mem) self.debug_msg += ( f"Trace name = [{label}] - Target: MEM - Mem = {m / (2**30)} GB - Time = {c} ms\n{trc_mem}\n\n" ) @@ -531,6 +541,10 @@ def fw_benchmark(): else label, ) ) + # Assign best computational trace + self.best_comp_trace = best_time.trace if self.optimizer_type == OptimizerType.RUNTIME else best_mem.trace + + # Cache the original fw trace self.cached_computational_trace = self.trace def bw_benchmark(): @@ -587,6 +601,7 @@ def bw_benchmark(): ) ) + # Cache original backward trace self.cached_computational_backward_trace = self.trace match self.trace_type: @@ -1070,10 +1085,12 @@ def measure_and_update_result(): ################################################## Public methods ################################################## """ - def get_optimal_fw_traces(self) -> Sequence[TraceCtx]: + def get_optimal_fw_traces(self, is_computational=False) -> Sequence[TraceCtx] | TraceCtx: if not self.cached_fw_traces: raise AssertionError("Failed to obtain optimal fw traces") - return [candidate.trace for candidate in self.cached_fw_traces] + if not is_computational: + return [candidate.trace for candidate in self.cached_fw_traces] + return self.best_comp_trace def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: restore_file = self.compile_data.compile_options.get("autotune_restore_configuration", "") @@ -1351,11 +1368,14 @@ def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType, apply_dce=True """ self.optimizer.attach_trace(trace=trace, trace_type=trace_type) - def get_optimal_fw_traces(self) -> Sequence[TraceCtx]: + def get_optimal_fw_traces(self, is_computational=False) -> Sequence[TraceCtx] | TraceCtx: """ Retrive the optimal forward traces that the object has tuned. + + Args: + is_computational: The requested forward trace is a computational trace (autograd is disabled). """ - return self.optimizer.get_optimal_fw_traces() + return self.optimizer.get_optimal_fw_traces(is_computational) def get_optimal_fw_bw_traces(self) -> tuple[TraceCtx, TraceCtx]: """ diff --git a/thunder/executors/passes.py b/thunder/executors/passes.py index 5239aa1f41..20633f9928 100644 --- a/thunder/executors/passes.py +++ b/thunder/executors/passes.py @@ -141,8 +141,8 @@ def visit_(bsym: BoundSymbol) -> transforms.VISIT_TYPE: # Autotuned transform_for_execution version def autotune_transform_for_execution( - *, optimizer_context: BackendOptimizer, trace: TraceCtx, trace_type: TraceType -) -> tuple[TraceCtx, TraceCtx] | None: + *, optimizer_context: BackendOptimizer, trace: TraceCtx, trace_type: TraceType, is_computational: bool = False +) -> tuple[TraceCtx, TraceCtx] | TraceCtx | None: import torch start_time_ns = time.perf_counter_ns() @@ -161,10 +161,16 @@ def autotune_transform_for_execution( optimizer_context.log_file_name = f"autotune_transform_for_execution_{sig_name}.log" # Forward traces are cached inside the context optimizer_context.optimize() + + # Retrive the optimized traces. If backward trace is requested then the forward trace will be given only together with the backward one. + # This is because the optimal forward does not always lead to an optimal backward. + # If this is a computational trace (no autograd) then the forward (computational) trace will be ready and returned. match trace_type: case TraceType.FW: - # Nothing more left - pass + if not is_computational: + pass + else: + fw_trace: TraceCtx = optimizer_context.get_optimal_fw_traces(is_computational) # When optimizing the backward pass, the optimizer will return the best fw and bw traces based on the requested autotune_type, no need to choose the fw pass manually case TraceType.BW: fw_extrace, bw_extrace = optimizer_context.get_optimal_fw_bw_traces() @@ -176,14 +182,21 @@ def autotune_transform_for_execution( # Assign the trace provenance match trace_type: case TraceType.FW: - cd = get_compile_data() - if not cd or not cd.compile_options.get('autotune_restore_configuration', ""): - fw_traces = optimizer_context.get_optimal_fw_traces() - for trc in fw_traces: - trc.set_provenance( - TraceProvenance(f"Autotuned transform for execution (took {elapsed_time_millis} milliseconds)") - ) - return None + if not is_computational: + cd = get_compile_data() + # Only for fresh tuning + if not cd or not cd.compile_options.get('autotune_restore_configuration', ""): + # We are assigning the provenance to all the possible candidates as at this stage we + # don't know which trace will be returned at the end of the optimization + fw_traces: list = optimizer_context.get_optimal_fw_traces() + for trc in fw_traces: + trc.set_provenance( + TraceProvenance(f"Autotuned transform for execution (took {elapsed_time_millis} milliseconds)") + ) + return None + else: + fw_trace.set_provenance(TraceProvenance(f"Autotuned transform for execution (took {elapsed_time_millis} milliseconds)")) + return fw_trace case TraceType.BW: bw_extrace.set_provenance( TraceProvenance(f"Autotuned transform for execution (took {elapsed_time_millis} milliseconds)") diff --git a/thunder/tests/test_autotuner.py b/thunder/tests/test_autotuner.py index 83378672b7..6d26df88ab 100644 --- a/thunder/tests/test_autotuner.py +++ b/thunder/tests/test_autotuner.py @@ -685,3 +685,30 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: thunder.last_traces(jitted)[-1].bound_symbols, thunder.last_traces(jitted_recovered)[-1].bound_symbols ): assert bsym_a.sym.executor == bsym_b.sym.executor + + +def test_no_autograd_trace_autotuning(): + def _fn(a, b): + t0 = a + b + t1 = a + t0 + t2 = t1 * t1 + t3 = b - t2 + return b @ t3 + + executors = ['torch', 'torchcompile'] + jfn_def = thunder.jit(_fn, executors=executors) + jfn_auto = thunder.jit(_fn, autotune_type='runtime', disable_torch_autograd=True, exeuctors=executors) + a = torch.randn(4,4) + b = torch.randn(4,4) + + y_def = jfn_def(a, b) + y_auto = jfn_auto(a, b) + + applied = set() + trace = thunder.last_traces(jfn_auto)[-1] + for b in trace.bound_symbols: + if b.sym.executor is not None: + applied.add(b.sym.executor.name) + + assert (applied == set(['torch']) or applied == set(['torchcompile', 'torch'])) + torch.testing.assert_close(y_def, y_auto) From 7f57584e84b29c8b28962800674d845508c0bd99 Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 20 Sep 2024 13:23:52 +0300 Subject: [PATCH 167/171] Added CUDA barrier for unit test --- thunder/tests/test_autotuner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/thunder/tests/test_autotuner.py b/thunder/tests/test_autotuner.py index 6d26df88ab..fa60343037 100644 --- a/thunder/tests/test_autotuner.py +++ b/thunder/tests/test_autotuner.py @@ -687,6 +687,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: assert bsym_a.sym.executor == bsym_b.sym.executor +@requiresCUDA +# Currently inside the autotuner flow nvfuser will be imported which will lead to import errors def test_no_autograd_trace_autotuning(): def _fn(a, b): t0 = a + b From 82017d267c93f3565d6e86de16b0d28d5d88fe02 Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 20 Sep 2024 14:03:24 +0300 Subject: [PATCH 168/171] Integrated autotuner in benchmark script --- thunder/benchmarks/benchmark_litgpt.py | 32 +++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index f63a66f5ce..fab3ff1527 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -238,6 +238,8 @@ def __init__( use_torchao_fp8_linear: bool = False, use_torchao_fp8_allgather: bool = False, use_torchao_fp8_precompute_scale_for_fsdp: bool = False, + autotune: str = "", + save_autotune_cfg: bool = False ): seed = 1337 torch.manual_seed(seed) @@ -358,6 +360,13 @@ def __init__( self.profiler_start = profiler_start self.profiler_stop = profiler_stop + # Autotuner + supported_autotuning = set(['runtime', 'memory', '']) + if autotune not in supported_autotuning: + raise AssertionError(f"Autotuning configuration not supported. Available ones are: {[a for a in supported_autotuning if a]}") + self.autotune_type = autotune + self.save_autotune_cfg = save_autotune_cfg + if n_layers is not None: self.config.n_layer = n_layers @@ -569,6 +578,16 @@ def setup_compile(self, model): executors.insert(0, transformer_engine_ex) + if "fa3" in self.compile: + from thunder.executors.fa3ex import fa3_ex + + executors.insert(0, fa3_ex) + + if "nvmath" in self.compile: + from thunder.executors.nvmathex import nvmath_ex + + executors.insert(0, nvmath_ex) + if "dynamo" in self.compile: if self.distributed_mode == "fsdp2": print("Resetting cache size for when fsdp2 and using thunder as backend torch.compile") @@ -595,7 +614,18 @@ def setup_compile(self, model): # so we are using the lower level torch._dynamo.optimize function model = torch._dynamo.optimize(backend=backend)(model) else: - model = thunder.jit(model, executors=executors) + if self.autotune_type: + # nvFuser compile options to be enabled if wanted with: autotune_nv_enable_options=True + model = thunder.jit( + model, + autotune_type=self.autotune_type, + executors=executors, + autotune_optimize_common_blocks=True, + autotune_save_configuration=self.save_autotune_cfg, + autotune_enable_te="transformerengine" in self.compile + ) + else: + model = thunder.jit(model, executors=executors) elif self.compile != "eager": raise ValueError(f"Invalid compile option: {self.compile}") From 26044a3e2de6ef53900b0305c4d5cd8e8b6ca8b7 Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 20 Sep 2024 14:13:25 +0300 Subject: [PATCH 169/171] Added missing flag --- thunder/benchmarks/benchmark_litgpt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index fab3ff1527..bc97c3f0c9 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -621,6 +621,7 @@ def setup_compile(self, model): autotune_type=self.autotune_type, executors=executors, autotune_optimize_common_blocks=True, + autotune_optimize_common_blocks_min_size=20, # This is quite low for a traced transformer block but will do the job autotune_save_configuration=self.save_autotune_cfg, autotune_enable_te="transformerengine" in self.compile ) From 0a171289a3d2e88385c7b2e2c7558910f6b611b8 Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 20 Sep 2024 15:05:50 +0300 Subject: [PATCH 170/171] Fixed comments --- thunder/backend_optimizer/optimizer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index caf6d84463..7a21d414cb 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -889,7 +889,6 @@ def match_bsym_executor(bsym_in: BoundSymbol, dicts: list[dict], ex_in: Executor best_res_time = BenchmarkResult() best_res_mem = BenchmarkResult() - # TODO (matteochen): Aggregate them best_placement_time = None best_keys_time = None best_placement_mem = None @@ -1050,6 +1049,7 @@ def measure_and_update_result(): ) # If any compile options is used we will need to have duplicated executors inside the executors list to maintain the matching. + # TODO: integrate torchcompile_cat alongside with nvFuser. This should speed up the autotuner too. self.fusion_executors_saved_for_later = [] ex: FusionExecutor for ex in self.fusion_executors: @@ -1122,7 +1122,6 @@ def attach_trace(self, *, trace: TraceCtx, trace_type: TraceType, apply_dce: boo match self.trace_type: case TraceType.FW: logger.info(f"New forward trace to optimize (strat = {self.optimizer_type})") - # TODO (matteochen): support bw trace optimization even though with no fw traces cached (computational trace?) case TraceType.BW: if not self.compile_data.compile_options.get("autotune_restore_configuration", ""): if not self.cached_fw_traces: From afb563776f324102e5b2d405eb3943e1637c1e72 Mon Sep 17 00:00:00 2001 From: matteochen Date: Fri, 20 Sep 2024 15:11:33 +0300 Subject: [PATCH 171/171] Removed comment --- thunder/backend_optimizer/optimizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/thunder/backend_optimizer/optimizer.py b/thunder/backend_optimizer/optimizer.py index 7a21d414cb..37a5076e20 100644 --- a/thunder/backend_optimizer/optimizer.py +++ b/thunder/backend_optimizer/optimizer.py @@ -987,6 +987,7 @@ def measure_and_update_result(): executors_mem = [] for bsym in self.trace.bound_symbols: if bsym.sym.id == PrimIDs.RETURN: + # TODO (matteochen): Aggregate them if "return" not in dict_time_strat or "return" not in dict_mem_strat: raise AssertionError(f"Expected key return in mapping {dict_time_strat} and {dict_mem_strat}") executors_time.append(dict_time_strat["return"])